diff --git a/.evergreen/config.yml b/.evergreen/config.yml index bc2cf0bb4..ef27397b3 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -58,14 +58,12 @@ functions: export MONGO_ORCHESTRATION_HOME="$DRIVERS_TOOLS/.evergreen/orchestration" export MONGODB_BINARIES="$DRIVERS_TOOLS/mongodb/bin" - export UPLOAD_BUCKET="${project}" cat < expansion.yml CURRENT_VERSION: "$CURRENT_VERSION" DRIVERS_TOOLS: "$DRIVERS_TOOLS" MONGO_ORCHESTRATION_HOME: "$MONGO_ORCHESTRATION_HOME" MONGODB_BINARIES: "$MONGODB_BINARIES" - UPLOAD_BUCKET: "$UPLOAD_BUCKET" PROJECT_DIRECTORY: "$PROJECT_DIRECTORY" PREPARE_SHELL: | set -o errexit @@ -73,7 +71,6 @@ functions: export DRIVERS_TOOLS="$DRIVERS_TOOLS" export MONGO_ORCHESTRATION_HOME="$MONGO_ORCHESTRATION_HOME" export MONGODB_BINARIES="$MONGODB_BINARIES" - export UPLOAD_BUCKET="$UPLOAD_BUCKET" export PROJECT_DIRECTORY="$PROJECT_DIRECTORY" export TMPDIR="$MONGO_ORCHESTRATION_HOME/db" @@ -103,30 +100,35 @@ functions: echo "{ \"releases\": { \"default\": \"$MONGODB_BINARIES\" }}" > $MONGO_ORCHESTRATION_HOME/orchestration.config "upload coverage" : + - command: ec2.assume_role + params: + role_arn: ${assume_role_arn} - command: s3.put params: - aws_key: ${aws_key} - aws_secret: ${aws_secret} + aws_key: ${AWS_ACCESS_KEY_ID} + aws_secret: ${AWS_SECRET_ACCESS_KEY} + aws_session_token: ${AWS_SESSION_TOKEN} local_file: src/.coverage optional: true # Upload the coverage report for all tasks in a single build to the same directory. - remote_file: ${UPLOAD_BUCKET}/coverage/${revision}/${version_id}/coverage/coverage.${build_variant}.${task_name} - bucket: mciuploads + remote_file: coverage/${revision}/${version_id}/coverage/coverage.${build_variant}.${task_name} + bucket: ${bucket_name} permissions: public-read content_type: text/html display_name: "Raw Coverage Report" "download and merge coverage" : + - command: ec2.assume_role + params: + role_arn: ${assume_role_arn} - command: shell.exec params: silent: true working_dir: "src" + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] script: | - export AWS_ACCESS_KEY_ID=${aws_key} - export AWS_SECRET_ACCESS_KEY=${aws_secret} - # Download all the task coverage files. - aws s3 cp --recursive s3://mciuploads/${UPLOAD_BUCKET}/coverage/${revision}/${version_id}/coverage/ coverage/ + aws s3 cp --recursive s3://${bucket_name}/coverage/${revision}/${version_id}/coverage/ coverage/ - command: shell.exec params: working_dir: "src" @@ -138,24 +140,27 @@ functions: params: silent: true working_dir: "src" + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] script: | - export AWS_ACCESS_KEY_ID=${aws_key} - export AWS_SECRET_ACCESS_KEY=${aws_secret} - aws s3 cp htmlcov/ s3://mciuploads/${UPLOAD_BUCKET}/coverage/${revision}/${version_id}/htmlcov/ --recursive --acl public-read --region us-east-1 + aws s3 cp htmlcov/ s3://${bucket_name}/coverage/${revision}/${version_id}/htmlcov/ --recursive --acl public-read --region us-east-1 # Attach the index.html with s3.put so it shows up in the Evergreen UI. - command: s3.put params: - aws_key: ${aws_key} - aws_secret: ${aws_secret} + aws_key: ${AWS_ACCESS_KEY_ID} + aws_secret: ${AWS_SECRET_ACCESS_KEY} + aws_session_token: ${AWS_SESSION_TOKEN} local_file: src/htmlcov/index.html - remote_file: ${UPLOAD_BUCKET}/coverage/${revision}/${version_id}/htmlcov/index.html - bucket: mciuploads + remote_file: coverage/${revision}/${version_id}/htmlcov/index.html + bucket: ${bucket_name} permissions: public-read content_type: text/html display_name: "Coverage Report HTML" "upload mo artifacts": + - command: ec2.assume_role + params: + role_arn: ${assume_role_arn} - command: shell.exec params: script: | @@ -174,37 +179,43 @@ functions: - "./**.mdmp" # Windows: minidumps - command: s3.put params: - aws_key: ${aws_key} - aws_secret: ${aws_secret} + aws_key: ${AWS_ACCESS_KEY_ID} + aws_secret: ${AWS_SECRET_ACCESS_KEY} + aws_session_token: ${AWS_SESSION_TOKEN} local_file: mongo-coredumps.tgz - remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/coredumps/${task_id}-${execution}-mongodb-coredumps.tar.gz - bucket: mciuploads + remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/coredumps/${task_id}-${execution}-mongodb-coredumps.tar.gz + bucket: ${bucket_name} permissions: public-read content_type: ${content_type|application/gzip} display_name: Core Dumps - Execution optional: true - command: s3.put params: - aws_key: ${aws_key} - aws_secret: ${aws_secret} + aws_key: ${AWS_ACCESS_KEY_ID} + aws_secret: ${AWS_SECRET_ACCESS_KEY} + aws_session_token: ${AWS_SESSION_TOKEN} local_file: mongodb-logs.tar.gz - remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-mongodb-logs.tar.gz - bucket: mciuploads + remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-mongodb-logs.tar.gz + bucket: ${bucket_name} permissions: public-read content_type: ${content_type|application/x-gzip} display_name: "mongodb-logs.tar.gz" - command: s3.put params: - aws_key: ${aws_key} - aws_secret: ${aws_secret} + aws_key: ${AWS_ACCESS_KEY_ID} + aws_secret: ${AWS_SECRET_ACCESS_KEY} + aws_session_token: ${AWS_SESSION_TOKEN} local_file: drivers-tools/.evergreen/orchestration/server.log - remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-orchestration.log - bucket: mciuploads + remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-orchestration.log + bucket: ${bucket_name} permissions: public-read content_type: ${content_type|text/plain} display_name: "orchestration.log" "upload working dir": + - command: ec2.assume_role + params: + role_arn: ${assume_role_arn} - command: archive.targz_pack params: target: "working-dir.tar.gz" @@ -213,11 +224,12 @@ functions: - "./**" - command: s3.put params: - aws_key: ${aws_key} - aws_secret: ${aws_secret} + aws_key: ${AWS_ACCESS_KEY_ID} + aws_secret: ${AWS_SECRET_ACCESS_KEY} + aws_session_token: ${AWS_SESSION_TOKEN} local_file: working-dir.tar.gz - remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-working-dir.tar.gz - bucket: mciuploads + remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-working-dir.tar.gz + bucket: ${bucket_name} permissions: public-read content_type: ${content_type|application/x-gzip} display_name: "working-dir.tar.gz" @@ -232,11 +244,12 @@ functions: - "*.lock" - command: s3.put params: - aws_key: ${aws_key} - aws_secret: ${aws_secret} + aws_key: ${AWS_ACCESS_KEY_ID} + aws_secret: ${AWS_SECRET_ACCESS_KEY} + aws_session_token: ${AWS_SESSION_TOKEN} local_file: drivers-dir.tar.gz - remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-drivers-dir.tar.gz - bucket: mciuploads + remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-drivers-dir.tar.gz + bucket: ${bucket_name} permissions: public-read content_type: ${content_type|application/x-gzip} display_name: "drivers-dir.tar.gz" @@ -785,6 +798,9 @@ functions: VERSION=${VERSION} ENSURE_UNIVERSAL2=${ENSURE_UNIVERSAL2} .evergreen/release.sh "upload release": + - command: ec2.assume_role + params: + role_arn: ${assume_role_arn} - command: archive.targz_pack params: target: "release-files.tgz" @@ -793,25 +809,27 @@ functions: - "*" - command: s3.put params: - aws_key: ${aws_key} - aws_secret: ${aws_secret} + aws_key: ${AWS_ACCESS_KEY_ID} + aws_secret: ${AWS_SECRET_ACCESS_KEY} + aws_session_token: ${AWS_SESSION_TOKEN} local_file: release-files.tgz - remote_file: ${UPLOAD_BUCKET}/release/${revision}/${task_id}-${execution}-release-files.tar.gz - bucket: mciuploads + remote_file: release/${revision}/${task_id}-${execution}-release-files.tar.gz + bucket: ${bucket_name} permissions: public-read content_type: ${content_type|application/gzip} display_name: Release files "download and merge releases": + - command: ec2.assume_role + params: + role_arn: ${assume_role_arn} - command: shell.exec params: silent: true + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] script: | - export AWS_ACCESS_KEY_ID=${aws_key} - export AWS_SECRET_ACCESS_KEY=${aws_secret} - # Download all the task coverage files. - aws s3 cp --recursive s3://mciuploads/${UPLOAD_BUCKET}/release/${revision}/ release/ + aws s3 cp --recursive s3://${bucket_name}/release/${revision}/ release/ - command: shell.exec params: shell: "bash" @@ -845,11 +863,12 @@ functions: - "*" - command: s3.put params: - aws_key: ${aws_key} - aws_secret: ${aws_secret} + aws_key: ${AWS_ACCESS_KEY_ID} + aws_secret: ${AWS_SECRET_ACCESS_KEY} + aws_session_token: ${AWS_SESSION_TOKEN} local_file: release-files-all.tgz - remote_file: ${UPLOAD_BUCKET}/release-all/${revision}/${task_id}-${execution}-release-files-all.tar.gz - bucket: mciuploads + remote_file: release-all/${revision}/${task_id}-${execution}-release-files-all.tar.gz + bucket: ${bucket_name} permissions: public-read content_type: ${content_type|application/gzip} display_name: Release files all @@ -2108,7 +2127,7 @@ tasks: script: | ${PREPARE_SHELL} export PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3 - export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/mciuploads/libmongocrypt/debian10/master/latest/libmongocrypt.tar.gz + export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/${bucket_name}/libmongocrypt/debian10/master/latest/libmongocrypt.tar.gz SUCCESS=false TEST_FLE_GCP_AUTO=1 ./.evergreen/tox.sh -m test-eg - name: testazurekms-task diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index d47e3a950..1f54717a1 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -258,6 +258,10 @@ if [ -z "$GREEN_FRAMEWORK" ]; then # Use --capture=tee-sys so pytest prints test output inline: # https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS + if [ -z "$TEST_ARGS" ]; then # TODO: remove this in PYTHON-4528 + python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 test/synchronous/ $TEST_ARGS + fi + python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 test/asynchronous/ $TEST_ARGS else python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS fi diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index abdd98b72..370b8759e 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -26,9 +26,6 @@ jobs: # required for all workflows security-events: write - # required to fetch internal or private CodeQL packs - packages: read - strategy: fail-fast: false matrix: diff --git a/.github/workflows/dist.yml b/.github/workflows/dist.yml index 8ac1d00a6..7ec55dd3b 100644 --- a/.github/workflows/dist.yml +++ b/.github/workflows/dist.yml @@ -10,6 +10,10 @@ on: workflow_dispatch: pull_request: workflow_call: + inputs: + ref: + required: true + type: string concurrency: group: dist-${{ github.ref }} @@ -44,6 +48,7 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 + ref: ${{ inputs.ref }} - uses: actions/setup-python@v5 with: @@ -99,6 +104,7 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 + ref: ${{ inputs.ref }} - uses: actions/setup-python@v5 with: diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml index 5643ee1e3..5feb0d1ab 100644 --- a/.github/workflows/release-python.yml +++ b/.github/workflows/release-python.yml @@ -19,7 +19,7 @@ env: PRODUCT_NAME: PyMongo # Changes per branch SILK_ASSET_GROUP: mongodb-python-driver - EVERGREEN_PROJECT: mongodb-python-driver + EVERGREEN_PROJECT: mongo-python-driver defaults: run: @@ -32,6 +32,8 @@ jobs: permissions: id-token: write contents: write + outputs: + version: ${{ steps.pre-publish.outputs.version }} steps: - uses: mongodb-labs/drivers-github-tools/secure-checkout@v2 with: @@ -44,6 +46,7 @@ jobs: aws_secret_id: ${{ secrets.AWS_SECRET_ID }} artifactory_username: ${{ vars.ARTIFACTORY_USERNAME }} - uses: mongodb-labs/drivers-github-tools/python/pre-publish@v2 + id: pre-publish with: version: ${{ inputs.version }} dry_run: ${{ inputs.dry_run }} @@ -51,12 +54,16 @@ jobs: build-dist: needs: [pre-publish] uses: ./.github/workflows/dist.yml + with: + ref: ${{ needs.pre-publish.outputs.version }} static-scan: needs: [pre-publish] uses: ./.github/workflows/codeql.yml + permissions: + security-events: write with: - ref: ${{ github.ref }} + ref: ${{ needs.pre-publish.outputs.version }} publish: needs: [build-dist, static-scan] diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index b93c93c02..cbac42f54 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -71,6 +71,9 @@ jobs: - name: Run tests run: | tox -m test + - name: Run async tests + run: | + tox -m test-async doctest: runs-on: ubuntu-latest @@ -203,3 +206,5 @@ jobs: which python pip install -e ".[test]" pytest -v + pytest -v test/synchronous/ + pytest -v test/asynchronous/ diff --git a/README.md b/README.md index 3d13f1aa9..ed434b02b 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ a native Python driver for MongoDB. The `gridfs` package is a [gridfs](https://github.com/mongodb/specifications/blob/master/source/gridfs/gridfs-spec.rst/) implementation on top of `pymongo`. -PyMongo supports MongoDB 3.6, 4.0, 4.2, 4.4, 5.0, 6.0, and 7.0. +PyMongo supports MongoDB 3.6, 4.0, 4.2, 4.4, 5.0, 6.0, 7.0, and 8.0. ## Support / Feedback diff --git a/doc/changelog.rst b/doc/changelog.rst index da52dab47..420c6c6de 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -6,7 +6,11 @@ Changes in Version 4.9.0 PyMongo 4.9 brings a number of improvements including: +- Added support for MongoDB 8.0. - A new asynchronous API with full asyncio support. +- Add support for :attr:`~pymongo.encryption.Algorithm.RANGE` and deprecate + :attr:`~pymongo.encryption.Algorithm.RANGEPREVIEW`. +- pymongocrypt>=1.10 is now required for :ref:`In-Use Encryption` support. Issues Resolved ............... @@ -26,12 +30,20 @@ PyMongo 4.8 brings a number of improvements including: - The handshake metadata for "os.name" on Windows has been simplified to "Windows" to improve import time. - The repr of ``bson.binary.Binary`` is now redacted when the subtype is SENSITIVE_SUBTYPE(8). +- Secure Software Development Life Cycle automation for release process. + GitHub Releases now include a Software Bill of Materials, and signature + files corresponding to the distribution files released on PyPI. +- Fixed a bug in change streams where both ``startAtOperationTime`` and ``resumeToken`` + could be added to a retry attempt, which caused the retry to fail. +- Fallback to stdlib ``ssl`` module when ``pyopenssl`` import fails with AttributeError. +- Improved performance of MongoClient operations, especially when many operations are being run concurrently. Unavoidable breaking changes ............................ -- Since we are now using ``hatch`` as our build backend, we no longer have a ``setup.py`` file - and require installation using ``pip``. +- Since we are now using ``hatch`` as our build backend, we no longer have a usable ``setup.py`` file + and require installation using ``pip``. Attempts to invoke the ``setup.py`` file will raise an exception. + Additionally, ``pip`` >= 21.3 is now required for editable installs. Issues Resolved ............... diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index 08174fd9d..9546429a3 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -42,13 +42,12 @@ from gridfs.grid_file_shared import ( _clear_entity_type_registry, ) from pymongo import ASCENDING, DESCENDING, WriteConcern, _csot -from pymongo.asynchronous.client_session import ClientSession +from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.collection import AsyncCollection -from pymongo.asynchronous.common import validate_string from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.database import AsyncDatabase -from pymongo.asynchronous.helpers import _check_write_command_response, anext -from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode +from pymongo.asynchronous.helpers import anext +from pymongo.common import validate_string from pymongo.errors import ( BulkWriteError, ConfigurationError, @@ -57,11 +56,13 @@ from pymongo.errors import ( InvalidOperation, OperationFailure, ) +from pymongo.helpers_shared import _check_write_command_response +from pymongo.read_preferences import ReadPreference, _ServerMode _IS_SYNC = False -def _disallow_transactions(session: Optional[ClientSession]) -> None: +def _disallow_transactions(session: Optional[AsyncClientSession]) -> None: if session and session.in_transaction: raise InvalidOperation("GridFS does not support multi-document transactions") @@ -155,7 +156,7 @@ class AsyncGridFS: await grid_file.write(data) return await grid_file._id - async def get(self, file_id: Any, session: Optional[ClientSession] = None) -> AsyncGridOut: + async def get(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> AsyncGridOut: """Get a file from GridFS by ``"_id"``. Returns an instance of :class:`~gridfs.grid_file.GridOut`, @@ -163,7 +164,7 @@ class AsyncGridFS: :param file_id: ``"_id"`` of the file to get :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -178,7 +179,7 @@ class AsyncGridFS: self, filename: Optional[str] = None, version: Optional[int] = -1, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> AsyncGridOut: """Get a file from GridFS by ``"filename"`` or metadata fields. @@ -205,7 +206,7 @@ class AsyncGridFS: :param version: version of the file to get (defaults to -1, the most recent version uploaded) :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :param kwargs: find files by custom metadata. .. versionchanged:: 3.6 @@ -234,7 +235,10 @@ class AsyncGridFS: raise NoFile("no version %d for filename %r" % (version, filename)) from None async def get_last_version( - self, filename: Optional[str] = None, session: Optional[ClientSession] = None, **kwargs: Any + self, + filename: Optional[str] = None, + session: Optional[AsyncClientSession] = None, + **kwargs: Any, ) -> AsyncGridOut: """Get the most recent version of a file in GridFS by ``"filename"`` or metadata fields. @@ -244,7 +248,7 @@ class AsyncGridFS: :param filename: ``"filename"`` of the file to get, or `None` :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :param kwargs: find files by custom metadata. .. versionchanged:: 3.6 @@ -253,7 +257,7 @@ class AsyncGridFS: return await self.get_version(filename=filename, session=session, **kwargs) # TODO add optional safe mode for chunk removal? - async def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None: + async def delete(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> None: """Delete a file from GridFS by ``"_id"``. Deletes all data belonging to the file with ``"_id"``: @@ -269,7 +273,7 @@ class AsyncGridFS: :param file_id: ``"_id"`` of the file to delete :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -281,12 +285,12 @@ class AsyncGridFS: await self._files.delete_one({"_id": file_id}, session=session) await self._chunks.delete_many({"files_id": file_id}, session=session) - async def list(self, session: Optional[ClientSession] = None) -> list[str]: + async def list(self, session: Optional[AsyncClientSession] = None) -> list[str]: """List the names of all files stored in this instance of :class:`GridFS`. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -306,7 +310,7 @@ class AsyncGridFS: async def find_one( self, filter: Optional[Any] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, *args: Any, **kwargs: Any, ) -> Optional[AsyncGridOut]: @@ -327,7 +331,7 @@ class AsyncGridFS: :param args: any additional positional arguments are the same as the arguments to :meth:`find`. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :param kwargs: any additional keyword arguments are the same as the arguments to :meth:`find`. @@ -370,7 +374,7 @@ class AsyncGridFS: :meth:`~pymongo.collection.Collection.find` in :class:`~pymongo.collection.Collection`. - If a :class:`~pymongo.client_session.ClientSession` is passed to + If a :class:`~pymongo.client_session.AsyncClientSession` is passed to :meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances are associated with that session. @@ -406,7 +410,7 @@ class AsyncGridFS: async def exists( self, document_or_id: Optional[Any] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> bool: """Check if a file exists in this instance of :class:`GridFS`. @@ -438,7 +442,7 @@ class AsyncGridFS: :param document_or_id: query document, or _id of the document to check for :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :param kwargs: keyword arguments are used as a query document, if they're present. @@ -525,7 +529,7 @@ class AsyncGridFSBucket: filename: str, chunk_size_bytes: Optional[int] = None, metadata: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> AsyncGridIn: """Opens a Stream that the application can write the contents of the file to. @@ -556,7 +560,7 @@ class AsyncGridFSBucket: files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -580,7 +584,7 @@ class AsyncGridFSBucket: filename: str, chunk_size_bytes: Optional[int] = None, metadata: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> AsyncGridIn: """Opens a Stream that the application can write the contents of the file to. @@ -615,7 +619,7 @@ class AsyncGridFSBucket: files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -641,7 +645,7 @@ class AsyncGridFSBucket: source: Any, chunk_size_bytes: Optional[int] = None, metadata: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> ObjectId: """Uploads a user file to a GridFS bucket. @@ -672,7 +676,7 @@ class AsyncGridFSBucket: files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -692,7 +696,7 @@ class AsyncGridFSBucket: source: Any, chunk_size_bytes: Optional[int] = None, metadata: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> None: """Uploads a user file to a GridFS bucket with a custom file id. @@ -724,7 +728,7 @@ class AsyncGridFSBucket: files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -735,7 +739,7 @@ class AsyncGridFSBucket: await gin.write(source) async def open_download_stream( - self, file_id: Any, session: Optional[ClientSession] = None + self, file_id: Any, session: Optional[AsyncClientSession] = None ) -> AsyncGridOut: """Opens a Stream from which the application can read the contents of the stored file specified by file_id. @@ -755,7 +759,7 @@ class AsyncGridFSBucket: :param file_id: The _id of the file to be downloaded. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -768,7 +772,7 @@ class AsyncGridFSBucket: @_csot.apply async def download_to_stream( - self, file_id: Any, destination: Any, session: Optional[ClientSession] = None + self, file_id: Any, destination: Any, session: Optional[AsyncClientSession] = None ) -> None: """Downloads the contents of the stored file specified by file_id and writes the contents to `destination`. @@ -790,7 +794,7 @@ class AsyncGridFSBucket: :param file_id: The _id of the file to be downloaded. :param destination: a file-like object implementing :meth:`write`. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -803,7 +807,7 @@ class AsyncGridFSBucket: destination.write(chunk) @_csot.apply - async def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None: + async def delete(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> None: """Given an file_id, delete this stored file's files collection document and associated chunks from a GridFS bucket. @@ -819,7 +823,7 @@ class AsyncGridFSBucket: :param file_id: The _id of the file to be deleted. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -859,7 +863,7 @@ class AsyncGridFSBucket: :meth:`~pymongo.collection.Collection.find` in :class:`~pymongo.collection.Collection`. - If a :class:`~pymongo.client_session.ClientSession` is passed to + If a :class:`~pymongo.client_session.AsyncClientSession` is passed to :meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances are associated with that session. @@ -878,7 +882,7 @@ class AsyncGridFSBucket: return AsyncGridOutCursor(self._collection, *args, **kwargs) async def open_download_stream_by_name( - self, filename: str, revision: int = -1, session: Optional[ClientSession] = None + self, filename: str, revision: int = -1, session: Optional[AsyncClientSession] = None ) -> AsyncGridOut: """Opens a Stream from which the application can read the contents of `filename` and optional `revision`. @@ -902,7 +906,7 @@ class AsyncGridFSBucket: filename and different uploadDate) of the file to retrieve. Defaults to -1 (the most recent revision). :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :Note: Revision numbers are defined as follows: @@ -937,7 +941,7 @@ class AsyncGridFSBucket: filename: str, destination: Any, revision: int = -1, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> None: """Write the contents of `filename` (with optional `revision`) to `destination`. @@ -961,7 +965,7 @@ class AsyncGridFSBucket: filename and different uploadDate) of the file to retrieve. Defaults to -1 (the most recent revision). :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :Note: Revision numbers are defined as follows: @@ -985,7 +989,7 @@ class AsyncGridFSBucket: destination.write(chunk) async def rename( - self, file_id: Any, new_filename: str, session: Optional[ClientSession] = None + self, file_id: Any, new_filename: str, session: Optional[AsyncClientSession] = None ) -> None: """Renames the stored file with the specified file_id. @@ -1002,7 +1006,7 @@ class AsyncGridFSBucket: :param file_id: The _id of the file to be renamed. :param new_filename: The new name of the file. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -1024,7 +1028,7 @@ class AsyncGridIn: def __init__( self, root_collection: AsyncCollection, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> None: """Write a file to GridFS @@ -1059,7 +1063,7 @@ class AsyncGridIn: :param root_collection: root collection to write to :param session: a - :class:`~pymongo.client_session.ClientSession` to use for all + :class:`~pymongo.client_session.AsyncClientSession` to use for all commands :param kwargs: Any: file level options (see above) @@ -1402,7 +1406,7 @@ class AsyncGridOut(io.IOBase): root_collection: AsyncCollection, file_id: Optional[int] = None, file_document: Optional[Any] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> None: """Read a file from GridFS @@ -1420,7 +1424,7 @@ class AsyncGridOut(io.IOBase): :param file_document: file document from `root_collection.files` :param session: a - :class:`~pymongo.client_session.ClientSession` to use for all + :class:`~pymongo.client_session.AsyncClientSession` to use for all commands .. versionchanged:: 3.8 @@ -1734,7 +1738,7 @@ class _AsyncGridOutChunkIterator: self, grid_out: AsyncGridOut, chunks: AsyncCollection, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], next_chunk: Any, ) -> None: self._id = grid_out._id @@ -1824,7 +1828,9 @@ class _AsyncGridOutChunkIterator: class AsyncGridOutIterator: - def __init__(self, grid_out: AsyncGridOut, chunks: AsyncCollection, session: ClientSession): + def __init__( + self, grid_out: AsyncGridOut, chunks: AsyncCollection, session: AsyncClientSession + ): self._chunk_iter = _AsyncGridOutChunkIterator(grid_out, chunks, session, 0) def __aiter__(self) -> AsyncGridOutIterator: @@ -1851,7 +1857,7 @@ class AsyncGridOutCursor(AsyncCursor): no_cursor_timeout: bool = False, sort: Optional[Any] = None, batch_size: int = 0, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> None: """Create a new cursor, similar to the normal :class:`~pymongo.cursor.Cursor`. @@ -1894,6 +1900,6 @@ class AsyncGridOutCursor(AsyncCursor): def remove_option(self, *args: Any, **kwargs: Any) -> NoReturn: raise NotImplementedError("Method does not exist for GridOutCursor") - def _clone_base(self, session: Optional[ClientSession]) -> AsyncGridOutCursor: + def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncGridOutCursor: """Creates an empty GridOutCursor for information to be copied into.""" return AsyncGridOutCursor(self._root_collection, session=session) diff --git a/gridfs/grid_file_shared.py b/gridfs/grid_file_shared.py index f6c37b9f3..b6f02a53d 100644 --- a/gridfs/grid_file_shared.py +++ b/gridfs/grid_file_shared.py @@ -5,7 +5,7 @@ import warnings from typing import Any, Optional from pymongo import ASCENDING -from pymongo.asynchronous.common import MAX_MESSAGE_SIZE +from pymongo.common import MAX_MESSAGE_SIZE from pymongo.errors import InvalidOperation _SEEK_SET = os.SEEK_SET diff --git a/gridfs/synchronous/grid_file.py b/gridfs/synchronous/grid_file.py index 0e9842992..98374cc8c 100644 --- a/gridfs/synchronous/grid_file.py +++ b/gridfs/synchronous/grid_file.py @@ -42,6 +42,7 @@ from gridfs.grid_file_shared import ( _grid_out_property, ) from pymongo import ASCENDING, DESCENDING, WriteConcern, _csot +from pymongo.common import validate_string from pymongo.errors import ( BulkWriteError, ConfigurationError, @@ -50,13 +51,13 @@ from pymongo.errors import ( InvalidOperation, OperationFailure, ) +from pymongo.helpers_shared import _check_write_command_response +from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.collection import Collection -from pymongo.synchronous.common import validate_string from pymongo.synchronous.cursor import Cursor from pymongo.synchronous.database import Database -from pymongo.synchronous.helpers import _check_write_command_response, next -from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode +from pymongo.synchronous.helpers import next _IS_SYNC = True @@ -234,7 +235,10 @@ class GridFS: raise NoFile("no version %d for filename %r" % (version, filename)) from None def get_last_version( - self, filename: Optional[str] = None, session: Optional[ClientSession] = None, **kwargs: Any + self, + filename: Optional[str] = None, + session: Optional[ClientSession] = None, + **kwargs: Any, ) -> GridOut: """Get the most recent version of a file in GridFS by ``"filename"`` or metadata fields. @@ -497,7 +501,7 @@ class GridFSBucket: .. seealso:: The MongoDB documentation on `gridfs `_. """ if not isinstance(db, Database): - raise TypeError("database must be an instance of AsyncDatabase") + raise TypeError("database must be an instance of Database") db = _clear_entity_type_registry(db) @@ -1028,7 +1032,7 @@ class GridIn: provided by :class:`~gridfs.GridFS`. Raises :class:`TypeError` if `root_collection` is not an - instance of :class:`~pymongo.collection.AsyncCollection`. + instance of :class:`~pymongo.collection.Collection`. Any of the file level options specified in the `GridFS Spec `_ may be passed as @@ -1069,10 +1073,10 @@ class GridIn: .. versionchanged:: 3.0 `root_collection` must use an acknowledged - :attr:`~pymongo.collection.AsyncCollection.write_concern` + :attr:`~pymongo.collection.Collection.write_concern` """ if not isinstance(root_collection, Collection): - raise TypeError("root_collection must be an instance of AsyncCollection") + raise TypeError("root_collection must be an instance of Collection") if not root_collection.write_concern.acknowledged: raise ConfigurationError("root_collection must use acknowledged write_concern") @@ -1401,7 +1405,7 @@ class GridOut(io.IOBase): Either `file_id` or `file_document` must be specified, `file_document` will be given priority if present. Raises :class:`TypeError` if `root_collection` is not an instance of - :class:`~pymongo.collection.AsyncCollection`. + :class:`~pymongo.collection.Collection`. :param root_collection: root collection to read from :param file_id: value of ``"_id"`` for the file to read @@ -1424,7 +1428,7 @@ class GridOut(io.IOBase): from the server. Metadata is fetched when first needed. """ if not isinstance(root_collection, Collection): - raise TypeError("root_collection must be an instance of AsyncCollection") + raise TypeError("root_collection must be an instance of Collection") _disallow_transactions(session) root_collection = _clear_entity_type_registry(root_collection) @@ -1482,7 +1486,7 @@ class GridOut(io.IOBase): self.open() # type: ignore[unused-coroutine] elif not self._file: raise InvalidOperation( - "You must call AsyncGridOut.open() before accessing the %s property" % name + "You must call GridOut.open() before accessing the %s property" % name ) if name in self._file: return self._file[name] @@ -1677,13 +1681,13 @@ class GridOut(io.IOBase): return False def __enter__(self) -> GridOut: - """Makes it possible to use :class:`AsyncGridOut` files + """Makes it possible to use :class:`GridOut` files with the async context manager protocol. """ return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any: - """Makes it possible to use :class:`AsyncGridOut` files + """Makes it possible to use :class:`GridOut` files with the async context manager protocol. """ self.close() diff --git a/pymongo/__init__.py b/pymongo/__init__.py index 8992281db..7ee177bda 100644 --- a/pymongo/__init__.py +++ b/pymongo/__init__.py @@ -89,11 +89,9 @@ TEXT = "text" from pymongo import _csot from pymongo._version import __version__, get_version_string, version_tuple from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION from pymongo.cursor import CursorType -from pymongo.synchronous.collection import ReturnDocument -from pymongo.synchronous.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION -from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.operations import ( +from pymongo.operations import ( DeleteMany, DeleteOne, IndexModel, @@ -102,7 +100,9 @@ from pymongo.synchronous.operations import ( UpdateMany, UpdateOne, ) -from pymongo.synchronous.read_preferences import ReadPreference +from pymongo.read_preferences import ReadPreference +from pymongo.synchronous.collection import ReturnDocument +from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern version = __version__ diff --git a/pymongo/asynchronous/aggregation.py b/pymongo/asynchronous/aggregation.py index 9fc2dae3c..fa6cefd53 100644 --- a/pymongo/asynchronous/aggregation.py +++ b/pymongo/asynchronous/aggregation.py @@ -18,20 +18,20 @@ from __future__ import annotations from collections.abc import Callable, Mapping, MutableMapping from typing import TYPE_CHECKING, Any, Optional, Union -from pymongo.asynchronous import common -from pymongo.asynchronous.collation import validate_collation_or_none -from pymongo.asynchronous.read_preferences import ReadPreference, _AggWritePref +from pymongo import common +from pymongo.collation import validate_collation_or_none from pymongo.errors import ConfigurationError +from pymongo.read_preferences import ReadPreference, _AggWritePref if TYPE_CHECKING: - from pymongo.asynchronous.client_session import ClientSession + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.database import AsyncDatabase - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.read_preferences import _ServerMode + from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server - from pymongo.asynchronous.typings import _DocumentType, _Pipeline + from pymongo.read_preferences import _ServerMode + from pymongo.typings import _DocumentType, _Pipeline _IS_SYNC = False @@ -53,7 +53,7 @@ class _AggregationCommand: explicit_session: bool, let: Optional[Mapping[str, Any]] = None, user_fields: Optional[MutableMapping[str, Any]] = None, - result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None, + result_processor: Optional[Callable[[Mapping[str, Any], AsyncConnection], None]] = None, comment: Any = None, ) -> None: if "explain" in options: @@ -121,7 +121,7 @@ class _AggregationCommand: raise NotImplementedError def get_read_preference( - self, session: Optional[ClientSession] + self, session: Optional[AsyncClientSession] ) -> Union[_AggWritePref, _ServerMode]: if self._write_preference: return self._write_preference @@ -132,9 +132,9 @@ class _AggregationCommand: async def get_cursor( self, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: _ServerMode, ) -> AsyncCommandCursor[_DocumentType]: # Serialize command. diff --git a/pymongo/asynchronous/auth.py b/pymongo/asynchronous/auth.py index 41e012022..1fb28f6c4 100644 --- a/pymongo/asynchronous/auth.py +++ b/pymongo/asynchronous/auth.py @@ -18,17 +18,13 @@ from __future__ import annotations import functools import hashlib import hmac -import os import socket -import typing from base64 import standard_b64decode, standard_b64encode -from collections import namedtuple from typing import ( TYPE_CHECKING, Any, Callable, Coroutine, - Dict, Mapping, MutableMapping, Optional, @@ -41,17 +37,19 @@ from pymongo.asynchronous.auth_aws import _authenticate_aws from pymongo.asynchronous.auth_oidc import ( _authenticate_oidc, _get_authenticator, - _OIDCAzureCallback, - _OIDCGCPCallback, - _OIDCProperties, - _OIDCTestCallback, +) +from pymongo.auth_shared import ( + MongoCredential, + _authenticate_scram_start, + _parse_scram_response, + _xor, ) from pymongo.errors import ConfigurationError, OperationFailure from pymongo.saslprep import saslprep if TYPE_CHECKING: - from pymongo.asynchronous.hello import Hello - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.hello import Hello HAVE_KERBEROS = True _USE_PRINCIPAL = False @@ -69,213 +67,9 @@ except ImportError: _IS_SYNC = False -MECHANISMS = frozenset( - [ - "GSSAPI", - "MONGODB-CR", - "MONGODB-OIDC", - "MONGODB-X509", - "MONGODB-AWS", - "PLAIN", - "SCRAM-SHA-1", - "SCRAM-SHA-256", - "DEFAULT", - ] -) -"""The authentication mechanisms supported by PyMongo.""" - - -class _Cache: - __slots__ = ("data",) - - _hash_val = hash("_Cache") - - def __init__(self) -> None: - self.data = None - - def __eq__(self, other: object) -> bool: - # Two instances must always compare equal. - if isinstance(other, _Cache): - return True - return NotImplemented - - def __ne__(self, other: object) -> bool: - if isinstance(other, _Cache): - return False - return NotImplemented - - def __hash__(self) -> int: - return self._hash_val - - -MongoCredential = namedtuple( - "MongoCredential", - ["mechanism", "source", "username", "password", "mechanism_properties", "cache"], -) -"""A hashable namedtuple of values used for authentication.""" - - -GSSAPIProperties = namedtuple( - "GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"] -) -"""Mechanism properties for GSSAPI authentication.""" - - -_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"]) -"""Mechanism properties for MONGODB-AWS authentication.""" - - -def _build_credentials_tuple( - mech: str, - source: Optional[str], - user: str, - passwd: str, - extra: Mapping[str, Any], - database: Optional[str], -) -> MongoCredential: - """Build and return a mechanism specific credentials tuple.""" - if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: - raise ConfigurationError(f"{mech} requires a username.") - if mech == "GSSAPI": - if source is not None and source != "$external": - raise ValueError("authentication source must be $external or None for GSSAPI") - properties = extra.get("authmechanismproperties", {}) - service_name = properties.get("SERVICE_NAME", "mongodb") - canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False)) - service_realm = properties.get("SERVICE_REALM") - props = GSSAPIProperties( - service_name=service_name, - canonicalize_host_name=canonicalize, - service_realm=service_realm, - ) - # Source is always $external. - return MongoCredential(mech, "$external", user, passwd, props, None) - elif mech == "MONGODB-X509": - if passwd is not None: - raise ConfigurationError("Passwords are not supported by MONGODB-X509") - if source is not None and source != "$external": - raise ValueError("authentication source must be $external or None for MONGODB-X509") - # Source is always $external, user can be None. - return MongoCredential(mech, "$external", user, None, None, None) - elif mech == "MONGODB-AWS": - if user is not None and passwd is None: - raise ConfigurationError("username without a password is not supported by MONGODB-AWS") - if source is not None and source != "$external": - raise ConfigurationError( - "authentication source must be $external or None for MONGODB-AWS" - ) - - properties = extra.get("authmechanismproperties", {}) - aws_session_token = properties.get("AWS_SESSION_TOKEN") - aws_props = _AWSProperties(aws_session_token=aws_session_token) - # user can be None for temporary link-local EC2 credentials. - return MongoCredential(mech, "$external", user, passwd, aws_props, None) - elif mech == "MONGODB-OIDC": - properties = extra.get("authmechanismproperties", {}) - callback = properties.get("OIDC_CALLBACK") - human_callback = properties.get("OIDC_HUMAN_CALLBACK") - environ = properties.get("ENVIRONMENT") - token_resource = properties.get("TOKEN_RESOURCE", "") - default_allowed = [ - "*.mongodb.net", - "*.mongodb-dev.net", - "*.mongodb-qa.net", - "*.mongodbgov.net", - "localhost", - "127.0.0.1", - "::1", - ] - allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed) - msg = ( - "authentication with MONGODB-OIDC requires providing either a callback or a environment" - ) - if passwd is not None: - msg = "password is not supported by MONGODB-OIDC" - raise ConfigurationError(msg) - if callback or human_callback: - if environ is not None: - raise ConfigurationError(msg) - if callback and human_callback: - msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK" - raise ConfigurationError(msg) - elif environ is not None: - if environ == "test": - if user is not None: - msg = "test environment for MONGODB-OIDC does not support username" - raise ConfigurationError(msg) - callback = _OIDCTestCallback() - elif environ == "azure": - passwd = None - if not token_resource: - raise ConfigurationError( - "Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" - ) - callback = _OIDCAzureCallback(token_resource) - elif environ == "gcp": - passwd = None - if not token_resource: - raise ConfigurationError( - "GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" - ) - callback = _OIDCGCPCallback(token_resource) - else: - raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}") - else: - raise ConfigurationError(msg) - - oidc_props = _OIDCProperties( - callback=callback, - human_callback=human_callback, - environment=environ, - allowed_hosts=allowed_hosts, - token_resource=token_resource, - username=user, - ) - return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache()) - - elif mech == "PLAIN": - source_database = source or database or "$external" - return MongoCredential(mech, source_database, user, passwd, None, None) - else: - source_database = source or database or "admin" - if passwd is None: - raise ConfigurationError("A password is required.") - return MongoCredential(mech, source_database, user, passwd, None, _Cache()) - - -def _xor(fir: bytes, sec: bytes) -> bytes: - """XOR two byte strings together.""" - return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)]) - - -def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]: - """Split a scram response into key, value pairs.""" - return dict( - typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1)) - for item in response.split(b",") - ) - - -def _authenticate_scram_start( - credentials: MongoCredential, mechanism: str -) -> tuple[bytes, bytes, MutableMapping[str, Any]]: - username = credentials.username - user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C") - nonce = standard_b64encode(os.urandom(32)) - first_bare = b"n=" + user + b",r=" + nonce - - cmd = { - "saslStart": 1, - "mechanism": mechanism, - "payload": Binary(b"n,," + first_bare), - "autoAuthorize": 1, - "options": {"skipEmptyExchange": True}, - } - return nonce, first_bare, cmd - async def _authenticate_scram( - credentials: MongoCredential, conn: Connection, mechanism: str + credentials: MongoCredential, conn: AsyncConnection, mechanism: str ) -> None: """Authenticate using SCRAM.""" username = credentials.username @@ -398,7 +192,7 @@ def _canonicalize_hostname(hostname: str) -> str: return name[0].lower() -async def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None: +async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnection) -> None: """Authenticate using GSSAPI.""" if not HAVE_KERBEROS: raise ConfigurationError( @@ -509,7 +303,7 @@ async def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) - raise OperationFailure(str(exc)) from None -async def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None: +async def _authenticate_plain(credentials: MongoCredential, conn: AsyncConnection) -> None: """Authenticate using SASL PLAIN (RFC 4616)""" source = credentials.source username = credentials.username @@ -524,7 +318,7 @@ async def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> await conn.command(source, cmd) -async def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None: +async def _authenticate_x509(credentials: MongoCredential, conn: AsyncConnection) -> None: """Authenticate using MONGODB-X509.""" ctx = conn.auth_ctx if ctx and ctx.speculate_succeeded(): @@ -535,7 +329,7 @@ async def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> await conn.command("$external", cmd) -async def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None: +async def _authenticate_mongo_cr(credentials: MongoCredential, conn: AsyncConnection) -> None: """Authenticate using MONGODB-CR.""" source = credentials.source username = credentials.username @@ -550,7 +344,7 @@ async def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) await conn.command(source, query) -async def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None: +async def _authenticate_default(credentials: MongoCredential, conn: AsyncConnection) -> None: if conn.max_wire_version >= 7: if conn.negotiated_mechs: mechs = conn.negotiated_mechs @@ -652,7 +446,7 @@ _SPECULATIVE_AUTH_MAP: Mapping[str, Any] = { async def authenticate( - credentials: MongoCredential, conn: Connection, reauthenticate: bool = False + credentials: MongoCredential, conn: AsyncConnection, reauthenticate: bool = False ) -> None: """Authenticate connection.""" mechanism = credentials.mechanism diff --git a/pymongo/asynchronous/auth_aws.py b/pymongo/asynchronous/auth_aws.py index 7cab111b3..9dcc625d1 100644 --- a/pymongo/asynchronous/auth_aws.py +++ b/pymongo/asynchronous/auth_aws.py @@ -23,13 +23,13 @@ from pymongo.errors import ConfigurationError, OperationFailure if TYPE_CHECKING: from bson.typings import _ReadableBuffer - from pymongo.asynchronous.auth import MongoCredential - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.auth_shared import MongoCredential _IS_SYNC = False -async def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None: +async def _authenticate_aws(credentials: MongoCredential, conn: AsyncConnection) -> None: """Authenticate using MONGODB-AWS.""" try: import pymongo_auth_aws # type:ignore[import] diff --git a/pymongo/asynchronous/auth_oidc.py b/pymongo/asynchronous/auth_oidc.py index 022a173dc..f5801b85d 100644 --- a/pymongo/asynchronous/auth_oidc.py +++ b/pymongo/asynchronous/auth_oidc.py @@ -15,79 +15,35 @@ """MONGODB-OIDC Authentication helpers.""" from __future__ import annotations -import abc -import os import threading import time from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union -from urllib.parse import quote import bson from bson.binary import Binary -from pymongo._azure_helpers import _get_azure_response from pymongo._csot import remaining -from pymongo._gcp_helpers import _get_gcp_response +from pymongo.auth_oidc_shared import ( + CALLBACK_VERSION, + HUMAN_CALLBACK_TIMEOUT_SECONDS, + MACHINE_CALLBACK_TIMEOUT_SECONDS, + TIME_BETWEEN_CALLS_SECONDS, + OIDCCallback, + OIDCCallbackContext, + OIDCCallbackResult, + OIDCIdPInfo, + _OIDCProperties, +) from pymongo.errors import ConfigurationError, OperationFailure -from pymongo.helpers_constants import _AUTHENTICATION_FAILURE_CODE +from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE if TYPE_CHECKING: - from pymongo.asynchronous.auth import MongoCredential - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.auth_shared import MongoCredential _IS_SYNC = False -@dataclass -class OIDCIdPInfo: - issuer: str - clientId: Optional[str] = field(default=None) - requestScopes: Optional[list[str]] = field(default=None) - - -@dataclass -class OIDCCallbackContext: - timeout_seconds: float - username: str - version: int - refresh_token: Optional[str] = field(default=None) - idp_info: Optional[OIDCIdPInfo] = field(default=None) - - -@dataclass -class OIDCCallbackResult: - access_token: str - expires_in_seconds: Optional[float] = field(default=None) - refresh_token: Optional[str] = field(default=None) - - -class OIDCCallback(abc.ABC): - """A base class for defining OIDC callbacks.""" - - @abc.abstractmethod - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - """Convert the given BSON value into our own type.""" - - -@dataclass -class _OIDCProperties: - callback: Optional[OIDCCallback] = field(default=None) - human_callback: Optional[OIDCCallback] = field(default=None) - environment: Optional[str] = field(default=None) - allowed_hosts: list[str] = field(default_factory=list) - token_resource: Optional[str] = field(default=None) - username: str = "" - - -"""Mechanism properties for MONGODB-OIDC authentication.""" - -TOKEN_BUFFER_MINUTES = 5 -HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60 -CALLBACK_VERSION = 1 -MACHINE_CALLBACK_TIMEOUT_SECONDS = 60 -TIME_BETWEEN_CALLS_SECONDS = 0.1 - - def _get_authenticator( credentials: MongoCredential, address: tuple[str, int] ) -> _OIDCAuthenticator: @@ -117,48 +73,6 @@ def _get_authenticator( return credentials.cache.data -class _OIDCTestCallback(OIDCCallback): - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - token_file = os.environ.get("OIDC_TOKEN_FILE") - if not token_file: - raise RuntimeError( - 'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set' - ) - with open(token_file) as fid: - return OIDCCallbackResult(access_token=fid.read().strip()) - - -class _OIDCAWSCallback(OIDCCallback): - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") - if not token_file: - raise RuntimeError( - 'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set' - ) - with open(token_file) as fid: - return OIDCCallbackResult(access_token=fid.read().strip()) - - -class _OIDCAzureCallback(OIDCCallback): - def __init__(self, token_resource: str) -> None: - self.token_resource = quote(token_resource) - - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds) - return OIDCCallbackResult( - access_token=resp["access_token"], expires_in_seconds=resp["expires_in"] - ) - - -class _OIDCGCPCallback(OIDCCallback): - def __init__(self, token_resource: str) -> None: - self.token_resource = quote(token_resource) - - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - resp = _get_gcp_response(self.token_resource, context.timeout_seconds) - return OIDCCallbackResult(access_token=resp["access_token"]) - - @dataclass class _OIDCAuthenticator: username: str @@ -170,7 +84,7 @@ class _OIDCAuthenticator: lock: threading.Lock = field(default_factory=threading.Lock) last_call_time: float = field(default=0) - async def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: + async def reauthenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]: """Handle a reauthenticate from the server.""" # Invalidate the token for the connection. self._invalidate(conn) @@ -179,7 +93,7 @@ class _OIDCAuthenticator: return await self._authenticate_machine(conn) return await self._authenticate_human(conn) - async def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: + async def authenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]: """Handle an initial authenticate request.""" # First handle speculative auth. # If it succeeded, we are done. @@ -203,7 +117,7 @@ class _OIDCAuthenticator: return None return self._get_start_command({"jwt": self.access_token}) - async def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]: + async def _authenticate_machine(self, conn: AsyncConnection) -> Mapping[str, Any]: # If there is a cached access token, try to authenticate with it. If # authentication fails with error code 18, invalidate the access token, # fetch a new access token, and try to authenticate again. If authentication @@ -217,7 +131,7 @@ class _OIDCAuthenticator: raise return await self._sasl_start_jwt(conn) - async def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]: + async def _authenticate_human(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]: # If we have a cached access token, try a JwtStepRequest. # authentication fails with error code 18, invalidate the access token, # and try to authenticate again. If authentication fails for any other @@ -307,7 +221,7 @@ class _OIDCAuthenticator: return self.access_token async def _run_command( - self, conn: Connection, cmd: MutableMapping[str, Any] + self, conn: AsyncConnection, cmd: MutableMapping[str, Any] ) -> Mapping[str, Any]: try: return await conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg] @@ -321,7 +235,7 @@ class _OIDCAuthenticator: return False return err.code == _AUTHENTICATION_FAILURE_CODE - def _invalidate(self, conn: Connection) -> None: + def _invalidate(self, conn: AsyncConnection) -> None: # Ignore the invalidation if a token gen id is given and is less than our # current token gen id. token_gen_id = conn.oidc_token_gen_id or 0 @@ -330,7 +244,7 @@ class _OIDCAuthenticator: self.access_token = None async def _sasl_continue_jwt( - self, conn: Connection, start_resp: Mapping[str, Any] + self, conn: AsyncConnection, start_resp: Mapping[str, Any] ) -> Mapping[str, Any]: self.access_token = None self.refresh_token = None @@ -342,7 +256,7 @@ class _OIDCAuthenticator: cmd = self._get_continue_command({"jwt": access_token}, start_resp) return await self._run_command(conn, cmd) - async def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]: + async def _sasl_start_jwt(self, conn: AsyncConnection) -> Mapping[str, Any]: access_token = self._get_access_token() conn.oidc_token_gen_id = self.token_gen_id cmd = self._get_start_command({"jwt": access_token}) @@ -370,7 +284,7 @@ class _OIDCAuthenticator: async def _authenticate_oidc( - credentials: MongoCredential, conn: Connection, reauthenticate: bool + credentials: MongoCredential, conn: AsyncConnection, reauthenticate: bool ) -> Optional[Mapping[str, Any]]: """Authenticate using MONGODB-OIDC.""" authenticator = _get_authenticator(credentials, conn.address) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index f6b45e0fa..c200899dd 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -19,6 +19,8 @@ from __future__ import annotations import copy +import datetime +import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -26,7 +28,6 @@ from typing import ( Any, Iterator, Mapping, - NoReturn, Optional, Type, Union, @@ -34,142 +35,52 @@ from typing import ( from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument -from pymongo import _csot -from pymongo.asynchronous import common -from pymongo.asynchronous.client_session import ClientSession, _validate_session_write_concern -from pymongo.asynchronous.common import ( +from pymongo import _csot, common +from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern +from pymongo.asynchronous.helpers import _handle_reauth +from pymongo.bulk_shared import ( + _COMMANDS, + _DELETE_ALL, + _merge_command, + _raise_bulk_write_error, + _Run, +) +from pymongo.common import ( validate_is_document_type, validate_ok_for_replace, validate_ok_for_update, ) -from pymongo.asynchronous.helpers import _get_wce_doc -from pymongo.asynchronous.message import ( +from pymongo.errors import ( + ConfigurationError, + InvalidOperation, + NotPrimaryError, + OperationFailure, +) +from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import ( _DELETE, _INSERT, _UPDATE, _BulkWriteContext, + _convert_exception, + _convert_write_result, _EncryptedBulkWriteContext, _randint, ) -from pymongo.asynchronous.read_preferences import ReadPreference -from pymongo.errors import ( - BulkWriteError, - ConfigurationError, - InvalidOperation, - OperationFailure, -) -from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES +from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern if TYPE_CHECKING: from pymongo.asynchronous.collection import AsyncCollection - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.typings import _DocumentOut, _DocumentType, _Pipeline + from pymongo.asynchronous.mongo_client import AsyncMongoClient + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline _IS_SYNC = False -_DELETE_ALL: int = 0 -_DELETE_ONE: int = 1 -# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err -_BAD_VALUE: int = 2 -_UNKNOWN_ERROR: int = 8 -_WRITE_CONCERN_ERROR: int = 64 - -_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete") - - -class _Run: - """Represents a batch of write operations.""" - - def __init__(self, op_type: int) -> None: - """Initialize a new Run object.""" - self.op_type: int = op_type - self.index_map: list[int] = [] - self.ops: list[Any] = [] - self.idx_offset: int = 0 - - def index(self, idx: int) -> int: - """Get the original index of an operation in this run. - - :param idx: The Run index that maps to the original index. - """ - return self.index_map[idx] - - def add(self, original_index: int, operation: Any) -> None: - """Add an operation to this Run instance. - - :param original_index: The original index of this operation - within a larger bulk operation. - :param operation: The operation document. - """ - self.index_map.append(original_index) - self.ops.append(operation) - - -def _merge_command( - run: _Run, - full_result: MutableMapping[str, Any], - offset: int, - result: Mapping[str, Any], -) -> None: - """Merge a write command result into the full bulk result.""" - affected = result.get("n", 0) - - if run.op_type == _INSERT: - full_result["nInserted"] += affected - - elif run.op_type == _DELETE: - full_result["nRemoved"] += affected - - elif run.op_type == _UPDATE: - upserted = result.get("upserted") - if upserted: - n_upserted = len(upserted) - for doc in upserted: - doc["index"] = run.index(doc["index"] + offset) - full_result["upserted"].extend(upserted) - full_result["nUpserted"] += n_upserted - full_result["nMatched"] += affected - n_upserted - else: - full_result["nMatched"] += affected - full_result["nModified"] += result["nModified"] - - write_errors = result.get("writeErrors") - if write_errors: - for doc in write_errors: - # Leave the server response intact for APM. - replacement = doc.copy() - idx = doc["index"] + offset - replacement["index"] = run.index(idx) - # Add the failed operation to the error document. - replacement["op"] = run.ops[idx] - full_result["writeErrors"].append(replacement) - - wce = _get_wce_doc(result) - if wce: - full_result["writeConcernErrors"].append(wce) - - -def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn: - """Raise a BulkWriteError from the full bulk api result.""" - # retryWrites on MMAPv1 should raise an actionable error. - if full_result["writeErrors"]: - full_result["writeErrors"].sort(key=lambda error: error["index"]) - err = full_result["writeErrors"][0] - code = err["code"] - msg = err["errmsg"] - if code == 20 and msg.startswith("Transaction numbers"): - errmsg = ( - "This MongoDB deployment does not support " - "retryable writes. Please add retryWrites=false " - "to your connection string." - ) - raise OperationFailure(errmsg, code, full_result) - raise BulkWriteError(full_result) - - -class _Bulk: +class _AsyncBulk: """The private guts of the bulk write API.""" def __init__( @@ -180,7 +91,7 @@ class _Bulk: comment: Optional[str] = None, let: Optional[Any] = None, ) -> None: - """Initialize a _Bulk instance.""" + """Initialize a _AsyncBulk instance.""" self.collection = collection.with_options( codec_options=collection.codec_options._replace( unicode_decode_error_handler="replace", document_class=dict @@ -204,13 +115,16 @@ class _Bulk: # Extra state so that we know where to pick up on a retry attempt. self.current_run = None self.next_run = None + self.is_encrypted = False @property def bulk_ctx_class(self) -> Type[_BulkWriteContext]: encrypter = self.collection.database.client._encrypter if encrypter and not encrypter._bypass_auto_encryption: + self.is_encrypted = True return _EncryptedBulkWriteContext else: + self.is_encrypted = False return _BulkWriteContext def add_insert(self, document: _DocumentOut) -> None: @@ -315,12 +229,236 @@ class _Bulk: if run.ops: yield run + @_handle_reauth + async def write_command( + self, + bwc: _BulkWriteContext, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + docs: list[Mapping[str, Any]], + client: AsyncMongoClient, + ) -> dict[str, Any]: + """A proxy for SocketInfo.write_command that handles event publishing.""" + cmd[bwc.field] = docs + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.STARTED, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._start(cmd, request_id, docs) + try: + reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] + duration = datetime.datetime.now() - bwc.start_time + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.SUCCEEDED, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.FAILED, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + + if bwc.publish: + bwc._fail(request_id, failure, duration) + raise + finally: + bwc.start_time = datetime.datetime.now() + return reply # type: ignore[return-value] + + async def unack_write( + self, + bwc: _BulkWriteContext, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + max_doc_size: int, + docs: list[Mapping[str, Any]], + client: AsyncMongoClient, + ) -> Optional[Mapping[str, Any]]: + """A proxy for AsyncConnection.unack_write that handles event publishing.""" + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.STARTED, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + cmd = bwc._start(cmd, request_id, docs) + try: + result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] + duration = datetime.datetime.now() - bwc.start_time + if result is not None: + reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] + else: + # Comply with APM spec. + reply = {"ok": 1} + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.SUCCEEDED, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._succeed(request_id, reply, duration) + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, OperationFailure): + failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] + elif isinstance(exc, NotPrimaryError): + failure = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.FAILED, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if bwc.publish: + assert bwc.start_time is not None + bwc._fail(request_id, failure, duration) + raise + finally: + bwc.start_time = datetime.datetime.now() + return result # type: ignore[return-value] + + async def _execute_batch_unack( + self, + bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext], + cmd: dict[str, Any], + ops: list[Mapping[str, Any]], + client: AsyncMongoClient, + ) -> list[Mapping[str, Any]]: + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + await bwc.conn.command( # type: ignore[misc] + bwc.db_name, + batched_cmd, # type: ignore[arg-type] + write_concern=WriteConcern(w=0), + session=bwc.session, # type: ignore[arg-type] + client=client, # type: ignore[arg-type] + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + # Though this isn't strictly a "legacy" write, the helper + # handles publishing commands and sending our message + # without receiving a result. Send 0 for max_doc_size + # to disable size checking. Size checking is handled while + # the documents are encoded to BSON. + await self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) # type: ignore[arg-type] + + return to_send + + async def _execute_batch( + self, + bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext], + cmd: dict[str, Any], + ops: list[Mapping[str, Any]], + client: AsyncMongoClient, + ) -> tuple[dict[str, Any], list[Mapping[str, Any]]]: + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + result = await bwc.conn.command( # type: ignore[misc] + bwc.db_name, + batched_cmd, # type: ignore[arg-type] + codec_options=bwc.codec, + session=bwc.session, # type: ignore[arg-type] + client=client, # type: ignore[arg-type] + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + result = await self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type] + await client._process_response(result, bwc.session) # type: ignore[arg-type] + + return result, to_send # type: ignore[return-value] + async def _execute_command( self, generator: Iterator[Any], write_concern: WriteConcern, - session: Optional[ClientSession], - conn: Connection, + session: Optional[AsyncClientSession], + conn: AsyncConnection, op_id: int, retryable: bool, full_result: MutableMapping[str, Any], @@ -335,8 +473,8 @@ class _Bulk: self.next_run = None run = self.current_run - # Connection.command validates the session, but we use - # Connection.write_command + # AsyncConnection.command validates the session, but we use + # AsyncConnection.write_command conn.validate_session(client, session) last_run = False @@ -387,7 +525,7 @@ class _Bulk: # Run as many ops as possible in one command. if write_concern.acknowledged: - result, to_send = await bwc.execute(cmd, ops, client) + result, to_send = await self._execute_batch(bwc, cmd, ops, client) # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) @@ -407,7 +545,7 @@ class _Bulk: if self.ordered and "writeErrors" in result: break else: - to_send = await bwc.execute_unack(cmd, ops, client) + to_send = await self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) @@ -422,7 +560,7 @@ class _Bulk: self, generator: Iterator[Any], write_concern: WriteConcern, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, ) -> dict[str, Any]: """Execute using write commands.""" @@ -440,7 +578,7 @@ class _Bulk: op_id = _randint() async def retryable_bulk( - session: Optional[ClientSession], conn: Connection, retryable: bool + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool ) -> None: await self._execute_command( generator, @@ -458,7 +596,7 @@ class _Bulk: retryable_bulk, session, operation, - bulk=self, + bulk=self, # type: ignore[arg-type] operation_id=op_id, ) @@ -466,7 +604,9 @@ class _Bulk: _raise_bulk_write_error(full_result) return full_result - async def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None: + async def execute_op_msg_no_results( + self, conn: AsyncConnection, generator: Iterator[Any] + ) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name client = self.collection.database.client @@ -499,13 +639,13 @@ class _Bulk: conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. - to_send = await bwc.execute_unack(cmd, ops, client) + to_send = await self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) async def execute_command_no_results( self, - conn: Connection, + conn: AsyncConnection, generator: Iterator[Any], write_concern: WriteConcern, ) -> None: @@ -541,7 +681,7 @@ class _Bulk: async def execute_no_results( self, - conn: Connection, + conn: AsyncConnection, generator: Iterator[Any], write_concern: WriteConcern, ) -> None: @@ -573,7 +713,7 @@ class _Bulk: async def execute( self, write_concern: WriteConcern, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, ) -> Any: """Execute operations.""" diff --git a/pymongo/asynchronous/change_stream.py b/pymongo/asynchronous/change_stream.py index b910767c5..e298df43a 100644 --- a/pymongo/asynchronous/change_stream.py +++ b/pymongo/asynchronous/change_stream.py @@ -21,17 +21,14 @@ from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union from bson import CodecOptions, _bson_to_dict from bson.raw_bson import RawBSONDocument from bson.timestamp import Timestamp -from pymongo import _csot -from pymongo.asynchronous import common +from pymongo import _csot, common from pymongo.asynchronous.aggregation import ( _AggregationCommand, _CollectionAggregationCommand, _DatabaseAggregationCommand, ) -from pymongo.asynchronous.collation import validate_collation_or_none from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.operations import _Op -from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _Pipeline +from pymongo.collation import validate_collation_or_none from pymongo.errors import ( ConnectionFailure, CursorNotFound, @@ -39,6 +36,8 @@ from pymongo.errors import ( OperationFailure, PyMongoError, ) +from pymongo.operations import _Op +from pymongo.typings import _CollationIn, _DocumentType, _Pipeline _IS_SYNC = False @@ -68,11 +67,11 @@ _RESUMABLE_GETMORE_ERRORS = frozenset( if TYPE_CHECKING: - from pymongo.asynchronous.client_session import ClientSession + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection def _resumable(exc: PyMongoError) -> bool: @@ -114,7 +113,7 @@ class ChangeStream(Generic[_DocumentType]): batch_size: Optional[int], collation: Optional[_CollationIn], start_at_operation_time: Optional[Timestamp], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], start_after: Optional[Mapping[str, Any]], comment: Optional[Any] = None, full_document_before_change: Optional[str] = None, @@ -211,7 +210,7 @@ class ChangeStream(Generic[_DocumentType]): full_pipeline.extend(self._pipeline) return full_pipeline - def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None: + def _process_result(self, result: Mapping[str, Any], conn: AsyncConnection) -> None: """Callback that caches the postBatchResumeToken or startAtOperationTime from a changeStream aggregate command response containing an empty batch of change documents. @@ -237,7 +236,7 @@ class ChangeStream(Generic[_DocumentType]): ) async def _run_aggregation_cmd( - self, session: Optional[ClientSession], explicit_session: bool + self, session: Optional[AsyncClientSession], explicit_session: bool ) -> AsyncCommandCursor: """Run the full aggregation pipeline for this ChangeStream and return the corresponding AsyncCommandCursor. diff --git a/pymongo/asynchronous/client_options.py b/pymongo/asynchronous/client_options.py deleted file mode 100644 index 834b61ceb..000000000 --- a/pymongo/asynchronous/client_options.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Tools to parse mongo client options.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast - -from bson.codec_options import _parse_codec_options -from pymongo.asynchronous import common -from pymongo.asynchronous.compression_support import CompressionSettings -from pymongo.asynchronous.monitoring import _EventListener, _EventListeners -from pymongo.asynchronous.pool import PoolOptions -from pymongo.asynchronous.read_preferences import ( - _ServerMode, - make_read_preference, - read_pref_mode_from_name, -) -from pymongo.asynchronous.server_selectors import any_server_selector -from pymongo.errors import ConfigurationError -from pymongo.read_concern import ReadConcern -from pymongo.ssl_support import get_ssl_context -from pymongo.write_concern import WriteConcern, validate_boolean - -if TYPE_CHECKING: - from bson.codec_options import CodecOptions - from pymongo.asynchronous.auth import MongoCredential - from pymongo.asynchronous.encryption_options import AutoEncryptionOpts - from pymongo.asynchronous.topology_description import _ServerSelector - from pymongo.pyopenssl_context import SSLContext - -_IS_SYNC = False - - -def _parse_credentials( - username: str, password: str, database: Optional[str], options: Mapping[str, Any] -) -> Optional[MongoCredential]: - """Parse authentication credentials.""" - mechanism = options.get("authmechanism", "DEFAULT" if username else None) - source = options.get("authsource") - if username or mechanism: - from pymongo.asynchronous.auth import _build_credentials_tuple - - return _build_credentials_tuple(mechanism, source, username, password, options, database) - return None - - -def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode: - """Parse read preference options.""" - if "read_preference" in options: - return options["read_preference"] - - name = options.get("readpreference", "primary") - mode = read_pref_mode_from_name(name) - tags = options.get("readpreferencetags") - max_staleness = options.get("maxstalenessseconds", -1) - return make_read_preference(mode, tags, max_staleness) - - -def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern: - """Parse write concern options.""" - concern = options.get("w") - wtimeout = options.get("wtimeoutms") - j = options.get("journal") - fsync = options.get("fsync") - return WriteConcern(concern, wtimeout, j, fsync) - - -def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern: - """Parse read concern options.""" - concern = options.get("readconcernlevel") - return ReadConcern(concern) - - -def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]: - """Parse ssl options.""" - use_tls = options.get("tls") - if use_tls is not None: - validate_boolean("tls", use_tls) - - certfile = options.get("tlscertificatekeyfile") - passphrase = options.get("tlscertificatekeyfilepassword") - ca_certs = options.get("tlscafile") - crlfile = options.get("tlscrlfile") - allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False) - allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False) - disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False) - - enabled_tls_opts = [] - for opt in ( - "tlscertificatekeyfile", - "tlscertificatekeyfilepassword", - "tlscafile", - "tlscrlfile", - ): - # Any non-null value of these options implies tls=True. - if opt in options and options[opt]: - enabled_tls_opts.append(opt) - for opt in ( - "tlsallowinvalidcertificates", - "tlsallowinvalidhostnames", - "tlsdisableocspendpointcheck", - ): - # A value of False for these options implies tls=True. - if opt in options and not options[opt]: - enabled_tls_opts.append(opt) - - if enabled_tls_opts: - if use_tls is None: - # Implicitly enable TLS when one of the tls* options is set. - use_tls = True - elif not use_tls: - # Error since tls is explicitly disabled but a tls option is set. - raise ConfigurationError( - "TLS has not been enabled but the " - "following tls parameters have been set: " - "%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts) - ) - - if use_tls: - ctx = get_ssl_context( - certfile, - passphrase, - ca_certs, - crlfile, - allow_invalid_certificates, - allow_invalid_hostnames, - disable_ocsp_endpoint_check, - ) - return ctx, allow_invalid_hostnames - return None, allow_invalid_hostnames - - -def _parse_pool_options( - username: str, password: str, database: Optional[str], options: Mapping[str, Any] -) -> PoolOptions: - """Parse connection pool options.""" - credentials = _parse_credentials(username, password, database, options) - max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE) - min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE) - max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC) - if max_pool_size is not None and min_pool_size > max_pool_size: - raise ValueError("minPoolSize must be smaller or equal to maxPoolSize") - connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT) - socket_timeout = options.get("sockettimeoutms") - wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT) - event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners")) - appname = options.get("appname") - driver = options.get("driver") - server_api = options.get("server_api") - compression_settings = CompressionSettings( - options.get("compressors", []), options.get("zlibcompressionlevel", -1) - ) - ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options) - load_balanced = options.get("loadbalanced") - max_connecting = options.get("maxconnecting", common.MAX_CONNECTING) - return PoolOptions( - max_pool_size, - min_pool_size, - max_idle_time_seconds, - connect_timeout, - socket_timeout, - wait_queue_timeout, - ssl_context, - tls_allow_invalid_hostnames, - _EventListeners(event_listeners), - appname, - driver, - compression_settings, - max_connecting=max_connecting, - server_api=server_api, - load_balanced=load_balanced, - credentials=credentials, - ) - - -class ClientOptions: - """Read only configuration options for an AsyncMongoClient. - - Should not be instantiated directly by application developers. Access - a client's options via :attr:`pymongo.mongo_client.AsyncMongoClient.options` - instead. - """ - - def __init__( - self, username: str, password: str, database: Optional[str], options: Mapping[str, Any] - ): - self.__options = options - self.__codec_options = _parse_codec_options(options) - self.__direct_connection = options.get("directconnection") - self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS) - # self.__server_selection_timeout is in seconds. Must use full name for - # common.SERVER_SELECTION_TIMEOUT because it is set directly by tests. - self.__server_selection_timeout = options.get( - "serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT - ) - self.__pool_options = _parse_pool_options(username, password, database, options) - self.__read_preference = _parse_read_preference(options) - self.__replica_set_name = options.get("replicaset") - self.__write_concern = _parse_write_concern(options) - self.__read_concern = _parse_read_concern(options) - self.__connect = options.get("connect") - self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY) - self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES) - self.__retry_reads = options.get("retryreads", common.RETRY_READS) - self.__server_selector = options.get("server_selector", any_server_selector) - self.__auto_encryption_opts = options.get("auto_encryption_opts") - self.__load_balanced = options.get("loadbalanced") - self.__timeout = options.get("timeoutms") - self.__server_monitoring_mode = options.get( - "servermonitoringmode", common.SERVER_MONITORING_MODE - ) - - @property - def _options(self) -> Mapping[str, Any]: - """The original options used to create this ClientOptions.""" - return self.__options - - @property - def connect(self) -> Optional[bool]: - """Whether to begin discovering a MongoDB topology automatically.""" - return self.__connect - - @property - def codec_options(self) -> CodecOptions: - """A :class:`~bson.codec_options.CodecOptions` instance.""" - return self.__codec_options - - @property - def direct_connection(self) -> Optional[bool]: - """Whether to connect to the deployment in 'Single' topology.""" - return self.__direct_connection - - @property - def local_threshold_ms(self) -> int: - """The local threshold for this instance.""" - return self.__local_threshold_ms - - @property - def server_selection_timeout(self) -> int: - """The server selection timeout for this instance in seconds.""" - return self.__server_selection_timeout - - @property - def server_selector(self) -> _ServerSelector: - return self.__server_selector - - @property - def heartbeat_frequency(self) -> int: - """The monitoring frequency in seconds.""" - return self.__heartbeat_frequency - - @property - def pool_options(self) -> PoolOptions: - """A :class:`~pymongo.pool.PoolOptions` instance.""" - return self.__pool_options - - @property - def read_preference(self) -> _ServerMode: - """A read preference instance.""" - return self.__read_preference - - @property - def replica_set_name(self) -> Optional[str]: - """Replica set name or None.""" - return self.__replica_set_name - - @property - def write_concern(self) -> WriteConcern: - """A :class:`~pymongo.write_concern.WriteConcern` instance.""" - return self.__write_concern - - @property - def read_concern(self) -> ReadConcern: - """A :class:`~pymongo.read_concern.ReadConcern` instance.""" - return self.__read_concern - - @property - def timeout(self) -> Optional[float]: - """The configured timeoutMS converted to seconds, or None. - - .. versionadded:: 4.2 - """ - return self.__timeout - - @property - def retry_writes(self) -> bool: - """If this instance should retry supported write operations.""" - return self.__retry_writes - - @property - def retry_reads(self) -> bool: - """If this instance should retry supported read operations.""" - return self.__retry_reads - - @property - def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]: - """A :class:`~pymongo.encryption.AutoEncryptionOpts` or None.""" - return self.__auto_encryption_opts - - @property - def load_balanced(self) -> Optional[bool]: - """True if the client was configured to connect to a load balancer.""" - return self.__load_balanced - - @property - def event_listeners(self) -> list[_EventListeners]: - """The event listeners registered for this client. - - See :mod:`~pymongo.monitoring` for details. - - .. versionadded:: 4.0 - """ - assert self.__pool_options._event_listeners is not None - return self.__pool_options._event_listeners.event_listeners() - - @property - def server_monitoring_mode(self) -> str: - """The configured serverMonitoringMode option. - - .. versionadded:: 4.5 - """ - return self.__server_monitoring_mode diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 75c6bad7c..4f2f2d97d 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -44,8 +44,8 @@ Transactions .. versionadded:: 3.7 MongoDB 4.0 adds support for transactions on replica set primaries. A -transaction is associated with a :class:`ClientSession`. To start a transaction -on a session, use :meth:`ClientSession.start_transaction` in a with-statement. +transaction is associated with a :class:`AsyncClientSession`. To start a transaction +on a session, use :meth:`AsyncClientSession.start_transaction` in a with-statement. Then, execute an operation within the transaction by passing the session to the operation: @@ -63,9 +63,9 @@ operation: ) Upon normal completion of ``async with session.start_transaction()`` block, the -transaction automatically calls :meth:`ClientSession.commit_transaction`. +transaction automatically calls :meth:`AsyncClientSession.commit_transaction`. If the block exits with an exception, the transaction automatically calls -:meth:`ClientSession.abort_transaction`. +:meth:`AsyncClientSession.abort_transaction`. In general, multi-document transactions only support read/write (CRUD) operations on existing collections. However, MongoDB 4.4 adds support for @@ -157,8 +157,6 @@ from bson.int64 import Int64 from bson.timestamp import Timestamp from pymongo import _csot from pymongo.asynchronous.cursor import _ConnectionManager -from pymongo.asynchronous.operations import _Op -from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode from pymongo.errors import ( ConfigurationError, ConnectionFailure, @@ -167,23 +165,25 @@ from pymongo.errors import ( PyMongoError, WTimeoutError, ) -from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES +from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES +from pymongo.operations import _Op from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.server_type import SERVER_TYPE from pymongo.write_concern import WriteConcern if TYPE_CHECKING: from types import TracebackType - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server - from pymongo.asynchronous.typings import ClusterTime, _Address + from pymongo.typings import ClusterTime, _Address _IS_SYNC = False class SessionOptions: - """Options for a new :class:`ClientSession`. + """Options for a new :class:`AsyncClientSession`. :param causal_consistency: If True, read operations are causally ordered within the session. Defaults to True when the ``snapshot`` @@ -246,7 +246,7 @@ class SessionOptions: class TransactionOptions: - """Options for :meth:`ClientSession.start_transaction`. + """Options for :meth:`AsyncClientSession.start_transaction`. :param read_concern: The :class:`~pymongo.read_concern.ReadConcern` to use for this transaction. @@ -336,8 +336,8 @@ class TransactionOptions: def _validate_session_write_concern( - session: Optional[ClientSession], write_concern: Optional[WriteConcern] -) -> Optional[ClientSession]: + session: Optional[AsyncClientSession], write_concern: Optional[WriteConcern] +) -> Optional[AsyncClientSession]: """Validate that an explicit session is not used with an unack'ed write. Returns the session to use for the next operation. @@ -362,7 +362,7 @@ def _validate_session_write_concern( class _TransactionContext: """Internal transaction context manager for start_transaction.""" - def __init__(self, session: ClientSession): + def __init__(self, session: AsyncClientSession): self.__session = session async def __aenter__(self) -> _TransactionContext: @@ -391,7 +391,7 @@ class _TxnState: class _Transaction: - """Internal class to hold transaction information in a ClientSession.""" + """Internal class to hold transaction information in a AsyncClientSession.""" def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient): self.opts = opts @@ -410,12 +410,12 @@ class _Transaction: return self.state == _TxnState.STARTING @property - def pinned_conn(self) -> Optional[Connection]: + def pinned_conn(self) -> Optional[AsyncConnection]: if self.active() and self.conn_mgr: return self.conn_mgr.conn return None - def pin(self, server: Server, conn: Connection) -> None: + def pin(self, server: Server, conn: AsyncConnection) -> None: self.sharded = True self.pinned_address = server.description.address if server.description.server_type == SERVER_TYPE.LoadBalancer: @@ -481,16 +481,16 @@ if TYPE_CHECKING: from pymongo.asynchronous.mongo_client import AsyncMongoClient -class ClientSession: +class AsyncClientSession: """A session for ordering sequential operations. - :class:`ClientSession` instances are **not thread-safe or fork-safe**. + :class:`AsyncClientSession` instances are **not thread-safe or fork-safe**. They can only be used by one thread or process at a time. A single - :class:`ClientSession` cannot be used to run multiple operations + :class:`AsyncClientSession` cannot be used to run multiple operations concurrently. Should not be initialized directly by application developers - to create a - :class:`ClientSession`, call + :class:`AsyncClientSession`, call :meth:`~pymongo.mongo_client.AsyncMongoClient.start_session`. """ @@ -541,7 +541,7 @@ class ClientSession: if self._server_session is None: raise InvalidOperation("Cannot use ended session") - async def __aenter__(self) -> ClientSession: + async def __aenter__(self) -> AsyncClientSession: return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: @@ -560,14 +560,14 @@ class ClientSession: return self._options @property - async def session_id(self) -> Mapping[str, Any]: + def session_id(self) -> Mapping[str, Any]: """A BSON document, the opaque server session identifier.""" self._check_ended() self._materialize(self._client.topology_description.logical_session_timeout_minutes) return self._server_session.session_id @property - async def _transaction_id(self) -> Int64: + def _transaction_id(self) -> Int64: """The current transaction id for the underlying server session.""" self._materialize(self._client.topology_description.logical_session_timeout_minutes) return self._server_session.transaction_id @@ -598,7 +598,7 @@ class ClientSession: async def with_transaction( self, - callback: Callable[[ClientSession], _T], + callback: Callable[[AsyncClientSession], _T], read_concern: Optional[ReadConcern] = None, write_concern: Optional[WriteConcern] = None, read_preference: Optional[_ServerMode] = None, @@ -646,25 +646,25 @@ class ClientSession: however, ``with_transaction`` will return without taking further action. - :class:`ClientSession` instances are **not thread-safe or fork-safe**. + :class:`AsyncClientSession` instances are **not thread-safe or fork-safe**. Consequently, the ``callback`` must not attempt to execute multiple operations concurrently. When ``callback`` raises an exception, ``with_transaction`` automatically aborts the current transaction. When ``callback`` or - :meth:`~ClientSession.commit_transaction` raises an exception that + :meth:`~AsyncClientSession.commit_transaction` raises an exception that includes the ``"TransientTransactionError"`` error label, ``with_transaction`` starts a new transaction and re-executes the ``callback``. - When :meth:`~ClientSession.commit_transaction` raises an exception with + When :meth:`~AsyncClientSession.commit_transaction` raises an exception with the ``"UnknownTransactionCommitResult"`` error label, ``with_transaction`` retries the commit until the result of the transaction is known. This method will cease retrying after 120 seconds has elapsed. This timeout is not configurable and any exception raised by the - ``callback`` or by :meth:`ClientSession.commit_transaction` after the + ``callback`` or by :meth:`AsyncClientSession.commit_transaction` after the timeout is reached will be re-raised. Applications that desire a different timeout duration should not use this method. @@ -850,7 +850,7 @@ class ClientSession: """ async def func( - _session: Optional[ClientSession], conn: Connection, _retryable: bool + _session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable: bool ) -> dict[str, Any]: return await self._finish_transaction(conn, command_name) @@ -858,7 +858,7 @@ class ClientSession: func, self, None, retryable=True, operation=_Op.ABORT ) - async def _finish_transaction(self, conn: Connection, command_name: str) -> dict[str, Any]: + async def _finish_transaction(self, conn: AsyncConnection, command_name: str) -> dict[str, Any]: self._transaction.attempt += 1 opts = self._transaction.opts assert opts @@ -897,8 +897,8 @@ class ClientSession: """Update the cluster time for this session. :param cluster_time: The - :data:`~pymongo.client_session.ClientSession.cluster_time` from - another `ClientSession` instance. + :data:`~pymongo.client_session.AsyncClientSession.cluster_time` from + another `AsyncClientSession` instance. """ if not isinstance(cluster_time, _Mapping): raise TypeError("cluster_time must be a subclass of collections.Mapping") @@ -918,8 +918,8 @@ class ClientSession: """Update the operation time for this session. :param operation_time: The - :data:`~pymongo.client_session.ClientSession.operation_time` from - another `ClientSession` instance. + :data:`~pymongo.client_session.AsyncClientSession.operation_time` from + another `AsyncClientSession` instance. """ if not isinstance(operation_time, Timestamp): raise TypeError("operation_time must be an instance of bson.timestamp.Timestamp") @@ -966,11 +966,11 @@ class ClientSession: return None @property - def _pinned_connection(self) -> Optional[Connection]: + def _pinned_connection(self) -> Optional[AsyncConnection]: """The connection this transaction was started on.""" return self._transaction.pinned_conn - def _pin(self, server: Server, conn: Connection) -> None: + def _pin(self, server: Server, conn: AsyncConnection) -> None: """Pin this session to the given Server or to the given connection.""" self._transaction.pin(server, conn) @@ -999,7 +999,7 @@ class ClientSession: command: MutableMapping[str, Any], is_retryable: bool, read_preference: _ServerMode, - conn: Connection, + conn: AsyncConnection, ) -> None: if not conn.supports_sessions: if not self._implicit: @@ -1042,7 +1042,7 @@ class ClientSession: self._check_ended() self._server_session.inc_transaction_id() - def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: Connection) -> None: + def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: AsyncConnection) -> None: if self.options.causal_consistency and self.operation_time is not None: cmd.setdefault("readConcern", {})["afterClusterTime"] = self.operation_time if self.options.snapshot: @@ -1054,7 +1054,7 @@ class ClientSession: rc["atClusterTime"] = self._snapshot_time def __copy__(self) -> NoReturn: - raise TypeError("A ClientSession cannot be copied, create a new session instead") + raise TypeError("A AsyncClientSession cannot be copied, create a new session instead") class _EmptyServerSession: diff --git a/pymongo/asynchronous/collation.py b/pymongo/asynchronous/collation.py deleted file mode 100644 index 26d5a68d7..000000000 --- a/pymongo/asynchronous/collation.py +++ /dev/null @@ -1,226 +0,0 @@ -# Copyright 2016 MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tools for working with `collations`_. - -.. _collations: https://www.mongodb.com/docs/manual/reference/collation/ -""" -from __future__ import annotations - -from typing import Any, Mapping, Optional, Union - -from pymongo.asynchronous import common -from pymongo.write_concern import validate_boolean - -_IS_SYNC = False - - -class CollationStrength: - """ - An enum that defines values for `strength` on a - :class:`~pymongo.collation.Collation`. - """ - - PRIMARY = 1 - """Differentiate base (unadorned) characters.""" - - SECONDARY = 2 - """Differentiate character accents.""" - - TERTIARY = 3 - """Differentiate character case.""" - - QUATERNARY = 4 - """Differentiate words with and without punctuation.""" - - IDENTICAL = 5 - """Differentiate unicode code point (characters are exactly identical).""" - - -class CollationAlternate: - """ - An enum that defines values for `alternate` on a - :class:`~pymongo.collation.Collation`. - """ - - NON_IGNORABLE = "non-ignorable" - """Spaces and punctuation are treated as base characters.""" - - SHIFTED = "shifted" - """Spaces and punctuation are *not* considered base characters. - - Spaces and punctuation are distinguished regardless when the - :class:`~pymongo.collation.Collation` strength is at least - :data:`~pymongo.collation.CollationStrength.QUATERNARY`. - - """ - - -class CollationMaxVariable: - """ - An enum that defines values for `max_variable` on a - :class:`~pymongo.collation.Collation`. - """ - - PUNCT = "punct" - """Both punctuation and spaces are ignored.""" - - SPACE = "space" - """Spaces alone are ignored.""" - - -class CollationCaseFirst: - """ - An enum that defines values for `case_first` on a - :class:`~pymongo.collation.Collation`. - """ - - UPPER = "upper" - """Sort uppercase characters first.""" - - LOWER = "lower" - """Sort lowercase characters first.""" - - OFF = "off" - """Default for locale or collation strength.""" - - -class Collation: - """Collation - - :param locale: (string) The locale of the collation. This should be a string - that identifies an `ICU locale ID` exactly. For example, ``en_US`` is - valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB - documentation for a list of supported locales. - :param caseLevel: (optional) If ``True``, turn on case sensitivity if - `strength` is 1 or 2 (case sensitivity is implied if `strength` is - greater than 2). Defaults to ``False``. - :param caseFirst: (optional) Specify that either uppercase or lowercase - characters take precedence. Must be one of the following values: - - * :data:`~CollationCaseFirst.UPPER` - * :data:`~CollationCaseFirst.LOWER` - * :data:`~CollationCaseFirst.OFF` (the default) - - :param strength: Specify the comparison strength. This is also - known as the ICU comparison level. This must be one of the following - values: - - * :data:`~CollationStrength.PRIMARY` - * :data:`~CollationStrength.SECONDARY` - * :data:`~CollationStrength.TERTIARY` (the default) - * :data:`~CollationStrength.QUATERNARY` - * :data:`~CollationStrength.IDENTICAL` - - Each successive level builds upon the previous. For example, a - `strength` of :data:`~CollationStrength.SECONDARY` differentiates - characters based both on the unadorned base character and its accents. - - :param numericOrdering: If ``True``, order numbers numerically - instead of in collation order (defaults to ``False``). - :param alternate: Specify whether spaces and punctuation are - considered base characters. This must be one of the following values: - - * :data:`~CollationAlternate.NON_IGNORABLE` (the default) - * :data:`~CollationAlternate.SHIFTED` - - :param maxVariable: When `alternate` is - :data:`~CollationAlternate.SHIFTED`, this option specifies what - characters may be ignored. This must be one of the following values: - - * :data:`~CollationMaxVariable.PUNCT` (the default) - * :data:`~CollationMaxVariable.SPACE` - - :param normalization: If ``True``, normalizes text into Unicode - NFD. Defaults to ``False``. - :param backwards: If ``True``, accents on characters are - considered from the back of the word to the front, as it is done in some - French dictionary ordering traditions. Defaults to ``False``. - :param kwargs: Keyword arguments supplying any additional options - to be sent with this Collation object. - - .. versionadded: 3.4 - - """ - - __slots__ = ("__document",) - - def __init__( - self, - locale: str, - caseLevel: Optional[bool] = None, - caseFirst: Optional[str] = None, - strength: Optional[int] = None, - numericOrdering: Optional[bool] = None, - alternate: Optional[str] = None, - maxVariable: Optional[str] = None, - normalization: Optional[bool] = None, - backwards: Optional[bool] = None, - **kwargs: Any, - ) -> None: - locale = common.validate_string("locale", locale) - self.__document: dict[str, Any] = {"locale": locale} - if caseLevel is not None: - self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel) - if caseFirst is not None: - self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst) - if strength is not None: - self.__document["strength"] = common.validate_integer("strength", strength) - if numericOrdering is not None: - self.__document["numericOrdering"] = validate_boolean( - "numericOrdering", numericOrdering - ) - if alternate is not None: - self.__document["alternate"] = common.validate_string("alternate", alternate) - if maxVariable is not None: - self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable) - if normalization is not None: - self.__document["normalization"] = validate_boolean("normalization", normalization) - if backwards is not None: - self.__document["backwards"] = validate_boolean("backwards", backwards) - self.__document.update(kwargs) - - @property - def document(self) -> dict[str, Any]: - """The document representation of this collation. - - .. note:: - :class:`Collation` is immutable. Mutating the value of - :attr:`document` does not mutate this :class:`Collation`. - """ - return self.__document.copy() - - def __repr__(self) -> str: - document = self.document - return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document)) - - def __eq__(self, other: Any) -> bool: - if isinstance(other, Collation): - return self.document == other.document - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -def validate_collation_or_none( - value: Optional[Union[Mapping[str, Any], Collation]] -) -> Optional[dict[str, Any]]: - if value is None: - return None - if isinstance(value, Collation): - return value.document - if isinstance(value, dict): - return value - raise TypeError("collation must be a dict, an instance of collation.Collation, or None.") diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 372cafe5b..d85f263ba 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -42,27 +42,32 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from bson.son import SON from bson.timestamp import Timestamp -from pymongo import ASCENDING, _csot -from pymongo.asynchronous import common, helpers, message +from pymongo import ASCENDING, _csot, common, helpers_shared, message from pymongo.asynchronous.aggregation import ( _CollectionAggregationCommand, _CollectionRawAggregationCommand, ) -from pymongo.asynchronous.bulk import _Bulk +from pymongo.asynchronous.bulk import _AsyncBulk from pymongo.asynchronous.change_stream import CollectionChangeStream -from pymongo.asynchronous.collation import validate_collation_or_none from pymongo.asynchronous.command_cursor import ( AsyncCommandCursor, AsyncRawBatchCommandCursor, ) -from pymongo.asynchronous.common import _ecoc_coll_name, _esc_coll_name from pymongo.asynchronous.cursor import ( AsyncCursor, AsyncRawBatchCursor, ) -from pymongo.asynchronous.helpers import _check_write_command_response -from pymongo.asynchronous.message import _UNICODE_REPLACE_CODEC_OPTIONS -from pymongo.asynchronous.operations import ( +from pymongo.collation import validate_collation_or_none +from pymongo.common import _ecoc_coll_name, _esc_coll_name +from pymongo.errors import ( + ConfigurationError, + InvalidName, + InvalidOperation, + OperationFailure, +) +from pymongo.helpers_shared import _check_write_command_response +from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS +from pymongo.operations import ( DeleteMany, DeleteOne, IndexModel, @@ -75,15 +80,8 @@ from pymongo.asynchronous.operations import ( _IndexList, _Op, ) -from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode -from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline -from pymongo.errors import ( - ConfigurationError, - InvalidName, - InvalidOperation, - OperationFailure, -) from pymongo.read_concern import DEFAULT_READ_CONCERN +from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.results import ( BulkWriteResult, DeleteResult, @@ -91,6 +89,7 @@ from pymongo.results import ( InsertOneResult, UpdateResult, ) +from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean _IS_SYNC = False @@ -127,11 +126,11 @@ class ReturnDocument: if TYPE_CHECKING: import bson from pymongo.asynchronous.aggregation import _AggregationCommand - from pymongo.asynchronous.client_session import ClientSession - from pymongo.asynchronous.collation import Collation + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.database import AsyncDatabase - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server + from pymongo.collation import Collation from pymongo.read_concern import ReadConcern @@ -147,7 +146,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional[ReadConcern] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> None: """Get / create an asynchronous Mongo collection. @@ -373,7 +372,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): ) def _write_concern_for_cmd( - self, cmd: Mapping[str, Any], session: Optional[ClientSession] + self, cmd: Mapping[str, Any], session: Optional[AsyncClientSession] ) -> WriteConcern: raw_wc = cmd.get("writeConcern") if raw_wc is not None: @@ -413,7 +412,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): batch_size: Optional[int] = None, collation: Optional[_CollationIn] = None, start_at_operation_time: Optional[Timestamp] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, start_after: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, full_document_before_change: Optional[str] = None, @@ -494,7 +493,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): the specified :class:`~bson.timestamp.Timestamp`. Requires MongoDB >= 4.0. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param start_after: The same as `resume_after` except that `start_after` can resume notifications after an invalidate event. This option and `resume_after` are mutually exclusive. @@ -546,13 +545,13 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): return change_stream async def _conn_for_writes( - self, session: Optional[ClientSession], operation: str - ) -> AsyncContextManager[Connection]: + self, session: Optional[AsyncClientSession], operation: str + ) -> AsyncContextManager[AsyncConnection]: return await self._database.client._conn_for_writes(session, operation) async def _command( self, - conn: Connection, + conn: AsyncConnection, command: MutableMapping[str, Any], read_preference: Optional[_ServerMode] = None, codec_options: Optional[CodecOptions] = None, @@ -561,13 +560,13 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): read_concern: Optional[ReadConcern] = None, write_concern: Optional[WriteConcern] = None, collation: Optional[_CollationIn] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, retryable_write: bool = False, user_fields: Optional[Any] = None, ) -> Mapping[str, Any]: """Internal command helper. - :param conn` - A Connection instance. + :param conn` - A AsyncConnection instance. :param command` - The command itself, as a :class:`~bson.son.SON` instance. :param read_preference` (optional) - The read preference to use. :param codec_options` (optional) - An instance of @@ -581,7 +580,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): :param collation` (optional) - An instance of :class:`~pymongo.collation.Collation`. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param retryable_write: True if this command is a retryable write. :param user_fields: Response fields that should be decoded @@ -613,7 +612,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): name: str, options: MutableMapping[str, Any], collation: Optional[_CollationIn], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], encrypted_fields: Optional[Mapping[str, Any]] = None, qev2_required: bool = False, ) -> None: @@ -646,7 +645,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def _create( self, options: MutableMapping[str, Any], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], ) -> None: collation = validate_collation_or_none(options.pop("collation", None)) encrypted_fields = options.pop("encryptedFields", None) @@ -676,7 +675,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): requests: Sequence[_WriteOp[_DocumentType]], ordered: bool = True, bypass_document_validation: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, let: Optional[Mapping] = None, ) -> BulkWriteResult: @@ -726,7 +725,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): write to opt-out of document level validation. Default is ``False``. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param let: Map of parameter names and values. Values must be @@ -755,7 +754,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): """ common.validate_list("requests", requests) - blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let) + blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let) for request in requests: try: request._add_to_bulk(blk) @@ -775,7 +774,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): write_concern: WriteConcern, op_id: Optional[int], bypass_doc_val: bool, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], comment: Optional[Any] = None, ) -> Any: """Internal helper for inserting a single document.""" @@ -786,7 +785,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): command["comment"] = comment async def _insert_command( - session: Optional[ClientSession], conn: Connection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool ) -> None: if bypass_doc_val: command["bypassDocumentValidation"] = True @@ -815,7 +814,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): self, document: Union[_DocumentType, RawBSONDocument], bypass_document_validation: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> InsertOneResult: """Insert a single document. @@ -835,7 +834,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): write to opt-out of document level validation. Default is ``False``. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -881,7 +880,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): documents: Iterable[Union[_DocumentType, RawBSONDocument]], ordered: bool = True, bypass_document_validation: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> InsertManyResult: """Insert an iterable of documents. @@ -904,7 +903,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): write to opt-out of document level validation. Default is ``False``. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -945,14 +944,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): yield (message._INSERT, document) write_concern = self._write_concern_for(session) - blk = _Bulk(self, ordered, bypass_document_validation, comment=comment) + blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment) blk.ops = list(gen()) await blk.execute(write_concern, session, _Op.INSERT) return InsertManyResult(inserted_ids, write_concern.acknowledged) async def _update( self, - conn: Connection, + conn: AsyncConnection, criteria: Mapping[str, Any], document: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, @@ -964,7 +963,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): collation: Optional[_CollationIn] = None, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -996,7 +995,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." ) if not isinstance(hint, str): - hint = helpers._index_document(hint) + hint = helpers_shared._index_document(hint) update_doc["hint"] = hint command = {"update": self.name, "ordered": ordered, "updates": [update_doc]} if let is not None: @@ -1051,14 +1050,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): collation: Optional[_CollationIn] = None, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Optional[Mapping[str, Any]]: """Internal update / replace helper.""" async def _update( - session: Optional[ClientSession], conn: Connection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool ) -> Optional[Mapping[str, Any]]: return await self._update( conn, @@ -1094,7 +1093,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): bypass_document_validation: bool = False, collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> UpdateResult: @@ -1143,7 +1142,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -1197,7 +1196,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): collation: Optional[_CollationIn] = None, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> UpdateResult: @@ -1252,7 +1251,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -1310,7 +1309,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): bypass_document_validation: Optional[bool] = None, collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> UpdateResult: @@ -1352,7 +1351,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -1404,14 +1403,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def drop( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, encrypted_fields: Optional[Mapping[str, Any]] = None, ) -> None: """Alias for :meth:`~pymongo.database.AsyncDatabase.drop_collection`. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param encrypted_fields: **(BETA)** Document that describes the encrypted fields for @@ -1447,7 +1446,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def _delete( self, - conn: Connection, + conn: AsyncConnection, criteria: Mapping[str, Any], multi: bool, write_concern: Optional[WriteConcern] = None, @@ -1455,7 +1454,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): ordered: bool = True, collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -1477,7 +1476,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): "Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands." ) if not isinstance(hint, str): - hint = helpers._index_document(hint) + hint = helpers_shared._index_document(hint) delete_doc["hint"] = hint command = {"delete": self.name, "ordered": ordered, "deletes": [delete_doc]} @@ -1510,14 +1509,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): ordered: bool = True, collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Mapping[str, Any]: """Internal delete helper.""" async def _delete( - session: Optional[ClientSession], conn: Connection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool ) -> Mapping[str, Any]: return await self._delete( conn, @@ -1546,7 +1545,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): filter: Mapping[str, Any], collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> DeleteResult: @@ -1570,7 +1569,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -1611,7 +1610,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): filter: Mapping[str, Any], collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> DeleteResult: @@ -1635,7 +1634,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -1737,7 +1736,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): always be returned. Use a dict to exclude fields from the result (e.g. projection={'_id': False}). :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param skip: the number of documents to omit (from the start of the result set) when returning the results :param limit: the maximum number of results to @@ -1931,8 +1930,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def _count_cmd( self, - session: Optional[ClientSession], - conn: Connection, + session: Optional[AsyncClientSession], + conn: AsyncConnection, read_preference: Optional[_ServerMode], cmd: dict[str, Any], collation: Optional[Collation], @@ -1956,11 +1955,11 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def _aggregate_one_result( self, - conn: Connection, + conn: AsyncConnection, read_preference: Optional[_ServerMode], cmd: dict[str, Any], collation: Optional[_CollationIn], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], ) -> Optional[Mapping[str, Any]]: """Internal helper to run an aggregate that returns a single result.""" result = await self._command( @@ -2012,9 +2011,9 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): kwargs["comment"] = comment async def _cmd( - session: Optional[ClientSession], + session: Optional[AsyncClientSession], _server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: Optional[_ServerMode], ) -> int: cmd: dict[str, Any] = {"count": self._name} @@ -2026,7 +2025,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def count_documents( self, filter: Mapping[str, Any], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> int: @@ -2072,7 +2071,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): to count in the collection. Can be an empty document to count all documents. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: See list of options above. @@ -2095,14 +2094,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): pipeline.append({"$group": {"_id": 1, "n": {"$sum": 1}}}) cmd = {"aggregate": self._name, "pipeline": pipeline, "cursor": {}} if "hint" in kwargs and not isinstance(kwargs["hint"], str): - kwargs["hint"] = helpers._index_document(kwargs["hint"]) + kwargs["hint"] = helpers_shared._index_document(kwargs["hint"]) collation = validate_collation_or_none(kwargs.pop("collation", None)) cmd.update(kwargs) async def _cmd( - session: Optional[ClientSession], + session: Optional[AsyncClientSession], _server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: Optional[_ServerMode], ) -> int: result = await self._aggregate_one_result( @@ -2117,10 +2116,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def _retryable_non_cursor_read( self, func: Callable[ - [Optional[ClientSession], Server, Connection, Optional[_ServerMode]], + [Optional[AsyncClientSession], Server, AsyncConnection, Optional[_ServerMode]], Coroutine[Any, Any, T], ], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, ) -> T: """Non-cursor read helper to handle implicit session creation.""" @@ -2131,7 +2130,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def create_indexes( self, indexes: Sequence[IndexModel], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> list[str]: @@ -2147,7 +2146,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): :param indexes: A list of :class:`~pymongo.operations.IndexModel` instances. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the createIndexes @@ -2177,14 +2176,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): @_csot.apply async def _create_indexes( - self, indexes: Sequence[IndexModel], session: Optional[ClientSession], **kwargs: Any + self, indexes: Sequence[IndexModel], session: Optional[AsyncClientSession], **kwargs: Any ) -> list[str]: """Internal createIndexes helper. :param indexes: A list of :class:`~pymongo.operations.IndexModel` instances. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param kwargs: optional arguments to the createIndexes command (like maxTimeMS) can be passed as keyword arguments. """ @@ -2223,7 +2222,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def create_index( self, keys: _IndexKeyHint, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> str: @@ -2299,7 +2298,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): :param keys: a single key or a list of (key, direction) pairs specifying the index to create :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: any additional index creation @@ -2340,7 +2339,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def drop_indexes( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> None: @@ -2350,7 +2349,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): Raises OperationFailure on an error. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the createIndexes @@ -2375,7 +2374,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def drop_index( self, index_or_name: _IndexKeyHint, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> None: @@ -2397,7 +2396,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): :param index_or_name: index (or name of index) to drop :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the createIndexes @@ -2424,13 +2423,13 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def _drop_index( self, index_or_name: _IndexKeyHint, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> None: name = index_or_name if isinstance(index_or_name, list): - name = helpers._gen_index_name(index_or_name) + name = helpers_shared._gen_index_name(index_or_name) if not isinstance(name, str): raise TypeError("index_or_name must be an instance of str or list") @@ -2451,7 +2450,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def list_indexes( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> AsyncCommandCursor[MutableMapping[str, Any]]: """Get a cursor over the index documents for this collection. @@ -2462,7 +2461,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): SON([('v', 2), ('key', SON([('_id', 1)])), ('name', '_id_')]) :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -2480,7 +2479,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def _list_indexes( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> AsyncCommandCursor[MutableMapping[str, Any]]: codec_options: CodecOptions = CodecOptions(SON) @@ -2492,9 +2491,9 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): explicit_session = session is not None async def _cmd( - session: Optional[ClientSession], + session: Optional[AsyncClientSession], _server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: _ServerMode, ) -> AsyncCommandCursor[MutableMapping[str, Any]]: cmd = {"listIndexes": self._name, "cursor": {}} @@ -2529,7 +2528,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def index_information( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> MutableMapping[str, Any]: """Get information on this collection's indexes. @@ -2551,7 +2550,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): 'x_1': {'unique': True, 'key': [('x', 1)]}} :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -2572,7 +2571,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def list_search_indexes( self, name: Optional[str] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> AsyncCommandCursor[Mapping[str, Any]]: @@ -2582,7 +2581,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): for. Only indexes with matching index names will be returned. If not given, all search indexes for the current collection will be returned. - :param session: a :class:`~pymongo.client_session.ClientSession`. + :param session: a :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -2625,7 +2624,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def create_search_index( self, model: Union[Mapping[str, Any], SearchIndexModel], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Any = None, **kwargs: Any, ) -> str: @@ -2636,7 +2635,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): instance or a dictionary with a model "definition" and optional "name". :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the createSearchIndexes @@ -2655,14 +2654,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def create_search_indexes( self, models: list[SearchIndexModel], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> list[str]: """Create multiple search indexes for the current collection. :param models: A list of :class:`~pymongo.operations.SearchIndexModel` instances. - :param session: a :class:`~pymongo.client_session.ClientSession`. + :param session: a :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the createSearchIndexes @@ -2679,7 +2678,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def _create_search_indexes( self, models: list[SearchIndexModel], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> list[str]: @@ -2711,7 +2710,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def drop_search_index( self, name: str, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> None: @@ -2719,7 +2718,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): :param name: The name of the search index to be deleted. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the dropSearchIndexes @@ -2746,7 +2745,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): self, name: str, definition: Mapping[str, Any], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> None: @@ -2755,7 +2754,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): :param name: The name of the search index to be updated. :param definition: The new search index definition. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the updateSearchIndexes @@ -2780,7 +2779,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def options( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> MutableMapping[str, Any]: """Get the options set on this collection. @@ -2791,7 +2790,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): dictionary if the collection has not been created yet. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -2830,7 +2829,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): aggregation_command: Type[_AggregationCommand], pipeline: _Pipeline, cursor_class: Type[AsyncCommandCursor], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], explicit_session: bool, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -2859,7 +2858,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def aggregate( self, pipeline: _Pipeline, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -2882,7 +2881,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): :param pipeline: a list of aggregation pipeline stages :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: A dict of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -2954,7 +2953,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def aggregate_raw_batches( self, pipeline: _Pipeline, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> AsyncRawBatchCursor[_DocumentType]: @@ -3003,7 +3002,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): async def rename( self, new_name: str, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> MutableMapping[str, Any]: @@ -3017,7 +3016,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): :param new_name: new name for this collection :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: additional arguments to the rename command @@ -3067,7 +3066,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): self, key: str, filter: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> list: @@ -3093,7 +3092,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): :param filter: A query document that specifies the documents from which to retrieve the distinct values. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: See list of options above. @@ -3118,9 +3117,9 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): cmd["comment"] = comment async def _cmd( - session: Optional[ClientSession], + session: Optional[AsyncClientSession], _server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: Optional[_ServerMode], ) -> list: return ( @@ -3146,7 +3145,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): return_document: bool = ReturnDocument.BEFORE, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping] = None, **kwargs: Any, ) -> Any: @@ -3163,20 +3162,20 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): cmd["let"] = let cmd.update(kwargs) if projection is not None: - cmd["fields"] = helpers._fields_list_to_dict(projection, "projection") + cmd["fields"] = helpers_shared._fields_list_to_dict(projection, "projection") if sort is not None: - cmd["sort"] = helpers._index_document(sort) + cmd["sort"] = helpers_shared._index_document(sort) if upsert is not None: validate_boolean("upsert", upsert) cmd["upsert"] = upsert if hint is not None: if not isinstance(hint, str): - hint = helpers._index_document(hint) + hint = helpers_shared._index_document(hint) write_concern = self._write_concern_for_cmd(cmd, session) async def _find_and_modify_helper( - session: Optional[ClientSession], conn: Connection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool ) -> Any: acknowledged = write_concern.acknowledged if array_filters is not None: @@ -3222,7 +3221,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, sort: Optional[_IndexList] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -3268,7 +3267,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -3314,7 +3313,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): upsert: bool = False, return_document: bool = ReturnDocument.BEFORE, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -3366,7 +3365,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -3422,7 +3421,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): return_document: bool = ReturnDocument.BEFORE, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -3513,7 +3512,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index 000df160d..f554d7f25 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -30,22 +30,22 @@ from typing import ( from bson import CodecOptions, _convert_raw_document_lists_to_streams from pymongo.asynchronous.cursor import _ConnectionManager -from pymongo.asynchronous.message import ( +from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS +from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure +from pymongo.message import ( _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore, ) -from pymongo.asynchronous.response import PinnedResponse -from pymongo.asynchronous.typings import _Address, _DocumentOut, _DocumentType -from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS -from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure +from pymongo.response import PinnedResponse +from pymongo.typings import _Address, _DocumentOut, _DocumentType if TYPE_CHECKING: - from pymongo.asynchronous.client_session import ClientSession + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.collection import AsyncCollection - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection _IS_SYNC = False @@ -62,7 +62,7 @@ class AsyncCommandCursor(Generic[_DocumentType]): address: Optional[_Address], batch_size: int = 0, max_await_time_ms: Optional[int] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, explicit_session: bool = False, comment: Any = None, ) -> None: @@ -133,7 +133,7 @@ class AsyncCommandCursor(Generic[_DocumentType]): """ return self._postbatchresumetoken - async def _maybe_pin_connection(self, conn: Connection) -> None: + async def _maybe_pin_connection(self, conn: AsyncConnection) -> None: client = self._collection.database.client if not client._should_pin_cursor(self._session): return @@ -188,8 +188,8 @@ class AsyncCommandCursor(Generic[_DocumentType]): return self._address @property - def session(self) -> Optional[ClientSession]: - """The cursor's :class:`~pymongo.client_session.ClientSession`, or None. + def session(self) -> Optional[AsyncClientSession]: + """The cursor's :class:`~pymongo.client_session.AsyncClientSession`, or None. .. versionadded:: 3.6 """ @@ -272,7 +272,7 @@ class AsyncCommandCursor(Generic[_DocumentType]): if isinstance(response, PinnedResponse): if not self._sock_mgr: - self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) + self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type] if response.from_command: cursor = response.docs[0]["cursor"] documents = cursor["nextBatch"] @@ -384,7 +384,7 @@ class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]): address: Optional[_Address], batch_size: int = 0, max_await_time_ms: Optional[int] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, explicit_session: bool = False, comment: Any = None, ) -> None: diff --git a/pymongo/asynchronous/common.py b/pymongo/asynchronous/common.py deleted file mode 100644 index 7dcfa2938..000000000 --- a/pymongo/asynchronous/common.py +++ /dev/null @@ -1,1062 +0,0 @@ -# Copyright 2011-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - - -"""Functions and classes common to multiple pymongo modules.""" -from __future__ import annotations - -import datetime -import warnings -from collections import OrderedDict, abc -from difflib import get_close_matches -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterator, - Mapping, - MutableMapping, - NoReturn, - Optional, - Sequence, - Type, - Union, - overload, -) -from urllib.parse import unquote_plus - -from bson import SON -from bson.binary import UuidRepresentation -from bson.codec_options import CodecOptions, DatetimeConversion, TypeRegistry -from bson.raw_bson import RawBSONDocument -from pymongo.asynchronous.compression_support import ( - validate_compressors, - validate_zlib_compression_level, -) -from pymongo.asynchronous.monitoring import _validate_event_listeners -from pymongo.asynchronous.read_preferences import _MONGOS_MODES, _ServerMode -from pymongo.driver_info import DriverInfo -from pymongo.errors import ConfigurationError -from pymongo.read_concern import ReadConcern -from pymongo.server_api import ServerApi -from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean - -if TYPE_CHECKING: - from pymongo.asynchronous.client_session import ClientSession - -_IS_SYNC = False - -ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict) - -# Defaults until we connect to a server and get updated limits. -MAX_BSON_SIZE = 16 * (1024**2) -MAX_MESSAGE_SIZE: int = 2 * MAX_BSON_SIZE -MIN_WIRE_VERSION = 0 -MAX_WIRE_VERSION = 0 -MAX_WRITE_BATCH_SIZE = 1000 - -# What this version of PyMongo supports. -MIN_SUPPORTED_SERVER_VERSION = "3.6" -MIN_SUPPORTED_WIRE_VERSION = 6 -MAX_SUPPORTED_WIRE_VERSION = 21 - -# Frequency to call hello on servers, in seconds. -HEARTBEAT_FREQUENCY = 10 - -# Frequency to clean up unclosed cursors, in seconds. -# See MongoClient._process_kill_cursors. -KILL_CURSOR_FREQUENCY = 1 - -# Frequency to process events queue, in seconds. -EVENTS_QUEUE_FREQUENCY = 1 - -# How long to wait, in seconds, for a suitable server to be found before -# aborting an operation. For example, if the client attempts an insert -# during a replica set election, SERVER_SELECTION_TIMEOUT governs the -# longest it is willing to wait for a new primary to be found. -SERVER_SELECTION_TIMEOUT = 30 - -# Spec requires at least 500ms between hello calls. -MIN_HEARTBEAT_INTERVAL = 0.5 - -# Spec requires at least 60s between SRV rescans. -MIN_SRV_RESCAN_INTERVAL = 60 - -# Default connectTimeout in seconds. -CONNECT_TIMEOUT = 20.0 - -# Default value for maxPoolSize. -MAX_POOL_SIZE = 100 - -# Default value for minPoolSize. -MIN_POOL_SIZE = 0 - -# The maximum number of concurrent connection creation attempts per pool. -MAX_CONNECTING = 2 - -# Default value for maxIdleTimeMS. -MAX_IDLE_TIME_MS: Optional[int] = None - -# Default value for maxIdleTimeMS in seconds. -MAX_IDLE_TIME_SEC: Optional[int] = None - -# Default value for waitQueueTimeoutMS in seconds. -WAIT_QUEUE_TIMEOUT: Optional[int] = None - -# Default value for localThresholdMS. -LOCAL_THRESHOLD_MS = 15 - -# Default value for retryWrites. -RETRY_WRITES = True - -# Default value for retryReads. -RETRY_READS = True - -# The error code returned when a command doesn't exist. -COMMAND_NOT_FOUND_CODES: Sequence[int] = (59,) - -# Error codes to ignore if GridFS calls createIndex on a secondary -UNAUTHORIZED_CODES: Sequence[int] = (13, 16547, 16548) - -# Maximum number of sessions to send in a single endSessions command. -# From the driver sessions spec. -_MAX_END_SESSIONS = 10000 - -# Default value for srvServiceName -SRV_SERVICE_NAME = "mongodb" - -# Default value for serverMonitoringMode -SERVER_MONITORING_MODE = "auto" # poll/stream/auto - - -def partition_node(node: str) -> tuple[str, int]: - """Split a host:port string into (host, int(port)) pair.""" - host = node - port = 27017 - idx = node.rfind(":") - if idx != -1: - host, port = node[:idx], int(node[idx + 1 :]) - if host.startswith("["): - host = host[1:-1] - return host, port - - -def clean_node(node: str) -> tuple[str, int]: - """Split and normalize a node name from a hello response.""" - host, port = partition_node(node) - - # Normalize hostname to lowercase, since DNS is case-insensitive: - # http://tools.ietf.org/html/rfc4343 - # This prevents useless rediscovery if "foo.com" is in the seed list but - # "FOO.com" is in the hello response. - return host.lower(), port - - -def raise_config_error(key: str, suggestions: Optional[list] = None) -> NoReturn: - """Raise ConfigurationError with the given key name.""" - msg = f"Unknown option: {key}." - if suggestions: - msg += f" Did you mean one of ({', '.join(suggestions)}) or maybe a camelCase version of one? Refer to docstring." - raise ConfigurationError(msg) - - -# Mapping of URI uuid representation options to valid subtypes. -_UUID_REPRESENTATIONS = { - "unspecified": UuidRepresentation.UNSPECIFIED, - "standard": UuidRepresentation.STANDARD, - "pythonLegacy": UuidRepresentation.PYTHON_LEGACY, - "javaLegacy": UuidRepresentation.JAVA_LEGACY, - "csharpLegacy": UuidRepresentation.CSHARP_LEGACY, -} - - -def validate_boolean_or_string(option: str, value: Any) -> bool: - """Validates that value is True, False, 'true', or 'false'.""" - if isinstance(value, str): - if value not in ("true", "false"): - raise ValueError(f"The value of {option} must be 'true' or 'false'") - return value == "true" - return validate_boolean(option, value) - - -def validate_integer(option: str, value: Any) -> int: - """Validates that 'value' is an integer (or basestring representation).""" - if isinstance(value, int): - return value - elif isinstance(value, str): - try: - return int(value) - except ValueError: - raise ValueError(f"The value of {option} must be an integer") from None - raise TypeError(f"Wrong type for {option}, value must be an integer") - - -def validate_positive_integer(option: str, value: Any) -> int: - """Validate that 'value' is a positive integer, which does not include 0.""" - val = validate_integer(option, value) - if val <= 0: - raise ValueError(f"The value of {option} must be a positive integer") - return val - - -def validate_non_negative_integer(option: str, value: Any) -> int: - """Validate that 'value' is a positive integer or 0.""" - val = validate_integer(option, value) - if val < 0: - raise ValueError(f"The value of {option} must be a non negative integer") - return val - - -def validate_readable(option: str, value: Any) -> Optional[str]: - """Validates that 'value' is file-like and readable.""" - if value is None: - return value - # First make sure its a string py3.3 open(True, 'r') succeeds - # Used in ssl cert checking due to poor ssl module error reporting - value = validate_string(option, value) - open(value).close() - return value - - -def validate_positive_integer_or_none(option: str, value: Any) -> Optional[int]: - """Validate that 'value' is a positive integer or None.""" - if value is None: - return value - return validate_positive_integer(option, value) - - -def validate_non_negative_integer_or_none(option: str, value: Any) -> Optional[int]: - """Validate that 'value' is a positive integer or 0 or None.""" - if value is None: - return value - return validate_non_negative_integer(option, value) - - -def validate_string(option: str, value: Any) -> str: - """Validates that 'value' is an instance of `str`.""" - if isinstance(value, str): - return value - raise TypeError(f"Wrong type for {option}, value must be an instance of str") - - -def validate_string_or_none(option: str, value: Any) -> Optional[str]: - """Validates that 'value' is an instance of `basestring` or `None`.""" - if value is None: - return value - return validate_string(option, value) - - -def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]: - """Validates that 'value' is an integer or string.""" - if isinstance(value, int): - return value - elif isinstance(value, str): - try: - return int(value) - except ValueError: - return value - raise TypeError(f"Wrong type for {option}, value must be an integer or a string") - - -def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]: - """Validates that 'value' is an integer or string.""" - if isinstance(value, int): - return value - elif isinstance(value, str): - try: - val = int(value) - except ValueError: - return value - return validate_non_negative_integer(option, val) - raise TypeError(f"Wrong type for {option}, value must be an non negative integer or a string") - - -def validate_positive_float(option: str, value: Any) -> float: - """Validates that 'value' is a float, or can be converted to one, and is - positive. - """ - errmsg = f"{option} must be an integer or float" - try: - value = float(value) - except ValueError: - raise ValueError(errmsg) from None - except TypeError: - raise TypeError(errmsg) from None - - # float('inf') doesn't work in 2.4 or 2.5 on Windows, so just cap floats at - # one billion - this is a reasonable approximation for infinity - if not 0 < value < 1e9: - raise ValueError(f"{option} must be greater than 0 and less than one billion") - return value - - -def validate_positive_float_or_zero(option: str, value: Any) -> float: - """Validates that 'value' is 0 or a positive float, or can be converted to - 0 or a positive float. - """ - if value == 0 or value == "0": - return 0 - return validate_positive_float(option, value) - - -def validate_timeout_or_none(option: str, value: Any) -> Optional[float]: - """Validates a timeout specified in milliseconds returning - a value in floating point seconds. - """ - if value is None: - return value - return validate_positive_float(option, value) / 1000.0 - - -def validate_timeout_or_zero(option: str, value: Any) -> float: - """Validates a timeout specified in milliseconds returning - a value in floating point seconds for the case where None is an error - and 0 is valid. Setting the timeout to nothing in the URI string is a - config error. - """ - if value is None: - raise ConfigurationError(f"{option} cannot be None") - if value == 0 or value == "0": - return 0 - return validate_positive_float(option, value) / 1000.0 - - -def validate_timeout_or_none_or_zero(option: Any, value: Any) -> Optional[float]: - """Validates a timeout specified in milliseconds returning - a value in floating point seconds. value=0 and value="0" are treated the - same as value=None which means unlimited timeout. - """ - if value is None or value == 0 or value == "0": - return None - return validate_positive_float(option, value) / 1000.0 - - -def validate_timeoutms(option: Any, value: Any) -> Optional[float]: - """Validates a timeout specified in milliseconds returning - a value in floating point seconds. - """ - if value is None: - return None - return validate_positive_float_or_zero(option, value) / 1000.0 - - -def validate_max_staleness(option: str, value: Any) -> int: - """Validates maxStalenessSeconds according to the Max Staleness Spec.""" - if value == -1 or value == "-1": - # Default: No maximum staleness. - return -1 - return validate_positive_integer(option, value) - - -def validate_read_preference(dummy: Any, value: Any) -> _ServerMode: - """Validate a read preference.""" - if not isinstance(value, _ServerMode): - raise TypeError(f"{value!r} is not a read preference.") - return value - - -def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode: - """Validate read preference mode for a MongoClient. - - .. versionchanged:: 3.5 - Returns the original ``value`` instead of the validated read preference - mode. - """ - if value not in _MONGOS_MODES: - raise ValueError(f"{value} is not a valid read preference") - return value - - -def validate_auth_mechanism(option: str, value: Any) -> str: - """Validate the authMechanism URI option.""" - from pymongo.asynchronous.auth import MECHANISMS - - if value not in MECHANISMS: - raise ValueError(f"{option} must be in {tuple(MECHANISMS)}") - return value - - -def validate_uuid_representation(dummy: Any, value: Any) -> int: - """Validate the uuid representation option selected in the URI.""" - try: - return _UUID_REPRESENTATIONS[value] - except KeyError: - raise ValueError( - f"{value} is an invalid UUID representation. " - "Must be one of " - f"{tuple(_UUID_REPRESENTATIONS)}" - ) from None - - -def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]]: - """Parse readPreferenceTags if passed as a client kwarg.""" - if not isinstance(value, list): - value = [value] - - tag_sets: list = [] - for tag_set in value: - if tag_set == "": - tag_sets.append({}) - continue - try: - tags = {} - for tag in tag_set.split(","): - key, val = tag.split(":") - tags[unquote_plus(key)] = unquote_plus(val) - tag_sets.append(tags) - except Exception: - raise ValueError(f"{tag_set!r} not a valid value for {name}") from None - return tag_sets - - -_MECHANISM_PROPS = frozenset( - [ - "SERVICE_NAME", - "CANONICALIZE_HOST_NAME", - "SERVICE_REALM", - "AWS_SESSION_TOKEN", - "ENVIRONMENT", - "TOKEN_RESOURCE", - ] -) - - -def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Union[bool, str]]: - """Validate authMechanismProperties.""" - props: dict[str, Any] = {} - if not isinstance(value, str): - if not isinstance(value, dict): - raise ValueError("Auth mechanism properties must be given as a string or a dictionary") - for key, value in value.items(): # noqa: B020 - if isinstance(value, str): - props[key] = value - elif isinstance(value, bool): - props[key] = str(value).lower() - elif key in ["ALLOWED_HOSTS"] and isinstance(value, list): - props[key] = value - elif key in ["OIDC_CALLBACK", "OIDC_HUMAN_CALLBACK"]: - from pymongo.asynchronous.auth_oidc import OIDCCallback - - if not isinstance(value, OIDCCallback): - raise ValueError("callback must be an OIDCCallback object") - props[key] = value - else: - raise ValueError(f"Invalid type for auth mechanism property {key}, {type(value)}") - return props - - value = validate_string(option, value) - value = unquote_plus(value) - for opt in value.split(","): - key, _, val = opt.partition(":") - if not val: - raise ValueError("Malformed auth mechanism properties") - if key not in _MECHANISM_PROPS: - # Try not to leak the token. - if "AWS_SESSION_TOKEN" in key: - raise ValueError( - "auth mechanism properties must be " - "key:value pairs like AWS_SESSION_TOKEN:" - ) - - raise ValueError( - f"{key} is not a supported auth " - "mechanism property. Must be one of " - f"{tuple(_MECHANISM_PROPS)}." - ) - - if key == "CANONICALIZE_HOST_NAME": - props[key] = validate_boolean_or_string(key, val) - else: - props[key] = val - - return props - - -def validate_document_class( - option: str, value: Any -) -> Union[Type[MutableMapping], Type[RawBSONDocument]]: - """Validate the document_class option.""" - # issubclass can raise TypeError for generic aliases like SON[str, Any]. - # In that case we can use the base class for the comparison. - is_mapping = False - try: - is_mapping = issubclass(value, abc.MutableMapping) - except TypeError: - if hasattr(value, "__origin__"): - is_mapping = issubclass(value.__origin__, abc.MutableMapping) - if not is_mapping and not issubclass(value, RawBSONDocument): - raise TypeError( - f"{option} must be dict, bson.son.SON, " - "bson.raw_bson.RawBSONDocument, or a " - "subclass of collections.MutableMapping" - ) - return value - - -def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]: - """Validate the type_registry option.""" - if value is not None and not isinstance(value, TypeRegistry): - raise TypeError(f"{option} must be an instance of {TypeRegistry}") - return value - - -def validate_list(option: str, value: Any) -> list: - """Validates that 'value' is a list.""" - if not isinstance(value, list): - raise TypeError(f"{option} must be a list") - return value - - -def validate_list_or_none(option: Any, value: Any) -> Optional[list]: - """Validates that 'value' is a list or None.""" - if value is None: - return value - return validate_list(option, value) - - -def validate_list_or_mapping(option: Any, value: Any) -> None: - """Validates that 'value' is a list or a document.""" - if not isinstance(value, (abc.Mapping, list)): - raise TypeError( - f"{option} must either be a list or an instance of dict, " - "bson.son.SON, or any other type that inherits from " - "collections.Mapping" - ) - - -def validate_is_mapping(option: str, value: Any) -> None: - """Validate the type of method arguments that expect a document.""" - if not isinstance(value, abc.Mapping): - raise TypeError( - f"{option} must be an instance of dict, bson.son.SON, or " - "any other type that inherits from " - "collections.Mapping" - ) - - -def validate_is_document_type(option: str, value: Any) -> None: - """Validate the type of method arguments that expect a MongoDB document.""" - if not isinstance(value, (abc.MutableMapping, RawBSONDocument)): - raise TypeError( - f"{option} must be an instance of dict, bson.son.SON, " - "bson.raw_bson.RawBSONDocument, or " - "a type that inherits from " - "collections.MutableMapping" - ) - - -def validate_appname_or_none(option: str, value: Any) -> Optional[str]: - """Validate the appname option.""" - if value is None: - return value - validate_string(option, value) - # We need length in bytes, so encode utf8 first. - if len(value.encode("utf-8")) > 128: - raise ValueError(f"{option} must be <= 128 bytes") - return value - - -def validate_driver_or_none(option: Any, value: Any) -> Optional[DriverInfo]: - """Validate the driver keyword arg.""" - if value is None: - return value - if not isinstance(value, DriverInfo): - raise TypeError(f"{option} must be an instance of DriverInfo") - return value - - -def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]: - """Validate the server_api keyword arg.""" - if value is None: - return value - if not isinstance(value, ServerApi): - raise TypeError(f"{option} must be an instance of ServerApi") - return value - - -def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]: - """Validates that 'value' is a callable.""" - if value is None: - return value - if not callable(value): - raise ValueError(f"{option} must be a callable") - return value - - -def validate_ok_for_replace(replacement: Mapping[str, Any]) -> None: - """Validate a replacement document.""" - validate_is_mapping("replacement", replacement) - # Replacement can be {} - if replacement and not isinstance(replacement, RawBSONDocument): - first = next(iter(replacement)) - if first.startswith("$"): - raise ValueError("replacement can not include $ operators") - - -def validate_ok_for_update(update: Any) -> None: - """Validate an update document.""" - validate_list_or_mapping("update", update) - # Update cannot be {}. - if not update: - raise ValueError("update cannot be empty") - - is_document = not isinstance(update, list) - first = next(iter(update)) - if is_document and not first.startswith("$"): - raise ValueError("update only works with $ operators") - - -_UNICODE_DECODE_ERROR_HANDLERS = frozenset(["strict", "replace", "ignore"]) - - -def validate_unicode_decode_error_handler(dummy: Any, value: str) -> str: - """Validate the Unicode decode error handler option of CodecOptions.""" - if value not in _UNICODE_DECODE_ERROR_HANDLERS: - raise ValueError( - f"{value} is an invalid Unicode decode error handler. " - "Must be one of " - f"{tuple(_UNICODE_DECODE_ERROR_HANDLERS)}" - ) - return value - - -def validate_tzinfo(dummy: Any, value: Any) -> Optional[datetime.tzinfo]: - """Validate the tzinfo option""" - if value is not None and not isinstance(value, datetime.tzinfo): - raise TypeError("%s must be an instance of datetime.tzinfo" % value) - return value - - -def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[Any]: - """Validate the driver keyword arg.""" - if value is None: - return value - from pymongo.asynchronous.encryption_options import AutoEncryptionOpts - - if not isinstance(value, AutoEncryptionOpts): - raise TypeError(f"{option} must be an instance of AutoEncryptionOpts") - - return value - - -def validate_datetime_conversion(option: Any, value: Any) -> Optional[DatetimeConversion]: - """Validate a DatetimeConversion string.""" - if value is None: - return DatetimeConversion.DATETIME - - if isinstance(value, str): - if value.isdigit(): - return DatetimeConversion(int(value)) - return DatetimeConversion[value] - elif isinstance(value, int): - return DatetimeConversion(value) - - raise TypeError(f"{option} must be a str or int representing DatetimeConversion") - - -def validate_server_monitoring_mode(option: str, value: str) -> str: - """Validate the serverMonitoringMode option.""" - if value not in {"auto", "stream", "poll"}: - raise ValueError( - f'{option}={value!r} is invalid. Must be one of "auto", "stream", or "poll"' - ) - return value - - -# Dictionary where keys are the names of public URI options, and values -# are lists of aliases for that option. -URI_OPTIONS_ALIAS_MAP: dict[str, list[str]] = { - "tls": ["ssl"], -} - -# Dictionary where keys are the names of URI options, and values -# are functions that validate user-input values for that option. If an option -# alias uses a different validator than its public counterpart, it should be -# included here as a key, value pair. -URI_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = { - "appname": validate_appname_or_none, - "authmechanism": validate_auth_mechanism, - "authmechanismproperties": validate_auth_mechanism_properties, - "authsource": validate_string, - "compressors": validate_compressors, - "connecttimeoutms": validate_timeout_or_none_or_zero, - "directconnection": validate_boolean_or_string, - "heartbeatfrequencyms": validate_timeout_or_none, - "journal": validate_boolean_or_string, - "localthresholdms": validate_positive_float_or_zero, - "maxidletimems": validate_timeout_or_none, - "maxconnecting": validate_positive_integer, - "maxpoolsize": validate_non_negative_integer_or_none, - "maxstalenessseconds": validate_max_staleness, - "readconcernlevel": validate_string_or_none, - "readpreference": validate_read_preference_mode, - "readpreferencetags": validate_read_preference_tags, - "replicaset": validate_string_or_none, - "retryreads": validate_boolean_or_string, - "retrywrites": validate_boolean_or_string, - "loadbalanced": validate_boolean_or_string, - "serverselectiontimeoutms": validate_timeout_or_zero, - "sockettimeoutms": validate_timeout_or_none_or_zero, - "tls": validate_boolean_or_string, - "tlsallowinvalidcertificates": validate_boolean_or_string, - "tlsallowinvalidhostnames": validate_boolean_or_string, - "tlscafile": validate_readable, - "tlscertificatekeyfile": validate_readable, - "tlscertificatekeyfilepassword": validate_string_or_none, - "tlsdisableocspendpointcheck": validate_boolean_or_string, - "tlsinsecure": validate_boolean_or_string, - "w": validate_non_negative_int_or_basestring, - "wtimeoutms": validate_non_negative_integer, - "zlibcompressionlevel": validate_zlib_compression_level, - "srvservicename": validate_string, - "srvmaxhosts": validate_non_negative_integer, - "timeoutms": validate_timeoutms, - "servermonitoringmode": validate_server_monitoring_mode, -} - -# Dictionary where keys are the names of URI options specific to pymongo, -# and values are functions that validate user-input values for those options. -NONSPEC_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = { - "connect": validate_boolean_or_string, - "driver": validate_driver_or_none, - "server_api": validate_server_api_or_none, - "fsync": validate_boolean_or_string, - "minpoolsize": validate_non_negative_integer, - "tlscrlfile": validate_readable, - "tz_aware": validate_boolean_or_string, - "unicode_decode_error_handler": validate_unicode_decode_error_handler, - "uuidrepresentation": validate_uuid_representation, - "waitqueuemultiple": validate_non_negative_integer_or_none, - "waitqueuetimeoutms": validate_timeout_or_none, - "datetime_conversion": validate_datetime_conversion, -} - -# Dictionary where keys are the names of keyword-only options for the -# MongoClient constructor, and values are functions that validate user-input -# values for those options. -KW_VALIDATORS: dict[str, Callable[[Any, Any], Any]] = { - "document_class": validate_document_class, - "type_registry": validate_type_registry, - "read_preference": validate_read_preference, - "event_listeners": _validate_event_listeners, - "tzinfo": validate_tzinfo, - "username": validate_string_or_none, - "password": validate_string_or_none, - "server_selector": validate_is_callable_or_none, - "auto_encryption_opts": validate_auto_encryption_opts_or_none, - "authoidcallowedhosts": validate_list, -} - -# Dictionary where keys are any URI option name, and values are the -# internally-used names of that URI option. Options with only one name -# variant need not be included here. Options whose public and internal -# names are the same need not be included here. -INTERNAL_URI_OPTION_NAME_MAP: dict[str, str] = { - "ssl": "tls", -} - -# Map from deprecated URI option names to a tuple indicating the method of -# their deprecation and any additional information that may be needed to -# construct the warning message. -URI_OPTIONS_DEPRECATION_MAP: dict[str, tuple[str, str]] = { - # format: : (, ), - # Supported values: - # - 'renamed': should be the new option name. Note that case is - # preserved for renamed options as they are part of user warnings. - # - 'removed': may suggest the rationale for deprecating the - # option and/or recommend remedial action. - # For example: - # 'wtimeout': ('renamed', 'wTimeoutMS'), -} - -# Augment the option validator map with pymongo-specific option information. -URI_OPTIONS_VALIDATOR_MAP.update(NONSPEC_OPTIONS_VALIDATOR_MAP) -for optname, aliases in URI_OPTIONS_ALIAS_MAP.items(): - for alias in aliases: - if alias not in URI_OPTIONS_VALIDATOR_MAP: - URI_OPTIONS_VALIDATOR_MAP[alias] = URI_OPTIONS_VALIDATOR_MAP[optname] - -# Map containing all URI option and keyword argument validators. -VALIDATORS: dict[str, Callable[[Any, Any], Any]] = URI_OPTIONS_VALIDATOR_MAP.copy() -VALIDATORS.update(KW_VALIDATORS) - -# List of timeout-related options. -TIMEOUT_OPTIONS: list[str] = [ - "connecttimeoutms", - "heartbeatfrequencyms", - "maxidletimems", - "maxstalenessseconds", - "serverselectiontimeoutms", - "sockettimeoutms", - "waitqueuetimeoutms", -] - -_AUTH_OPTIONS = frozenset(["authmechanismproperties"]) - - -def validate_auth_option(option: str, value: Any) -> tuple[str, Any]: - """Validate optional authentication parameters.""" - lower, value = validate(option, value) - if lower not in _AUTH_OPTIONS: - raise ConfigurationError(f"Unknown option: {option}. Must be in {_AUTH_OPTIONS}") - return option, value - - -def _get_validator( - key: str, validators: dict[str, Callable[[Any, Any], Any]], normed_key: Optional[str] = None -) -> Callable: - normed_key = normed_key or key - try: - return validators[normed_key] - except KeyError: - suggestions = get_close_matches(normed_key, validators, cutoff=0.2) - raise_config_error(key, suggestions) - - -def validate(option: str, value: Any) -> tuple[str, Any]: - """Generic validation function.""" - validator = _get_validator(option, VALIDATORS, normed_key=option.lower()) - value = validator(option, value) - return option, value - - -def get_validated_options( - options: Mapping[str, Any], warn: bool = True -) -> MutableMapping[str, Any]: - """Validate each entry in options and raise a warning if it is not valid. - Returns a copy of options with invalid entries removed. - - :param opts: A dict containing MongoDB URI options. - :param warn: If ``True`` then warnings will be logged and - invalid options will be ignored. Otherwise, invalid options will - cause errors. - """ - validated_options: MutableMapping[str, Any] - if isinstance(options, _CaseInsensitiveDictionary): - validated_options = _CaseInsensitiveDictionary() - - def get_normed_key(x: str) -> str: - return x - - def get_setter_key(x: str) -> str: - return options.cased_key(x) # type: ignore[attr-defined] - - else: - validated_options = {} - - def get_normed_key(x: str) -> str: - return x.lower() - - def get_setter_key(x: str) -> str: - return x - - for opt, value in options.items(): - normed_key = get_normed_key(opt) - try: - validator = _get_validator(opt, URI_OPTIONS_VALIDATOR_MAP, normed_key=normed_key) - validated = validator(opt, value) - except (ValueError, TypeError, ConfigurationError) as exc: - if warn: - warnings.warn(str(exc), stacklevel=2) - else: - raise - else: - validated_options[get_setter_key(normed_key)] = validated - return validated_options - - -def _esc_coll_name(encrypted_fields: Mapping[str, Any], name: str) -> Any: - return encrypted_fields.get("escCollection", f"enxcol_.{name}.esc") - - -def _ecoc_coll_name(encrypted_fields: Mapping[str, Any], name: str) -> Any: - return encrypted_fields.get("ecocCollection", f"enxcol_.{name}.ecoc") - - -# List of write-concern-related options. -WRITE_CONCERN_OPTIONS = frozenset(["w", "wtimeout", "wtimeoutms", "fsync", "j", "journal"]) - - -class BaseObject: - """A base class that provides attributes and methods common - to multiple pymongo classes. - - SHOULD NOT BE USED BY DEVELOPERS EXTERNAL TO MONGODB. - """ - - def __init__( - self, - codec_options: CodecOptions, - read_preference: _ServerMode, - write_concern: WriteConcern, - read_concern: ReadConcern, - ) -> None: - if not isinstance(codec_options, CodecOptions): - raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") - self._codec_options = codec_options - - if not isinstance(read_preference, _ServerMode): - raise TypeError( - f"{read_preference!r} is not valid for read_preference. See " - "pymongo.read_preferences for valid " - "options." - ) - self._read_preference = read_preference - - if not isinstance(write_concern, WriteConcern): - raise TypeError( - "write_concern must be an instance of pymongo.write_concern.WriteConcern" - ) - self._write_concern = write_concern - - if not isinstance(read_concern, ReadConcern): - raise TypeError("read_concern must be an instance of pymongo.read_concern.ReadConcern") - self._read_concern = read_concern - - @property - def codec_options(self) -> CodecOptions: - """Read only access to the :class:`~bson.codec_options.CodecOptions` - of this instance. - """ - return self._codec_options - - @property - def write_concern(self) -> WriteConcern: - """Read only access to the :class:`~pymongo.write_concern.WriteConcern` - of this instance. - - .. versionchanged:: 3.0 - The :attr:`write_concern` attribute is now read only. - """ - return self._write_concern - - def _write_concern_for(self, session: Optional[ClientSession]) -> WriteConcern: - """Read only access to the write concern of this instance or session.""" - # Override this operation's write concern with the transaction's. - if session and session.in_transaction: - return DEFAULT_WRITE_CONCERN - return self.write_concern - - @property - def read_preference(self) -> _ServerMode: - """Read only access to the read preference of this instance. - - .. versionchanged:: 3.0 - The :attr:`read_preference` attribute is now read only. - """ - return self._read_preference - - def _read_preference_for(self, session: Optional[ClientSession]) -> _ServerMode: - """Read only access to the read preference of this instance or session.""" - # Override this operation's read preference with the transaction's. - if session: - return session._txn_read_preference() or self._read_preference - return self._read_preference - - @property - def read_concern(self) -> ReadConcern: - """Read only access to the :class:`~pymongo.read_concern.ReadConcern` - of this instance. - - .. versionadded:: 3.2 - """ - return self._read_concern - - -class _CaseInsensitiveDictionary(MutableMapping[str, Any]): - def __init__(self, *args: Any, **kwargs: Any): - self.__casedkeys: dict[str, Any] = {} - self.__data: dict[str, Any] = {} - self.update(dict(*args, **kwargs)) - - def __contains__(self, key: str) -> bool: # type: ignore[override] - return key.lower() in self.__data - - def __len__(self) -> int: - return len(self.__data) - - def __iter__(self) -> Iterator[str]: - return (key for key in self.__casedkeys) - - def __repr__(self) -> str: - return str({self.__casedkeys[k]: self.__data[k] for k in self}) - - def __setitem__(self, key: str, value: Any) -> None: - lc_key = key.lower() - self.__casedkeys[lc_key] = key - self.__data[lc_key] = value - - def __getitem__(self, key: str) -> Any: - return self.__data[key.lower()] - - def __delitem__(self, key: str) -> None: - lc_key = key.lower() - del self.__casedkeys[lc_key] - del self.__data[lc_key] - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, abc.Mapping): - return NotImplemented - if len(self) != len(other): - return False - for key in other: # noqa: SIM110 - if self[key] != other[key]: - return False - - return True - - def get(self, key: str, default: Optional[Any] = None) -> Any: - return self.__data.get(key.lower(), default) - - def pop(self, key: str, *args: Any, **kwargs: Any) -> Any: - lc_key = key.lower() - self.__casedkeys.pop(lc_key, None) - return self.__data.pop(lc_key, *args, **kwargs) - - def popitem(self) -> tuple[str, Any]: - lc_key, cased_key = self.__casedkeys.popitem() - value = self.__data.pop(lc_key) - return cased_key, value - - def clear(self) -> None: - self.__casedkeys.clear() - self.__data.clear() - - @overload - def setdefault(self, key: str, default: None = None) -> Optional[Any]: - ... - - @overload - def setdefault(self, key: str, default: Any) -> Any: - ... - - def setdefault(self, key: str, default: Optional[Any] = None) -> Optional[Any]: - lc_key = key.lower() - if key in self: - return self.__data[lc_key] - else: - self.__casedkeys[lc_key] = key - self.__data[lc_key] = default - return default - - def update(self, other: Mapping[str, Any]) -> None: # type: ignore[override] - if isinstance(other, _CaseInsensitiveDictionary): - for key in other: - self[other.cased_key(key)] = other[key] - else: - for key in other: - self[key] = other[key] - - def cased_key(self, key: str) -> Any: - return self.__casedkeys[key.lower()] diff --git a/pymongo/asynchronous/compression_support.py b/pymongo/asynchronous/compression_support.py deleted file mode 100644 index 8a39bfb46..000000000 --- a/pymongo/asynchronous/compression_support.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright 2018 MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import warnings -from typing import Any, Iterable, Optional, Union - -from pymongo.asynchronous.hello_compat import HelloCompat -from pymongo.helpers_constants import _SENSITIVE_COMMANDS - -_IS_SYNC = False - - -_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"} -_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD} -_NO_COMPRESSION.update(_SENSITIVE_COMMANDS) - - -def _have_snappy() -> bool: - try: - import snappy # type:ignore[import] # noqa: F401 - - return True - except ImportError: - return False - - -def _have_zlib() -> bool: - try: - import zlib # noqa: F401 - - return True - except ImportError: - return False - - -def _have_zstd() -> bool: - try: - import zstandard # noqa: F401 - - return True - except ImportError: - return False - - -def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[str]: - try: - # `value` is string. - compressors = value.split(",") # type: ignore[union-attr] - except AttributeError: - # `value` is an iterable. - compressors = list(value) - - for compressor in compressors[:]: - if compressor not in _SUPPORTED_COMPRESSORS: - compressors.remove(compressor) - warnings.warn(f"Unsupported compressor: {compressor}", stacklevel=2) - elif compressor == "snappy" and not _have_snappy(): - compressors.remove(compressor) - warnings.warn( - "Wire protocol compression with snappy is not available. " - "You must install the python-snappy module for snappy support.", - stacklevel=2, - ) - elif compressor == "zlib" and not _have_zlib(): - compressors.remove(compressor) - warnings.warn( - "Wire protocol compression with zlib is not available. " - "The zlib module is not available.", - stacklevel=2, - ) - elif compressor == "zstd" and not _have_zstd(): - compressors.remove(compressor) - warnings.warn( - "Wire protocol compression with zstandard is not available. " - "You must install the zstandard module for zstandard support.", - stacklevel=2, - ) - return compressors - - -def validate_zlib_compression_level(option: str, value: Any) -> int: - try: - level = int(value) - except Exception: - raise TypeError(f"{option} must be an integer, not {value!r}.") from None - if level < -1 or level > 9: - raise ValueError("%s must be between -1 and 9, not %d." % (option, level)) - return level - - -class CompressionSettings: - def __init__(self, compressors: list[str], zlib_compression_level: int): - self.compressors = compressors - self.zlib_compression_level = zlib_compression_level - - def get_compression_context( - self, compressors: Optional[list[str]] - ) -> Union[SnappyContext, ZlibContext, ZstdContext, None]: - if compressors: - chosen = compressors[0] - if chosen == "snappy": - return SnappyContext() - elif chosen == "zlib": - return ZlibContext(self.zlib_compression_level) - elif chosen == "zstd": - return ZstdContext() - return None - return None - - -class SnappyContext: - compressor_id = 1 - - @staticmethod - def compress(data: bytes) -> bytes: - import snappy - - return snappy.compress(data) - - -class ZlibContext: - compressor_id = 2 - - def __init__(self, level: int): - self.level = level - - def compress(self, data: bytes) -> bytes: - import zlib - - return zlib.compress(data, self.level) - - -class ZstdContext: - compressor_id = 3 - - @staticmethod - def compress(data: bytes) -> bytes: - # ZstdCompressor is not thread safe. - # TODO: Use a pool? - - import zstandard - - return zstandard.ZstdCompressor().compress(data) - - -def decompress(data: bytes, compressor_id: int) -> bytes: - if compressor_id == SnappyContext.compressor_id: - # python-snappy doesn't support the buffer interface. - # https://github.com/andrix/python-snappy/issues/65 - # This only matters when data is a memoryview since - # id(bytes(data)) == id(data) when data is a bytes. - import snappy - - return snappy.uncompress(bytes(data)) - elif compressor_id == ZlibContext.compressor_id: - import zlib - - return zlib.decompress(data) - elif compressor_id == ZstdContext.compressor_id: - # ZstdDecompressor is not thread safe. - # TODO: Use a pool? - import zstandard - - return zstandard.ZstdDecompressor().decompress(data) - else: - raise ValueError("Unknown compressorId %d" % (compressor_id,)) diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 5b4771bf8..469d09a98 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -36,14 +36,17 @@ from typing import ( from bson import RE_TYPE, _convert_raw_document_lists_to_streams from bson.code import Code from bson.son import SON -from pymongo.asynchronous import helpers -from pymongo.asynchronous.collation import validate_collation_or_none -from pymongo.asynchronous.common import ( +from pymongo import helpers_shared +from pymongo.asynchronous.helpers import anext +from pymongo.collation import validate_collation_or_none +from pymongo.common import ( validate_is_document_type, validate_is_mapping, ) -from pymongo.asynchronous.helpers import anext -from pymongo.asynchronous.message import ( +from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort +from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure +from pymongo.lock import _ALock, _create_lock +from pymongo.message import ( _CursorAddress, _GetMore, _OpMsg, @@ -52,21 +55,18 @@ from pymongo.asynchronous.message import ( _RawBatchGetMore, _RawBatchQuery, ) -from pymongo.asynchronous.response import PinnedResponse -from pymongo.asynchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType -from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort -from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure -from pymongo.lock import _ALock, _create_lock +from pymongo.response import PinnedResponse +from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType from pymongo.write_concern import validate_boolean if TYPE_CHECKING: from _typeshed import SupportsItems from bson.codec_options import CodecOptions - from pymongo.asynchronous.client_session import ClientSession + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.collection import AsyncCollection - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.read_preferences import _ServerMode + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.read_preferences import _ServerMode _IS_SYNC = False @@ -74,8 +74,8 @@ _IS_SYNC = False class _ConnectionManager: """Used with exhaust cursors to ensure the connection is returned.""" - def __init__(self, conn: Connection, more_to_come: bool): - self.conn: Optional[Connection] = conn + def __init__(self, conn: AsyncConnection, more_to_come: bool): + self.conn: Optional[AsyncConnection] = conn self.more_to_come = more_to_come self._alock = _ALock(_create_lock()) @@ -116,7 +116,7 @@ class AsyncCursor(Generic[_DocumentType]): show_record_id: Optional[bool] = None, snapshot: Optional[bool] = None, comment: Optional[Any] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, allow_disk_use: Optional[bool] = None, let: Optional[bool] = None, ) -> None: @@ -134,7 +134,7 @@ class AsyncCursor(Generic[_DocumentType]): self._exhaust = False self._sock_mgr: Any = None self._killed = False - self._session: Optional[ClientSession] + self._session: Optional[AsyncClientSession] if session: self._session = session @@ -179,7 +179,7 @@ class AsyncCursor(Generic[_DocumentType]): allow_disk_use = validate_boolean("allow_disk_use", allow_disk_use) if projection is not None: - projection = helpers._fields_list_to_dict(projection, "projection") + projection = helpers_shared._fields_list_to_dict(projection, "projection") if let is not None: validate_is_document_type("let", let) @@ -191,7 +191,7 @@ class AsyncCursor(Generic[_DocumentType]): self._skip = skip self._limit = limit self._batch_size = batch_size - self._ordering = sort and helpers._index_document(sort) or None + self._ordering = sort and helpers_shared._index_document(sort) or None self._max_scan = max_scan self._explain = False self._comment = comment @@ -313,7 +313,7 @@ class AsyncCursor(Generic[_DocumentType]): base.__dict__.update(data) return base - def _clone_base(self, session: Optional[ClientSession]) -> AsyncCursor: + def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncCursor: """Creates an empty Cursor object for information to be copied into.""" return self.__class__(self._collection, session=session) @@ -742,8 +742,8 @@ class AsyncCursor(Generic[_DocumentType]): key, if not given :data:`~pymongo.ASCENDING` is assumed """ self._check_okay_to_chain() - keys = helpers._index_list(key_or_list, direction) - self._ordering = helpers._index_document(keys) + keys = helpers_shared._index_list(key_or_list, direction) + self._ordering = helpers_shared._index_document(keys) return self async def explain(self) -> _DocumentType: @@ -774,7 +774,7 @@ class AsyncCursor(Generic[_DocumentType]): if isinstance(index, str): self._hint = index else: - self._hint = helpers._index_document(index) + self._hint = helpers_shared._index_document(index) def hint(self, index: Optional[_Hint]) -> AsyncCursor[_DocumentType]: """Adds a 'hint', telling Mongo the proper index to use for the query. @@ -927,8 +927,8 @@ class AsyncCursor(Generic[_DocumentType]): return self._address @property - def session(self) -> Optional[ClientSession]: - """The cursor's :class:`~pymongo.client_session.ClientSession`, or None. + def session(self) -> Optional[AsyncClientSession]: + """The cursor's :class:`~pymongo.client_session.AsyncClientSession`, or None. .. versionadded:: 3.6 """ @@ -1122,7 +1122,7 @@ class AsyncCursor(Generic[_DocumentType]): self._address = response.address if isinstance(response, PinnedResponse): if not self._sock_mgr: - self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) + self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type] cmd_name = operation.name docs = response.docs diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index 57ad71ece..f7a07027c 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -33,25 +33,24 @@ from typing import ( from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions from bson.dbref import DBRef from bson.timestamp import Timestamp -from pymongo import _csot -from pymongo.asynchronous import common +from pymongo import _csot, common from pymongo.asynchronous.aggregation import _DatabaseAggregationCommand from pymongo.asynchronous.change_stream import DatabaseChangeStream from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.common import _ecoc_coll_name, _esc_coll_name -from pymongo.asynchronous.operations import _Op -from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode -from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline +from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.database_shared import _check_name, _CodecDocumentType from pymongo.errors import CollectionInvalid, InvalidOperation +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline if TYPE_CHECKING: import bson import bson.codec_options - from pymongo.asynchronous.client_session import ClientSession + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server from pymongo.read_concern import ReadConcern from pymongo.write_concern import WriteConcern @@ -151,7 +150,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): >>> db1.read_preference Primary() - >>> from pymongo.asynchronous.read_preferences import Secondary + >>> from pymongo.read_preferences import Secondary >>> db2 = db1.with_options(read_preference=Secondary([{'node': 'analytics'}])) >>> db1.read_preference Primary() @@ -328,7 +327,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): batch_size: Optional[int] = None, collation: Optional[_CollationIn] = None, start_at_operation_time: Optional[Timestamp] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, start_after: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, full_document_before_change: Optional[str] = None, @@ -402,7 +401,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): the specified :class:`~bson.timestamp.Timestamp`. Requires MongoDB >= 4.0. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param start_after: The same as `resume_after` except that `start_after` can resume notifications after an invalidate event. This option and `resume_after` are mutually exclusive. @@ -458,7 +457,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional[ReadConcern] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, check_exists: Optional[bool] = True, **kwargs: Any, ) -> AsyncCollection[_DocumentType]: @@ -489,7 +488,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): :param collation: An instance of :class:`~pymongo.collation.Collation`. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param `check_exists`: if True (the default), send a listCollections command to check if the collection already exists before creation. :param kwargs: additional keyword arguments will @@ -607,7 +606,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): return coll async def aggregate( - self, pipeline: _Pipeline, session: Optional[ClientSession] = None, **kwargs: Any + self, pipeline: _Pipeline, session: Optional[AsyncClientSession] = None, **kwargs: Any ) -> AsyncCommandCursor[_DocumentType]: """Perform a database-level aggregation. @@ -634,7 +633,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): :param pipeline: a list of aggregation pipeline stages :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param kwargs: extra `aggregate command`_ parameters. All optional `aggregate command`_ parameters should be passed as @@ -688,7 +687,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): @overload async def _command( self, - conn: Connection, + conn: AsyncConnection, command: Union[str, MutableMapping[str, Any]], value: int = 1, check: bool = True, @@ -697,7 +696,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): codec_options: CodecOptions[dict[str, Any]] = DEFAULT_CODEC_OPTIONS, write_concern: Optional[WriteConcern] = None, parse_write_concern_error: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> dict[str, Any]: ... @@ -705,7 +704,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): @overload async def _command( self, - conn: Connection, + conn: AsyncConnection, command: Union[str, MutableMapping[str, Any]], value: int = 1, check: bool = True, @@ -714,14 +713,14 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): codec_options: CodecOptions[_CodecDocumentType] = ..., write_concern: Optional[WriteConcern] = None, parse_write_concern_error: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> _CodecDocumentType: ... async def _command( self, - conn: Connection, + conn: AsyncConnection, command: Union[str, MutableMapping[str, Any]], value: int = 1, check: bool = True, @@ -732,7 +731,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): ] = DEFAULT_CODEC_OPTIONS, write_concern: Optional[WriteConcern] = None, parse_write_concern_error: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> Union[dict[str, Any], _CodecDocumentType]: """Internal command helper.""" @@ -763,7 +762,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_preference: Optional[_ServerMode] = None, codec_options: None = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> dict[str, Any]: @@ -778,7 +777,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_preference: Optional[_ServerMode] = None, codec_options: CodecOptions[_CodecDocumentType] = ..., - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> _CodecDocumentType: @@ -793,7 +792,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_preference: Optional[_ServerMode] = None, codec_options: Optional[bson.codec_options.CodecOptions[_CodecDocumentType]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> Union[dict[str, Any], _CodecDocumentType]: @@ -852,7 +851,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): :param codec_options: A :class:`~bson.codec_options.CodecOptions` instance. :param session: A - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: additional keyword arguments will @@ -922,7 +921,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): value: Any = 1, read_preference: Optional[_ServerMode] = None, codec_options: Optional[CodecOptions[_CodecDocumentType]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, max_await_time_ms: Optional[int] = None, **kwargs: Any, @@ -953,7 +952,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): :param codec_options`: A :class:`~bson.codec_options.CodecOptions` instance. :param session: A - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to future getMores for this command. :param max_await_time_ms: The number of ms to wait for more data on future getMores for this command. @@ -1024,15 +1023,15 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): self, command: Union[str, MutableMapping[str, Any]], operation: str, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> dict[str, Any]: """Same as command but used for retryable read commands.""" read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY async def _cmd( - session: Optional[ClientSession], + session: Optional[AsyncClientSession], _server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: _ServerMode, ) -> dict[str, Any]: return await self._command( @@ -1046,8 +1045,8 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): async def _list_collections( self, - conn: Connection, - session: Optional[ClientSession], + conn: AsyncConnection, + session: Optional[AsyncClientSession], read_preference: _ServerMode, **kwargs: Any, ) -> AsyncCommandCursor[MutableMapping[str, Any]]: @@ -1075,7 +1074,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): async def _list_collections_helper( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, filter: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -1083,7 +1082,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): """Get a cursor over the collections of this database. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param filter: A query document to filter the list of collections returned from the listCollections command. :param comment: A user-provided comment to attach to this @@ -1106,9 +1105,9 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): kwargs["comment"] = comment async def _cmd( - session: Optional[ClientSession], + session: Optional[AsyncClientSession], _server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: _ServerMode, ) -> AsyncCommandCursor[MutableMapping[str, Any]]: return await self._list_collections( @@ -1121,7 +1120,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): async def list_collections( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, filter: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -1129,7 +1128,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): """Get a cursor over the collections of this database. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param filter: A query document to filter the list of collections returned from the listCollections command. :param comment: A user-provided comment to attach to this @@ -1149,7 +1148,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): async def _list_collection_names( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, filter: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -1174,7 +1173,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): async def list_collection_names( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, filter: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -1187,7 +1186,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): db.list_collection_names(filter=filter) :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param filter: A query document to filter the list of collections returned from the listCollections command. :param comment: A user-provided comment to attach to this @@ -1207,7 +1206,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): return await self._list_collection_names(session, filter, comment, **kwargs) async def _drop_helper( - self, name: str, session: Optional[ClientSession] = None, comment: Optional[Any] = None + self, name: str, session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None ) -> dict[str, Any]: command = {"drop": name} if comment is not None: @@ -1227,7 +1226,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): async def drop_collection( self, name_or_collection: Union[str, AsyncCollection[_DocumentTypeArg]], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, encrypted_fields: Optional[Mapping[str, Any]] = None, ) -> dict[str, Any]: @@ -1236,7 +1235,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): :param name_or_collection: the name of a collection to drop or the collection object itself :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param encrypted_fields: **(BETA)** Document that describes the encrypted fields for @@ -1306,7 +1305,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): name_or_collection: Union[str, AsyncCollection[_DocumentTypeArg]], scandata: bool = False, full: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, background: Optional[bool] = None, comment: Optional[Any] = None, ) -> dict[str, Any]: @@ -1326,7 +1325,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): of the structure of the collection and the individual documents. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param background: A boolean flag that determines whether the command runs in the background. Requires MongoDB 4.4+. :param comment: A user-provided comment to attach to this @@ -1386,7 +1385,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): async def dereference( self, dbref: DBRef, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> Optional[_DocumentType]: @@ -1401,7 +1400,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): :param dbref: the reference :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: any additional keyword arguments diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 0043cde75..13f11c326 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -59,16 +59,13 @@ from bson.errors import BSONError from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson from pymongo import _csot from pymongo.asynchronous.collection import AsyncCollection -from pymongo.asynchronous.common import CONNECT_TIMEOUT from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.database import AsyncDatabase -from pymongo.asynchronous.encryption_options import AutoEncryptionOpts, RangeOpts from pymongo.asynchronous.mongo_client import AsyncMongoClient -from pymongo.asynchronous.operations import UpdateOne -from pymongo.asynchronous.pool import PoolOptions, _configured_socket, _raise_connection_failure -from pymongo.asynchronous.typings import _DocumentType, _DocumentTypeArg -from pymongo.asynchronous.uri_parser import parse_host +from pymongo.asynchronous.pool import _configured_socket, _raise_connection_failure +from pymongo.common import CONNECT_TIMEOUT from pymongo.daemon import _spawn_daemon +from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts from pymongo.errors import ( ConfigurationError, EncryptedCollectionError, @@ -78,9 +75,13 @@ from pymongo.errors import ( ServerSelectionTimeoutError, ) from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall +from pymongo.operations import UpdateOne +from pymongo.pool_options import PoolOptions from pymongo.read_concern import ReadConcern from pymongo.results import BulkWriteResult, DeleteResult from pymongo.ssl_support import get_ssl_context +from pymongo.typings import _DocumentType, _DocumentTypeArg +from pymongo.uri_parser import parse_host from pymongo.write_concern import WriteConcern if TYPE_CHECKING: @@ -381,7 +382,10 @@ class _Encrypter: ) io_callbacks = _EncryptionIO( # type:ignore[misc] - metadata_client, key_vault_coll, mongocryptd_client, opts + metadata_client, + key_vault_coll, # type:ignore[arg-type] + mongocryptd_client, + opts, ) self._auto_encrypter = AsyncAutoEncrypter( io_callbacks, @@ -459,7 +463,14 @@ class Algorithm(str, enum.Enum): RANGE = "Range" """Range. - .. versionadded:: 4.8 + .. versionadded:: 4.9 + """ + RANGEPREVIEW = "RangePreview" + """**DEPRECATED** - RangePreview. + + .. note:: Support for RangePreview is deprecated. Use :attr:`Algorithm.RANGE` instead. + + .. versionadded:: 4.4 """ @@ -473,7 +484,18 @@ class QueryType(str, enum.Enum): """Used to encrypt a value for an equality query.""" RANGE = "range" - """Used to encrypt a value for a range query.""" + """Used to encrypt a value for a range query. + + .. versionadded:: 4.9 + """ + + RANGEPREVIEW = "RangePreview" + """**DEPRECATED** - Used to encrypt a value for a rangePreview query. + + .. note:: Support for RangePreview is deprecated. Use :attr:`QueryType.RANGE` instead. + + .. versionadded:: 4.4 + """ def _create_mongocrypt_options(**kwargs: Any) -> MongoCryptOptions: @@ -570,7 +592,7 @@ class ClientEncryption(Generic[_DocumentType]): raise ConfigurationError( "client-side field level encryption requires the pymongocrypt " "library: install a compatible version with: " - "python -m pip install 'pymongo[encryption]'" + "python -m pip install --upgrade 'pymongo[encryption]'" ) if not isinstance(codec_options, CodecOptions): @@ -843,7 +865,7 @@ class ClientEncryption(Generic[_DocumentType]): :return: The encrypted value, a :class:`~bson.binary.Binary` with subtype 6. - .. versionchanged:: 4.8 + .. versionchanged:: 4.9 Added the `range_opts` parameter. .. versionchanged:: 4.7 @@ -899,7 +921,7 @@ class ClientEncryption(Generic[_DocumentType]): :return: The encrypted expression, a :class:`~bson.RawBSONDocument`. - .. versionchanged:: 4.8 + .. versionchanged:: 4.9 Added the `range_opts` parameter. .. versionchanged:: 4.7 diff --git a/pymongo/asynchronous/encryption_options.py b/pymongo/asynchronous/encryption_options.py deleted file mode 100644 index cdef2d1d1..000000000 --- a/pymongo/asynchronous/encryption_options.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright 2019-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Support for automatic client-side field level encryption.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Mapping, Optional - -try: - import pymongocrypt # type:ignore[import] # noqa: F401 - - _HAVE_PYMONGOCRYPT = True -except ImportError: - _HAVE_PYMONGOCRYPT = False -from bson import int64 -from pymongo.asynchronous.common import validate_is_mapping -from pymongo.asynchronous.uri_parser import _parse_kms_tls_options -from pymongo.errors import ConfigurationError - -if TYPE_CHECKING: - from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.typings import _DocumentTypeArg - -_IS_SYNC = False - - -class AutoEncryptionOpts: - """Options to configure automatic client-side field level encryption.""" - - def __init__( - self, - kms_providers: Mapping[str, Any], - key_vault_namespace: str, - key_vault_client: Optional[AsyncMongoClient[_DocumentTypeArg]] = None, - schema_map: Optional[Mapping[str, Any]] = None, - bypass_auto_encryption: bool = False, - mongocryptd_uri: str = "mongodb://localhost:27020", - mongocryptd_bypass_spawn: bool = False, - mongocryptd_spawn_path: str = "mongocryptd", - mongocryptd_spawn_args: Optional[list[str]] = None, - kms_tls_options: Optional[Mapping[str, Any]] = None, - crypt_shared_lib_path: Optional[str] = None, - crypt_shared_lib_required: bool = False, - bypass_query_analysis: bool = False, - encrypted_fields_map: Optional[Mapping[str, Any]] = None, - ) -> None: - """Options to configure automatic client-side field level encryption. - - Automatic client-side field level encryption requires MongoDB >=4.2 - enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not - supported for operations on a database or view and will result in - error. - - Although automatic encryption requires MongoDB >=4.2 enterprise or a - MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all - users. To configure automatic *decryption* without automatic - *encryption* set ``bypass_auto_encryption=True``. Explicit - encryption and explicit decryption is also supported for all users - with the :class:`~pymongo.encryption.ClientEncryption` class. - - See :ref:`automatic-client-side-encryption` for an example. - - :param kms_providers: Map of KMS provider options. The `kms_providers` - map values differ by provider: - - - `aws`: Map with "accessKeyId" and "secretAccessKey" as strings. - These are the AWS access key ID and AWS secret access key used - to generate KMS messages. An optional "sessionToken" may be - included to support temporary AWS credentials. - - `azure`: Map with "tenantId", "clientId", and "clientSecret" as - strings. Additionally, "identityPlatformEndpoint" may also be - specified as a string (defaults to 'login.microsoftonline.com'). - These are the Azure Active Directory credentials used to - generate Azure Key Vault messages. - - `gcp`: Map with "email" as a string and "privateKey" - as `bytes` or a base64 encoded string. - Additionally, "endpoint" may also be specified as a string - (defaults to 'oauth2.googleapis.com'). These are the - credentials used to generate Google Cloud KMS messages. - - `kmip`: Map with "endpoint" as a host with required port. - For example: ``{"endpoint": "example.com:443"}``. - - `local`: Map with "key" as `bytes` (96 bytes in length) or - a base64 encoded string which decodes - to 96 bytes. "key" is the master key used to encrypt/decrypt - data keys. This key should be generated and stored as securely - as possible. - - KMS providers may be specified with an optional name suffix - separated by a colon, for example "kmip:name" or "aws:name". - Named KMS providers do not support :ref:`CSFLE on-demand credentials`. - Named KMS providers enables more than one of each KMS provider type to be configured. - For example, to configure multiple local KMS providers:: - - kms_providers = { - "local": {"key": local_kek1}, # Unnamed KMS provider. - "local:myname": {"key": local_kek2}, # Named KMS provider with name "myname". - } - - :param key_vault_namespace: The namespace for the key vault collection. - The key vault collection contains all data keys used for encryption - and decryption. Data keys are stored as documents in this MongoDB - collection. Data keys are protected with encryption by a KMS - provider. - :param key_vault_client: By default, the key vault collection - is assumed to reside in the same MongoDB cluster as the encrypted - AsyncMongoClient. Use this option to route data key queries to a - separate MongoDB cluster. - :param schema_map: Map of collection namespace ("db.coll") to - JSON Schema. By default, a collection's JSONSchema is periodically - polled with the listCollections command. But a JSONSchema may be - specified locally with the schemaMap option. - - **Supplying a `schema_map` provides more security than relying on - JSON Schemas obtained from the server. It protects against a - malicious server advertising a false JSON Schema, which could trick - the client into sending unencrypted data that should be - encrypted.** - - Schemas supplied in the schemaMap only apply to configuring - automatic encryption for client side encryption. Other validation - rules in the JSON schema will not be enforced by the driver and - will result in an error. - :param bypass_auto_encryption: If ``True``, automatic - encryption will be disabled but automatic decryption will still be - enabled. Defaults to ``False``. - :param mongocryptd_uri: The MongoDB URI used to connect - to the *local* mongocryptd process. Defaults to - ``'mongodb://localhost:27020'``. - :param mongocryptd_bypass_spawn: If ``True``, the encrypted - AsyncMongoClient will not attempt to spawn the mongocryptd process. - Defaults to ``False``. - :param mongocryptd_spawn_path: Used for spawning the - mongocryptd process. Defaults to ``'mongocryptd'`` and spawns - mongocryptd from the system path. - :param mongocryptd_spawn_args: A list of string arguments to - use when spawning the mongocryptd process. Defaults to - ``['--idleShutdownTimeoutSecs=60']``. If the list does not include - the ``idleShutdownTimeoutSecs`` option then - ``'--idleShutdownTimeoutSecs=60'`` will be added. - :param kms_tls_options: A map of KMS provider names to TLS - options to use when creating secure connections to KMS providers. - Accepts the same TLS options as - :class:`pymongo.mongo_client.AsyncMongoClient`. For example, to - override the system default CA file:: - - kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}} - - Or to supply a client certificate:: - - kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}} - :param crypt_shared_lib_path: Override the path to load the crypt_shared library. - :param crypt_shared_lib_required: If True, raise an error if libmongocrypt is - unable to load the crypt_shared library. - :param bypass_query_analysis: If ``True``, disable automatic analysis - of outgoing commands. Set `bypass_query_analysis` to use explicit - encryption on indexed fields without the MongoDB Enterprise Advanced - licensed crypt_shared library. - :param encrypted_fields_map: Map of collection namespace ("db.coll") to documents - that described the encrypted fields for Queryable Encryption. For example:: - - { - "db.encryptedCollection": { - "escCollection": "enxcol_.encryptedCollection.esc", - "ecocCollection": "enxcol_.encryptedCollection.ecoc", - "fields": [ - { - "path": "firstName", - "keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')), - "bsonType": "string", - "queries": {"queryType": "equality"} - }, - { - "path": "ssn", - "keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')), - "bsonType": "string" - } - ] - } - } - - .. versionchanged:: 4.2 - Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`, - and `bypass_query_analysis` parameters. - - .. versionchanged:: 4.0 - Added the `kms_tls_options` parameter and the "kmip" KMS provider. - - .. versionadded:: 3.9 - """ - if not _HAVE_PYMONGOCRYPT: - raise ConfigurationError( - "client side encryption requires the pymongocrypt library: " - "install a compatible version with: " - "python -m pip install 'pymongo[encryption]'" - ) - if encrypted_fields_map: - validate_is_mapping("encrypted_fields_map", encrypted_fields_map) - self._encrypted_fields_map = encrypted_fields_map - self._bypass_query_analysis = bypass_query_analysis - self._crypt_shared_lib_path = crypt_shared_lib_path - self._crypt_shared_lib_required = crypt_shared_lib_required - self._kms_providers = kms_providers - self._key_vault_namespace = key_vault_namespace - self._key_vault_client = key_vault_client - self._schema_map = schema_map - self._bypass_auto_encryption = bypass_auto_encryption - self._mongocryptd_uri = mongocryptd_uri - self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn - self._mongocryptd_spawn_path = mongocryptd_spawn_path - if mongocryptd_spawn_args is None: - mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"] - self._mongocryptd_spawn_args = mongocryptd_spawn_args - if not isinstance(self._mongocryptd_spawn_args, list): - raise TypeError("mongocryptd_spawn_args must be a list") - if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args): - self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60") - # Maps KMS provider name to a SSLContext. - self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options) - self._bypass_query_analysis = bypass_query_analysis - - -class RangeOpts: - """Options to configure encrypted queries using the range algorithm.""" - - def __init__( - self, - sparsity: int, - trim_factor: int, - min: Optional[Any] = None, - max: Optional[Any] = None, - precision: Optional[int] = None, - ) -> None: - """Options to configure encrypted queries using the range algorithm. - - :param sparsity: An integer. - :param trim_factor: An integer. - :param min: A BSON scalar value corresponding to the type being queried. - :param max: A BSON scalar value corresponding to the type being queried. - :param precision: An integer, may only be set for double or decimal128 types. - - .. versionadded:: 4.4 - """ - self.min = min - self.max = max - self.sparsity = sparsity - self.trim_factor = trim_factor - self.precision = precision - - @property - def document(self) -> dict[str, Any]: - doc = {} - for k, v in [ - ("sparsity", int64.Int64(self.sparsity)), - ("trimFactor", self.trim_factor), - ("precision", self.precision), - ("min", self.min), - ("max", self.max), - ]: - if v is not None: - doc[k] = v - return doc diff --git a/pymongo/asynchronous/event_loggers.py b/pymongo/asynchronous/event_loggers.py deleted file mode 100644 index 9bb8bb36b..000000000 --- a/pymongo/asynchronous/event_loggers.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2020-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Example event logger classes. - -.. versionadded:: 3.11 - -These loggers can be registered using :func:`register` or -:class:`~pymongo.mongo_client.MongoClient`. - -``monitoring.register(CommandLogger())`` - -or - -``MongoClient(event_listeners=[CommandLogger()])`` -""" -from __future__ import annotations - -import logging - -from pymongo.asynchronous import monitoring - -_IS_SYNC = False - - -class CommandLogger(monitoring.CommandListener): - """A simple listener that logs command events. - - Listens for :class:`~pymongo.monitoring.CommandStartedEvent`, - :class:`~pymongo.monitoring.CommandSucceededEvent` and - :class:`~pymongo.monitoring.CommandFailedEvent` events and - logs them at the `INFO` severity level using :mod:`logging`. - .. versionadded:: 3.11 - """ - - def started(self, event: monitoring.CommandStartedEvent) -> None: - logging.info( - f"Command {event.command_name} with request id " - f"{event.request_id} started on server " - f"{event.connection_id}" - ) - - def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: - logging.info( - f"Command {event.command_name} with request id " - f"{event.request_id} on server {event.connection_id} " - f"succeeded in {event.duration_micros} " - "microseconds" - ) - - def failed(self, event: monitoring.CommandFailedEvent) -> None: - logging.info( - f"Command {event.command_name} with request id " - f"{event.request_id} on server {event.connection_id} " - f"failed in {event.duration_micros} " - "microseconds" - ) - - -class ServerLogger(monitoring.ServerListener): - """A simple listener that logs server discovery events. - - Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`, - :class:`~pymongo.monitoring.ServerDescriptionChangedEvent`, - and :class:`~pymongo.monitoring.ServerClosedEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def opened(self, event: monitoring.ServerOpeningEvent) -> None: - logging.info(f"Server {event.server_address} added to topology {event.topology_id}") - - def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None: - previous_server_type = event.previous_description.server_type - new_server_type = event.new_description.server_type - if new_server_type != previous_server_type: - # server_type_name was added in PyMongo 3.4 - logging.info( - f"Server {event.server_address} changed type from " - f"{event.previous_description.server_type_name} to " - f"{event.new_description.server_type_name}" - ) - - def closed(self, event: monitoring.ServerClosedEvent) -> None: - logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}") - - -class HeartbeatLogger(monitoring.ServerHeartbeatListener): - """A simple listener that logs server heartbeat events. - - Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`, - :class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`, - and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None: - logging.info(f"Heartbeat sent to server {event.connection_id}") - - def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None: - # The reply.document attribute was added in PyMongo 3.4. - logging.info( - f"Heartbeat to server {event.connection_id} " - "succeeded with reply " - f"{event.reply.document}" - ) - - def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None: - logging.warning( - f"Heartbeat to server {event.connection_id} failed with error {event.reply}" - ) - - -class TopologyLogger(monitoring.TopologyListener): - """A simple listener that logs server topology events. - - Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`, - :class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`, - and :class:`~pymongo.monitoring.TopologyClosedEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def opened(self, event: monitoring.TopologyOpenedEvent) -> None: - logging.info(f"Topology with id {event.topology_id} opened") - - def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None: - logging.info(f"Topology description updated for topology id {event.topology_id}") - previous_topology_type = event.previous_description.topology_type - new_topology_type = event.new_description.topology_type - if new_topology_type != previous_topology_type: - # topology_type_name was added in PyMongo 3.4 - logging.info( - f"Topology {event.topology_id} changed type from " - f"{event.previous_description.topology_type_name} to " - f"{event.new_description.topology_type_name}" - ) - # The has_writable_server and has_readable_server methods - # were added in PyMongo 3.4. - if not event.new_description.has_writable_server(): - logging.warning("No writable servers available.") - if not event.new_description.has_readable_server(): - logging.warning("No readable servers available.") - - def closed(self, event: monitoring.TopologyClosedEvent) -> None: - logging.info(f"Topology with id {event.topology_id} closed") - - -class ConnectionPoolLogger(monitoring.ConnectionPoolListener): - """A simple listener that logs server connection pool events. - - Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`, - :class:`~pymongo.monitoring.PoolClearedEvent`, - :class:`~pymongo.monitoring.PoolClosedEvent`, - :~pymongo.monitoring.class:`ConnectionCreatedEvent`, - :class:`~pymongo.monitoring.ConnectionReadyEvent`, - :class:`~pymongo.monitoring.ConnectionClosedEvent`, - :class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`, - :class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`, - :class:`~pymongo.monitoring.ConnectionCheckedOutEvent`, - and :class:`~pymongo.monitoring.ConnectionCheckedInEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def pool_created(self, event: monitoring.PoolCreatedEvent) -> None: - logging.info(f"[pool {event.address}] pool created") - - def pool_ready(self, event: monitoring.PoolReadyEvent) -> None: - logging.info(f"[pool {event.address}] pool ready") - - def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None: - logging.info(f"[pool {event.address}] pool cleared") - - def pool_closed(self, event: monitoring.PoolClosedEvent) -> None: - logging.info(f"[pool {event.address}] pool closed") - - def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None: - logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created") - - def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded" - ) - - def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] " - f'connection closed, reason: "{event.reason}"' - ) - - def connection_check_out_started( - self, event: monitoring.ConnectionCheckOutStartedEvent - ) -> None: - logging.info(f"[pool {event.address}] connection check out started") - - def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None: - logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}") - - def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool" - ) - - def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool" - ) diff --git a/pymongo/asynchronous/hello_compat.py b/pymongo/asynchronous/hello_compat.py deleted file mode 100644 index 9bc8b088c..000000000 --- a/pymongo/asynchronous/hello_compat.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The HelloCompat class, placed here to break circular import issues.""" -from __future__ import annotations - -_IS_SYNC = False - - -class HelloCompat: - CMD = "hello" - LEGACY_CMD = "ismaster" - PRIMARY = "isWritablePrimary" - LEGACY_PRIMARY = "ismaster" - LEGACY_ERROR = "not master" diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 7ec1c18c5..8a85135c1 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -1,4 +1,4 @@ -# Copyright 2009-present MongoDB, Inc. +# Copyright 2024-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,270 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Bits and pieces used by the driver that don't really fit elsewhere.""" +"""Miscellaneous pieces that need to be synchronized.""" from __future__ import annotations import builtins import sys -import traceback -from collections import abc from typing import ( - TYPE_CHECKING, Any, Callable, - Container, - Iterable, - Mapping, - NoReturn, - Optional, - Sequence, TypeVar, - Union, cast, ) -from pymongo import ASCENDING -from pymongo.asynchronous.hello_compat import HelloCompat from pymongo.errors import ( - CursorNotFound, - DuplicateKeyError, - ExecutionTimeout, - NotPrimaryError, OperationFailure, - WriteConcernError, - WriteError, - WTimeoutError, - _wtimeout_error, ) -from pymongo.helpers_constants import _NOT_PRIMARY_CODES, _REAUTHENTICATION_REQUIRED_CODE - -if TYPE_CHECKING: - from pymongo.asynchronous.operations import _IndexList - from pymongo.asynchronous.typings import _DocumentOut - from pymongo.cursor_shared import _Hint +from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE _IS_SYNC = False - -def _gen_index_name(keys: _IndexList) -> str: - """Generate an index name from the set of fields it is over.""" - return "_".join(["{}_{}".format(*item) for item in keys]) - - -def _index_list( - key_or_list: _Hint, direction: Optional[Union[int, str]] = None -) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]: - """Helper to generate a list of (key, direction) pairs. - - Takes such a list, or a single key, or a single key and direction. - """ - if direction is not None: - if not isinstance(key_or_list, str): - raise TypeError("Expected a string and a direction") - return [(key_or_list, direction)] - else: - if isinstance(key_or_list, str): - return [(key_or_list, ASCENDING)] - elif isinstance(key_or_list, abc.ItemsView): - return list(key_or_list) # type: ignore[arg-type] - elif isinstance(key_or_list, abc.Mapping): - return list(key_or_list.items()) - elif not isinstance(key_or_list, (list, tuple)): - raise TypeError("if no direction is specified, key_or_list must be an instance of list") - values: list[tuple[str, int]] = [] - for item in key_or_list: - if isinstance(item, str): - item = (item, ASCENDING) # noqa: PLW2901 - values.append(item) - return values - - -def _index_document(index_list: _IndexList) -> dict[str, Any]: - """Helper to generate an index specifying document. - - Takes a list of (key, direction) pairs. - """ - if not isinstance(index_list, (list, tuple, abc.Mapping)): - raise TypeError( - "must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list) - ) - if not len(index_list): - raise ValueError("key_or_list must not be empty") - - index: dict[str, Any] = {} - - if isinstance(index_list, abc.Mapping): - for key in index_list: - value = index_list[key] - _validate_index_key_pair(key, value) - index[key] = value - else: - for item in index_list: - if isinstance(item, str): - item = (item, ASCENDING) # noqa: PLW2901 - key, value = item - _validate_index_key_pair(key, value) - index[key] = value - return index - - -def _validate_index_key_pair(key: Any, value: Any) -> None: - if not isinstance(key, str): - raise TypeError("first item in each key pair must be an instance of str") - if not isinstance(value, (str, int, abc.Mapping)): - raise TypeError( - "second item in each key pair must be 1, -1, " - "'2d', or another valid MongoDB index specifier." - ) - - -def _check_command_response( - response: _DocumentOut, - max_wire_version: Optional[int], - allowable_errors: Optional[Container[Union[int, str]]] = None, - parse_write_concern_error: bool = False, -) -> None: - """Check the response to a command for errors.""" - if "ok" not in response: - # Server didn't recognize our message as a command. - raise OperationFailure( - response.get("$err"), # type: ignore[arg-type] - response.get("code"), - response, - max_wire_version, - ) - - if parse_write_concern_error and "writeConcernError" in response: - _error = response["writeConcernError"] - _labels = response.get("errorLabels") - if _labels: - _error.update({"errorLabels": _labels}) - _raise_write_concern_error(_error) - - if response["ok"]: - return - - details = response - # Mongos returns the error details in a 'raw' object - # for some errors. - if "raw" in response: - for shard in response["raw"].values(): - # Grab the first non-empty raw error from a shard. - if shard.get("errmsg") and not shard.get("ok"): - details = shard - break - - errmsg = details["errmsg"] - code = details.get("code") - - # For allowable errors, only check for error messages when the code is not - # included. - if allowable_errors: - if code is not None: - if code in allowable_errors: - return - elif errmsg in allowable_errors: - return - - # Server is "not primary" or "recovering" - if code is not None: - if code in _NOT_PRIMARY_CODES: - raise NotPrimaryError(errmsg, response) - elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg: - raise NotPrimaryError(errmsg, response) - - # Other errors - # findAndModify with upsert can raise duplicate key error - if code in (11000, 11001, 12582): - raise DuplicateKeyError(errmsg, code, response, max_wire_version) - elif code == 50: - raise ExecutionTimeout(errmsg, code, response, max_wire_version) - elif code == 43: - raise CursorNotFound(errmsg, code, response, max_wire_version) - - raise OperationFailure(errmsg, code, response, max_wire_version) - - -def _raise_last_write_error(write_errors: list[Any]) -> NoReturn: - # If the last batch had multiple errors only report - # the last error to emulate continue_on_error. - error = write_errors[-1] - if error.get("code") == 11000: - raise DuplicateKeyError(error.get("errmsg"), 11000, error) - raise WriteError(error.get("errmsg"), error.get("code"), error) - - -def _raise_write_concern_error(error: Any) -> NoReturn: - if _wtimeout_error(error): - # Make sure we raise WTimeoutError - raise WTimeoutError(error.get("errmsg"), error.get("code"), error) - raise WriteConcernError(error.get("errmsg"), error.get("code"), error) - - -def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]: - """Return the writeConcernError or None.""" - wce = result.get("writeConcernError") - if wce: - # The server reports errorLabels at the top level but it's more - # convenient to attach it to the writeConcernError doc itself. - error_labels = result.get("errorLabels") - if error_labels: - # Copy to avoid changing the original document. - wce = wce.copy() - wce["errorLabels"] = error_labels - return wce - - -def _check_write_command_response(result: Mapping[str, Any]) -> None: - """Backward compatibility helper for write command error handling.""" - # Prefer write errors over write concern errors - write_errors = result.get("writeErrors") - if write_errors: - _raise_last_write_error(write_errors) - - wce = _get_wce_doc(result) - if wce: - _raise_write_concern_error(wce) - - -def _fields_list_to_dict( - fields: Union[Mapping[str, Any], Iterable[str]], option_name: str -) -> Mapping[str, Any]: - """Takes a sequence of field names and returns a matching dictionary. - - ["a", "b"] becomes {"a": 1, "b": 1} - - and - - ["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1} - """ - if isinstance(fields, abc.Mapping): - return fields - - if isinstance(fields, (abc.Sequence, abc.Set)): - if not all(isinstance(field, str) for field in fields): - raise TypeError(f"{option_name} must be a list of key names, each an instance of str") - return dict.fromkeys(fields, 1) - - raise TypeError(f"{option_name} must be a mapping or list of key names") - - -def _handle_exception() -> None: - """Print exceptions raised by subscribers to stderr.""" - # Heavily influenced by logging.Handler.handleError. - - # See note here: - # https://docs.python.org/3.4/library/sys.html#sys.__stderr__ - if sys.stderr: - einfo = sys.exc_info() - try: - traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr) - except OSError: - pass - finally: - del einfo - - # See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories F = TypeVar("F", bound=Callable[..., Any]) @@ -283,8 +38,8 @@ F = TypeVar("F", bound=Callable[..., Any]) def _handle_reauth(func: F) -> F: async def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) - from pymongo.asynchronous.message import _BulkWriteContext - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.message import _BulkWriteContext try: return await func(*args, **kwargs) @@ -292,16 +47,16 @@ def _handle_reauth(func: F) -> F: if no_reauth: raise if exc.code == _REAUTHENTICATION_REQUIRED_CODE: - # Look for an argument that either is a Connection + # Look for an argument that either is a AsyncConnection # or has a connection attribute, so we can trigger # a reauth. conn = None for arg in args: - if isinstance(arg, Connection): + if isinstance(arg, AsyncConnection): conn = arg break if isinstance(arg, _BulkWriteContext): - conn = arg.conn + conn = arg.conn # type: ignore[assignment] break if conn: await conn.authenticate(reauthenticate=True) diff --git a/pymongo/asynchronous/logger.py b/pymongo/asynchronous/logger.py deleted file mode 100644 index 4fe820127..000000000 --- a/pymongo/asynchronous/logger.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2023-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import enum -import logging -import os -import warnings -from typing import Any - -from bson import UuidRepresentation, json_util -from bson.json_util import JSONOptions, _truncate_documents -from pymongo.asynchronous.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason - -_IS_SYNC = False - - -class _CommandStatusMessage(str, enum.Enum): - STARTED = "Command started" - SUCCEEDED = "Command succeeded" - FAILED = "Command failed" - - -class _ServerSelectionStatusMessage(str, enum.Enum): - STARTED = "Server selection started" - SUCCEEDED = "Server selection succeeded" - FAILED = "Server selection failed" - WAITING = "Waiting for suitable server to become available" - - -class _ConnectionStatusMessage(str, enum.Enum): - POOL_CREATED = "Connection pool created" - POOL_READY = "Connection pool ready" - POOL_CLOSED = "Connection pool closed" - POOL_CLEARED = "Connection pool cleared" - - CONN_CREATED = "Connection created" - CONN_READY = "Connection ready" - CONN_CLOSED = "Connection closed" - - CHECKOUT_STARTED = "Connection checkout started" - CHECKOUT_SUCCEEDED = "Connection checked out" - CHECKOUT_FAILED = "Connection checkout failed" - CHECKEDIN = "Connection checked in" - - -_DEFAULT_DOCUMENT_LENGTH = 1000 -_SENSITIVE_COMMANDS = [ - "authenticate", - "saslStart", - "saslContinue", - "getnonce", - "createUser", - "updateUser", - "copydbgetnonce", - "copydbsaslstart", - "copydb", -] -_HELLO_COMMANDS = ["hello", "ismaster", "isMaster"] -_REDACTED_FAILURE_FIELDS = ["code", "codeName", "errorLabels"] -_DOCUMENT_NAMES = ["command", "reply", "failure"] -_JSON_OPTIONS = JSONOptions(uuid_representation=UuidRepresentation.STANDARD) -_COMMAND_LOGGER = logging.getLogger("pymongo.command") -_CONNECTION_LOGGER = logging.getLogger("pymongo.connection") -_SERVER_SELECTION_LOGGER = logging.getLogger("pymongo.serverSelection") -_CLIENT_LOGGER = logging.getLogger("pymongo.client") -_VERBOSE_CONNECTION_ERROR_REASONS = { - ConnectionClosedReason.POOL_CLOSED: "Connection pool was closed", - ConnectionCheckOutFailedReason.POOL_CLOSED: "Connection pool was closed", - ConnectionClosedReason.STALE: "Connection pool was stale", - ConnectionClosedReason.ERROR: "An error occurred while using the connection", - ConnectionCheckOutFailedReason.CONN_ERROR: "An error occurred while trying to establish a new connection", - ConnectionClosedReason.IDLE: "Connection was idle too long", - ConnectionCheckOutFailedReason.TIMEOUT: "Connection exceeded the specified timeout", -} - - -def _debug_log(logger: logging.Logger, **fields: Any) -> None: - logger.debug(LogMessage(**fields)) - - -def _verbose_connection_error_reason(reason: str) -> str: - return _VERBOSE_CONNECTION_ERROR_REASONS.get(reason, reason) - - -def _info_log(logger: logging.Logger, **fields: Any) -> None: - logger.info(LogMessage(**fields)) - - -def _log_or_warn(logger: logging.Logger, message: str) -> None: - if logger.isEnabledFor(logging.INFO): - logger.info(message) - else: - # stacklevel=4 ensures that the warning is for the user's code. - warnings.warn(message, UserWarning, stacklevel=4) - - -class LogMessage: - __slots__ = ("_kwargs", "_redacted") - - def __init__(self, **kwargs: Any): - self._kwargs = kwargs - self._redacted = False - - def __str__(self) -> str: - self._redact() - return "%s" % ( - json_util.dumps( - self._kwargs, json_options=_JSON_OPTIONS, default=lambda o: o.__repr__() - ) - ) - - def _is_sensitive(self, doc_name: str) -> bool: - is_speculative_authenticate = ( - self._kwargs.pop("speculative_authenticate", False) - or "speculativeAuthenticate" in self._kwargs[doc_name] - ) - is_sensitive_command = ( - "commandName" in self._kwargs and self._kwargs["commandName"] in _SENSITIVE_COMMANDS - ) - - is_sensitive_hello = ( - self._kwargs["commandName"] in _HELLO_COMMANDS and is_speculative_authenticate - ) - - return is_sensitive_command or is_sensitive_hello - - def _redact(self) -> None: - if self._redacted: - return - self._kwargs = {k: v for k, v in self._kwargs.items() if v is not None} - if "durationMS" in self._kwargs and hasattr(self._kwargs["durationMS"], "total_seconds"): - self._kwargs["durationMS"] = self._kwargs["durationMS"].total_seconds() * 1000 - if "serviceId" in self._kwargs: - self._kwargs["serviceId"] = str(self._kwargs["serviceId"]) - document_length = int(os.getenv("MONGOB_LOG_MAX_DOCUMENT_LENGTH", _DEFAULT_DOCUMENT_LENGTH)) - if document_length < 0: - document_length = _DEFAULT_DOCUMENT_LENGTH - is_server_side_error = self._kwargs.pop("isServerSideError", False) - - for doc_name in _DOCUMENT_NAMES: - doc = self._kwargs.get(doc_name) - if doc: - if doc_name == "failure" and is_server_side_error: - doc = {k: v for k, v in doc.items() if k in _REDACTED_FAILURE_FIELDS} - if doc_name != "failure" and self._is_sensitive(doc_name): - doc = json_util.dumps({}) - else: - truncated_doc = _truncate_documents(doc, document_length)[0] - doc = json_util.dumps( - truncated_doc, - json_options=_JSON_OPTIONS, - default=lambda o: o.__repr__(), - ) - if len(doc) > document_length: - doc = ( - doc.encode()[:document_length].decode("unicode-escape", "ignore") - ) + "..." - self._kwargs[doc_name] = doc - self._redacted = True diff --git a/pymongo/asynchronous/max_staleness_selectors.py b/pymongo/asynchronous/max_staleness_selectors.py deleted file mode 100644 index fadd3b429..000000000 --- a/pymongo/asynchronous/max_staleness_selectors.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2016 MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Criteria to select ServerDescriptions based on maxStalenessSeconds. - -The Max Staleness Spec says: When there is a known primary P, -a secondary S's staleness is estimated with this formula: - - (S.lastUpdateTime - S.lastWriteDate) - (P.lastUpdateTime - P.lastWriteDate) - + heartbeatFrequencyMS - -When there is no known primary, a secondary S's staleness is estimated with: - - SMax.lastWriteDate - S.lastWriteDate + heartbeatFrequencyMS - -where "SMax" is the secondary with the greatest lastWriteDate. -""" -from __future__ import annotations - -from typing import TYPE_CHECKING - -from pymongo.errors import ConfigurationError -from pymongo.server_type import SERVER_TYPE - -if TYPE_CHECKING: - from pymongo.asynchronous.server_selectors import Selection - -_IS_SYNC = False - -# Constant defined in Max Staleness Spec: An idle primary writes a no-op every -# 10 seconds to refresh secondaries' lastWriteDate values. -IDLE_WRITE_PERIOD = 10 -SMALLEST_MAX_STALENESS = 90 - - -def _validate_max_staleness(max_staleness: int, heartbeat_frequency: int) -> None: - # We checked for max staleness -1 before this, it must be positive here. - if max_staleness < heartbeat_frequency + IDLE_WRITE_PERIOD: - raise ConfigurationError( - "maxStalenessSeconds must be at least heartbeatFrequencyMS +" - " %d seconds. maxStalenessSeconds is set to %d," - " heartbeatFrequencyMS is set to %d." - % (IDLE_WRITE_PERIOD, max_staleness, heartbeat_frequency * 1000) - ) - - if max_staleness < SMALLEST_MAX_STALENESS: - raise ConfigurationError( - "maxStalenessSeconds must be at least %d. " - "maxStalenessSeconds is set to %d." % (SMALLEST_MAX_STALENESS, max_staleness) - ) - - -def _with_primary(max_staleness: int, selection: Selection) -> Selection: - """Apply max_staleness, in seconds, to a Selection with a known primary.""" - primary = selection.primary - assert primary - sds = [] - - for s in selection.server_descriptions: - if s.server_type == SERVER_TYPE.RSSecondary: - # See max-staleness.rst for explanation of this formula. - assert s.last_write_date and primary.last_write_date # noqa: PT018 - staleness = ( - (s.last_update_time - s.last_write_date) - - (primary.last_update_time - primary.last_write_date) - + selection.heartbeat_frequency - ) - - if staleness <= max_staleness: - sds.append(s) - else: - sds.append(s) - - return selection.with_server_descriptions(sds) - - -def _no_primary(max_staleness: int, selection: Selection) -> Selection: - """Apply max_staleness, in seconds, to a Selection with no known primary.""" - # Secondary that's replicated the most recent writes. - smax = selection.secondary_with_max_last_write_date() - if not smax: - # No secondaries and no primary, short-circuit out of here. - return selection.with_server_descriptions([]) - - sds = [] - - for s in selection.server_descriptions: - if s.server_type == SERVER_TYPE.RSSecondary: - # See max-staleness.rst for explanation of this formula. - assert smax.last_write_date and s.last_write_date # noqa: PT018 - staleness = smax.last_write_date - s.last_write_date + selection.heartbeat_frequency - - if staleness <= max_staleness: - sds.append(s) - else: - sds.append(s) - - return selection.with_server_descriptions(sds) - - -def select(max_staleness: int, selection: Selection) -> Selection: - """Apply max_staleness, in seconds, to a Selection.""" - if max_staleness == -1: - return selection - - # Server Selection Spec: If the TopologyType is ReplicaSetWithPrimary or - # ReplicaSetNoPrimary, a client MUST raise an error if maxStaleness < - # heartbeatFrequency + IDLE_WRITE_PERIOD, or if maxStaleness < 90. - _validate_max_staleness(max_staleness, selection.heartbeat_frequency) - - if selection.primary: - return _with_primary(max_staleness, selection) - else: - return _no_primary(max_staleness, selection) diff --git a/pymongo/asynchronous/message.py b/pymongo/asynchronous/message.py deleted file mode 100644 index d2f048b40..000000000 --- a/pymongo/asynchronous/message.py +++ /dev/null @@ -1,1760 +0,0 @@ -# Copyright 2009-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tools for creating `messages -`_ to be sent to -MongoDB. - -.. note:: This module is for internal use and is generally not needed by - application developers. -""" -from __future__ import annotations - -import datetime -import logging -import random -import struct -from io import BytesIO as _BytesIO -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - Mapping, - MutableMapping, - NoReturn, - Optional, - Union, -) - -import bson -from bson import CodecOptions, _decode_selective, _dict_to_bson, _make_c_string, encode -from bson.int64 import Int64 -from bson.raw_bson import ( - _RAW_ARRAY_BSON_OPTIONS, - DEFAULT_RAW_BSON_OPTIONS, - RawBSONDocument, - _inflate_bson, -) - -try: - from pymongo import _cmessage # type: ignore[attr-defined] - - _use_c = True -except ImportError: - _use_c = False -from pymongo.asynchronous.hello_compat import HelloCompat -from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.asynchronous.logger import ( - _COMMAND_LOGGER, - _CommandStatusMessage, - _debug_log, -) -from pymongo.asynchronous.read_preferences import ReadPreference -from pymongo.errors import ( - ConfigurationError, - CursorNotFound, - DocumentTooLarge, - ExecutionTimeout, - InvalidOperation, - NotPrimaryError, - OperationFailure, - ProtocolError, -) -from pymongo.write_concern import WriteConcern - -if TYPE_CHECKING: - from datetime import timedelta - - from pymongo.asynchronous.client_session import ClientSession - from pymongo.asynchronous.compression_support import SnappyContext, ZlibContext, ZstdContext - from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.monitoring import _EventListeners - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.read_preferences import _ServerMode - from pymongo.asynchronous.typings import _Address, _DocumentOut - from pymongo.read_concern import ReadConcern - - -_IS_SYNC = False - -MAX_INT32 = 2147483647 -MIN_INT32 = -2147483648 - -# Overhead allowed for encoded command documents. -_COMMAND_OVERHEAD = 16382 - -_INSERT = 0 -_UPDATE = 1 -_DELETE = 2 - -_EMPTY = b"" -_BSONOBJ = b"\x03" -_ZERO_8 = b"\x00" -_ZERO_16 = b"\x00\x00" -_ZERO_32 = b"\x00\x00\x00\x00" -_ZERO_64 = b"\x00\x00\x00\x00\x00\x00\x00\x00" -_SKIPLIM = b"\x00\x00\x00\x00\xff\xff\xff\xff" -_OP_MAP = { - _INSERT: b"\x04documents\x00\x00\x00\x00\x00", - _UPDATE: b"\x04updates\x00\x00\x00\x00\x00", - _DELETE: b"\x04deletes\x00\x00\x00\x00\x00", -} -_FIELD_MAP = {"insert": "documents", "update": "updates", "delete": "deletes"} - -_UNICODE_REPLACE_CODEC_OPTIONS: CodecOptions[Mapping[str, Any]] = CodecOptions( - unicode_decode_error_handler="replace" -) - - -def _randint() -> int: - """Generate a pseudo random 32 bit integer.""" - return random.randint(MIN_INT32, MAX_INT32) # noqa: S311 - - -def _maybe_add_read_preference( - spec: MutableMapping[str, Any], read_preference: _ServerMode -) -> MutableMapping[str, Any]: - """Add $readPreference to spec when appropriate.""" - mode = read_preference.mode - document = read_preference.document - # Only add $readPreference if it's something other than primary to avoid - # problems with mongos versions that don't support read preferences. Also, - # for maximum backwards compatibility, don't add $readPreference for - # secondaryPreferred unless tags or maxStalenessSeconds are in use (setting - # the secondaryOkay bit has the same effect). - if mode and (mode != ReadPreference.SECONDARY_PREFERRED.mode or len(document) > 1): - if "$query" not in spec: - spec = {"$query": spec} - spec["$readPreference"] = document - return spec - - -def _convert_exception(exception: Exception) -> dict[str, Any]: - """Convert an Exception into a failure document for publishing.""" - return {"errmsg": str(exception), "errtype": exception.__class__.__name__} - - -def _convert_write_result( - operation: str, command: Mapping[str, Any], result: Mapping[str, Any] -) -> dict[str, Any]: - """Convert a legacy write result to write command format.""" - # Based on _merge_legacy from bulk.py - affected = result.get("n", 0) - res = {"ok": 1, "n": affected} - errmsg = result.get("errmsg", result.get("err", "")) - if errmsg: - # The write was successful on at least the primary so don't return. - if result.get("wtimeout"): - res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}} - else: - # The write failed. - error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg} - if "errInfo" in result: - error["errInfo"] = result["errInfo"] - res["writeErrors"] = [error] - return res - if operation == "insert": - # GLE result for insert is always 0 in most MongoDB versions. - res["n"] = len(command["documents"]) - elif operation == "update": - if "upserted" in result: - res["upserted"] = [{"index": 0, "_id": result["upserted"]}] - # Versions of MongoDB before 2.6 don't return the _id for an - # upsert if _id is not an ObjectId. - elif result.get("updatedExisting") is False and affected == 1: - # If _id is in both the update document *and* the query spec - # the update document _id takes precedence. - update = command["updates"][0] - _id = update["u"].get("_id", update["q"].get("_id")) - res["upserted"] = [{"index": 0, "_id": _id}] - return res - - -_OPTIONS = { - "tailable": 2, - "oplogReplay": 8, - "noCursorTimeout": 16, - "awaitData": 32, - "allowPartialResults": 128, -} - - -_MODIFIERS = { - "$query": "filter", - "$orderby": "sort", - "$hint": "hint", - "$comment": "comment", - "$maxScan": "maxScan", - "$maxTimeMS": "maxTimeMS", - "$max": "max", - "$min": "min", - "$returnKey": "returnKey", - "$showRecordId": "showRecordId", - "$showDiskLoc": "showRecordId", # <= MongoDb 3.0 - "$snapshot": "snapshot", -} - - -def _gen_find_command( - coll: str, - spec: Mapping[str, Any], - projection: Optional[Union[Mapping[str, Any], Iterable[str]]], - skip: int, - limit: int, - batch_size: Optional[int], - options: Optional[int], - read_concern: ReadConcern, - collation: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, - allow_disk_use: Optional[bool] = None, -) -> dict[str, Any]: - """Generate a find command document.""" - cmd: dict[str, Any] = {"find": coll} - if "$query" in spec: - cmd.update( - [ - (_MODIFIERS[key], val) if key in _MODIFIERS else (key, val) - for key, val in spec.items() - ] - ) - if "$explain" in cmd: - cmd.pop("$explain") - if "$readPreference" in cmd: - cmd.pop("$readPreference") - else: - cmd["filter"] = spec - - if projection: - cmd["projection"] = projection - if skip: - cmd["skip"] = skip - if limit: - cmd["limit"] = abs(limit) - if limit < 0: - cmd["singleBatch"] = True - if batch_size: - cmd["batchSize"] = batch_size - if read_concern.level and not (session and session.in_transaction): - cmd["readConcern"] = read_concern.document - if collation: - cmd["collation"] = collation - if allow_disk_use is not None: - cmd["allowDiskUse"] = allow_disk_use - if options: - cmd.update([(opt, True) for opt, val in _OPTIONS.items() if options & val]) - - return cmd - - -def _gen_get_more_command( - cursor_id: Optional[int], - coll: str, - batch_size: Optional[int], - max_await_time_ms: Optional[int], - comment: Optional[Any], - conn: Connection, -) -> dict[str, Any]: - """Generate a getMore command document.""" - cmd: dict[str, Any] = {"getMore": cursor_id, "collection": coll} - if batch_size: - cmd["batchSize"] = batch_size - if max_await_time_ms is not None: - cmd["maxTimeMS"] = max_await_time_ms - if comment is not None and conn.max_wire_version >= 9: - cmd["comment"] = comment - return cmd - - -class _Query: - """A query operation.""" - - __slots__ = ( - "flags", - "db", - "coll", - "ntoskip", - "spec", - "fields", - "codec_options", - "read_preference", - "limit", - "batch_size", - "name", - "read_concern", - "collation", - "session", - "client", - "allow_disk_use", - "_as_command", - "exhaust", - ) - - # For compatibility with the _GetMore class. - conn_mgr = None - cursor_id = None - - def __init__( - self, - flags: int, - db: str, - coll: str, - ntoskip: int, - spec: Mapping[str, Any], - fields: Optional[Mapping[str, Any]], - codec_options: CodecOptions, - read_preference: _ServerMode, - limit: int, - batch_size: int, - read_concern: ReadConcern, - collation: Optional[Mapping[str, Any]], - session: Optional[ClientSession], - client: AsyncMongoClient, - allow_disk_use: Optional[bool], - exhaust: bool, - ): - self.flags = flags - self.db = db - self.coll = coll - self.ntoskip = ntoskip - self.spec = spec - self.fields = fields - self.codec_options = codec_options - self.read_preference = read_preference - self.read_concern = read_concern - self.limit = limit - self.batch_size = batch_size - self.collation = collation - self.session = session - self.client = client - self.allow_disk_use = allow_disk_use - self.name = "find" - self._as_command: Optional[tuple[dict[str, Any], str]] = None - self.exhaust = exhaust - - def reset(self) -> None: - self._as_command = None - - def namespace(self) -> str: - return f"{self.db}.{self.coll}" - - def use_command(self, conn: Connection) -> bool: - use_find_cmd = False - if not self.exhaust: - use_find_cmd = True - elif conn.max_wire_version >= 8: - # OP_MSG supports exhaust on MongoDB 4.2+ - use_find_cmd = True - elif not self.read_concern.ok_for_legacy: - raise ConfigurationError( - "read concern level of %s is not valid " - "with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version) - ) - - conn.validate_session(self.client, self.session) - return use_find_cmd - - async def as_command( - self, conn: Connection, apply_timeout: bool = False - ) -> tuple[dict[str, Any], str]: - """Return a find command document for this query.""" - # We use the command twice: on the wire and for command monitoring. - # Generate it once, for speed and to avoid repeating side-effects. - if self._as_command is not None: - return self._as_command - - explain = "$explain" in self.spec - cmd: dict[str, Any] = _gen_find_command( - self.coll, - self.spec, - self.fields, - self.ntoskip, - self.limit, - self.batch_size, - self.flags, - self.read_concern, - self.collation, - self.session, - self.allow_disk_use, - ) - if explain: - self.name = "explain" - cmd = {"explain": cmd} - session = self.session - conn.add_server_api(cmd) - if session: - session._apply_to(cmd, False, self.read_preference, conn) - # Explain does not support readConcern. - if not explain and not session.in_transaction: - session._update_read_concern(cmd, conn) - conn.send_cluster_time(cmd, session, self.client) - # Support auto encryption - client = self.client - if client._encrypter and not client._encrypter._bypass_auto_encryption: - cmd = await client._encrypter.encrypt(self.db, cmd, self.codec_options) - # Support CSOT - if apply_timeout: - conn.apply_timeout(client, cmd) - self._as_command = cmd, self.db - return self._as_command - - async def get_message( - self, read_preference: _ServerMode, conn: Connection, use_cmd: bool = False - ) -> tuple[int, bytes, int]: - """Get a query message, possibly setting the secondaryOk bit.""" - # Use the read_preference decided by _socket_from_server. - self.read_preference = read_preference - if read_preference.mode: - # Set the secondaryOk bit. - flags = self.flags | 4 - else: - flags = self.flags - - ns = self.namespace() - spec = self.spec - - if use_cmd: - spec = (await self.as_command(conn, apply_timeout=True))[0] - request_id, msg, size, _ = _op_msg( - 0, - spec, - self.db, - read_preference, - self.codec_options, - ctx=conn.compression_context, - ) - return request_id, msg, size - - # OP_QUERY treats ntoreturn of -1 and 1 the same, return - # one document and close the cursor. We have to use 2 for - # batch size if 1 is specified. - ntoreturn = self.batch_size == 1 and 2 or self.batch_size - if self.limit: - if ntoreturn: - ntoreturn = min(self.limit, ntoreturn) - else: - ntoreturn = self.limit - - if conn.is_mongos: - assert isinstance(spec, MutableMapping) - spec = _maybe_add_read_preference(spec, read_preference) - - return _query( - flags, - ns, - self.ntoskip, - ntoreturn, - spec, - None if use_cmd else self.fields, - self.codec_options, - ctx=conn.compression_context, - ) - - -class _GetMore: - """A getmore operation.""" - - __slots__ = ( - "db", - "coll", - "ntoreturn", - "cursor_id", - "max_await_time_ms", - "codec_options", - "read_preference", - "session", - "client", - "conn_mgr", - "_as_command", - "exhaust", - "comment", - ) - - name = "getMore" - - def __init__( - self, - db: str, - coll: str, - ntoreturn: int, - cursor_id: int, - codec_options: CodecOptions, - read_preference: _ServerMode, - session: Optional[ClientSession], - client: AsyncMongoClient, - max_await_time_ms: Optional[int], - conn_mgr: Any, - exhaust: bool, - comment: Any, - ): - self.db = db - self.coll = coll - self.ntoreturn = ntoreturn - self.cursor_id = cursor_id - self.codec_options = codec_options - self.read_preference = read_preference - self.session = session - self.client = client - self.max_await_time_ms = max_await_time_ms - self.conn_mgr = conn_mgr - self._as_command: Optional[tuple[dict[str, Any], str]] = None - self.exhaust = exhaust - self.comment = comment - - def reset(self) -> None: - self._as_command = None - - def namespace(self) -> str: - return f"{self.db}.{self.coll}" - - def use_command(self, conn: Connection) -> bool: - use_cmd = False - if not self.exhaust: - use_cmd = True - elif conn.max_wire_version >= 8: - # OP_MSG supports exhaust on MongoDB 4.2+ - use_cmd = True - - conn.validate_session(self.client, self.session) - return use_cmd - - async def as_command( - self, conn: Connection, apply_timeout: bool = False - ) -> tuple[dict[str, Any], str]: - """Return a getMore command document for this query.""" - # See _Query.as_command for an explanation of this caching. - if self._as_command is not None: - return self._as_command - - cmd: dict[str, Any] = _gen_get_more_command( - self.cursor_id, - self.coll, - self.ntoreturn, - self.max_await_time_ms, - self.comment, - conn, - ) - if self.session: - self.session._apply_to(cmd, False, self.read_preference, conn) - conn.add_server_api(cmd) - conn.send_cluster_time(cmd, self.session, self.client) - # Support auto encryption - client = self.client - if client._encrypter and not client._encrypter._bypass_auto_encryption: - cmd = await client._encrypter.encrypt(self.db, cmd, self.codec_options) - # Support CSOT - if apply_timeout: - conn.apply_timeout(client, cmd=None) - self._as_command = cmd, self.db - return self._as_command - - async def get_message( - self, dummy0: Any, conn: Connection, use_cmd: bool = False - ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]: - """Get a getmore message.""" - ns = self.namespace() - ctx = conn.compression_context - - if use_cmd: - spec = (await self.as_command(conn, apply_timeout=True))[0] - if self.conn_mgr and self.exhaust: - flags = _OpMsg.EXHAUST_ALLOWED - else: - flags = 0 - request_id, msg, size, _ = _op_msg( - flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context - ) - return request_id, msg, size - - return _get_more(ns, self.ntoreturn, self.cursor_id, ctx) - - -class _RawBatchQuery(_Query): - def use_command(self, conn: Connection) -> bool: - # Compatibility checks. - super().use_command(conn) - if conn.max_wire_version >= 8: - # MongoDB 4.2+ supports exhaust over OP_MSG - return True - elif not self.exhaust: - return True - return False - - -class _RawBatchGetMore(_GetMore): - def use_command(self, conn: Connection) -> bool: - # Compatibility checks. - super().use_command(conn) - if conn.max_wire_version >= 8: - # MongoDB 4.2+ supports exhaust over OP_MSG - return True - elif not self.exhaust: - return True - return False - - -class _CursorAddress(tuple): - """The server address (host, port) of a cursor, with namespace property.""" - - __namespace: Any - - def __new__(cls, address: _Address, namespace: str) -> _CursorAddress: - self = tuple.__new__(cls, address) - self.__namespace = namespace - return self - - @property - def namespace(self) -> str: - """The namespace this cursor.""" - return self.__namespace - - def __hash__(self) -> int: - # Two _CursorAddress instances with different namespaces - # must not hash the same. - return ((*self, self.__namespace)).__hash__() - - def __eq__(self, other: object) -> bool: - if isinstance(other, _CursorAddress): - return tuple(self) == tuple(other) and self.namespace == other.namespace - return NotImplemented - - def __ne__(self, other: object) -> bool: - return not self == other - - -_pack_compression_header = struct.Struct(" tuple[int, bytes]: - """Takes message data, compresses it, and adds an OP_COMPRESSED header.""" - compressed = ctx.compress(data) - request_id = _randint() - - header = _pack_compression_header( - _COMPRESSION_HEADER_SIZE + len(compressed), # Total message length - request_id, # Request id - 0, # responseTo - 2012, # operation id - operation, # original operation id - len(data), # uncompressed message length - ctx.compressor_id, - ) # compressor id - return request_id, header + compressed - - -_pack_header = struct.Struct(" tuple[int, bytes]: - """Takes message data and adds a message header based on the operation. - - Returns the resultant message string. - """ - rid = _randint() - message = _pack_header(16 + len(data), rid, 0, operation) - return rid, message + data - - -_pack_int = struct.Struct(" tuple[bytes, int, int]: - """Get a OP_MSG message. - - Note: this method handles multiple documents in a type one payload but - it does not perform batch splitting and the total message size is - only checked *after* generating the entire message. - """ - # Encode the command document in payload 0 without checking keys. - encoded = _dict_to_bson(command, False, opts) - flags_type = _pack_op_msg_flags_type(flags, 0) - total_size = len(encoded) - max_doc_size = 0 - if identifier and docs is not None: - type_one = _pack_byte(1) - cstring = _make_c_string(identifier) - encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs] - size = len(cstring) + sum(len(doc) for doc in encoded_docs) + 4 - encoded_size = _pack_int(size) - total_size += size - max_doc_size = max(len(doc) for doc in encoded_docs) - data = [flags_type, encoded, type_one, encoded_size, cstring, *encoded_docs] - else: - data = [flags_type, encoded] - return b"".join(data), total_size, max_doc_size - - -def _op_msg_compressed( - flags: int, - command: Mapping[str, Any], - identifier: str, - docs: Optional[list[Mapping[str, Any]]], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> tuple[int, bytes, int, int]: - """Internal OP_MSG message helper.""" - msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) - rid, msg = _compress(2013, msg, ctx) - return rid, msg, total_size, max_bson_size - - -def _op_msg_uncompressed( - flags: int, - command: Mapping[str, Any], - identifier: str, - docs: Optional[list[Mapping[str, Any]]], - opts: CodecOptions, -) -> tuple[int, bytes, int, int]: - """Internal compressed OP_MSG message helper.""" - data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) - request_id, op_message = __pack_message(2013, data) - return request_id, op_message, total_size, max_bson_size - - -if _use_c: - _op_msg_uncompressed = _cmessage._op_msg - - -def _op_msg( - flags: int, - command: MutableMapping[str, Any], - dbname: str, - read_preference: Optional[_ServerMode], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> tuple[int, bytes, int, int]: - """Get a OP_MSG message.""" - command["$db"] = dbname - # getMore commands do not send $readPreference. - if read_preference is not None and "$readPreference" not in command: - # Only send $readPreference if it's not primary (the default). - if read_preference.mode: - command["$readPreference"] = read_preference.document - name = next(iter(command)) - try: - identifier = _FIELD_MAP[name] - docs = command.pop(identifier) - except KeyError: - identifier = "" - docs = None - try: - if ctx: - return _op_msg_compressed(flags, command, identifier, docs, opts, ctx) - return _op_msg_uncompressed(flags, command, identifier, docs, opts) - finally: - # Add the field back to the command. - if identifier: - command[identifier] = docs - - -def _query_impl( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, -) -> tuple[bytes, int]: - """Get an OP_QUERY message.""" - encoded = _dict_to_bson(query, False, opts) - if field_selector: - efs = _dict_to_bson(field_selector, False, opts) - else: - efs = b"" - max_bson_size = max(len(encoded), len(efs)) - return ( - b"".join( - [ - _pack_int(options), - _make_c_string(collection_name), - _pack_int(num_to_skip), - _pack_int(num_to_return), - encoded, - efs, - ] - ), - max_bson_size, - ) - - -def _query_compressed( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> tuple[int, bytes, int]: - """Internal compressed query message helper.""" - op_query, max_bson_size = _query_impl( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts - ) - rid, msg = _compress(2004, op_query, ctx) - return rid, msg, max_bson_size - - -def _query_uncompressed( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, -) -> tuple[int, bytes, int]: - """Internal query message helper.""" - op_query, max_bson_size = _query_impl( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts - ) - rid, msg = __pack_message(2004, op_query) - return rid, msg, max_bson_size - - -if _use_c: - _query_uncompressed = _cmessage._query_message - - -def _query( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> tuple[int, bytes, int]: - """Get a **query** message.""" - if ctx: - return _query_compressed( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx - ) - return _query_uncompressed( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts - ) - - -_pack_long_long = struct.Struct(" bytes: - """Get an OP_GET_MORE message.""" - return b"".join( - [ - _ZERO_32, - _make_c_string(collection_name), - _pack_int(num_to_return), - _pack_long_long(cursor_id), - ] - ) - - -def _get_more_compressed( - collection_name: str, - num_to_return: int, - cursor_id: int, - ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> tuple[int, bytes]: - """Internal compressed getMore message helper.""" - return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx) - - -def _get_more_uncompressed( - collection_name: str, num_to_return: int, cursor_id: int -) -> tuple[int, bytes]: - """Internal getMore message helper.""" - return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id)) - - -if _use_c: - _get_more_uncompressed = _cmessage._get_more_message - - -def _get_more( - collection_name: str, - num_to_return: int, - cursor_id: int, - ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> tuple[int, bytes]: - """Get a **getMore** message.""" - if ctx: - return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx) - return _get_more_uncompressed(collection_name, num_to_return, cursor_id) - - -class _BulkWriteContext: - """A wrapper around Connection for use with write splitting functions.""" - - __slots__ = ( - "db_name", - "conn", - "op_id", - "name", - "field", - "publish", - "start_time", - "listeners", - "session", - "compress", - "op_type", - "codec", - ) - - def __init__( - self, - database_name: str, - cmd_name: str, - conn: Connection, - operation_id: int, - listeners: _EventListeners, - session: ClientSession, - op_type: int, - codec: CodecOptions, - ): - self.db_name = database_name - self.conn = conn - self.op_id = operation_id - self.listeners = listeners - self.publish = listeners.enabled_for_commands - self.name = cmd_name - self.field = _FIELD_MAP[self.name] - self.start_time = datetime.datetime.now() - self.session = session - self.compress = bool(conn.compression_context) - self.op_type = op_type - self.codec = codec - - def __batch_command( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] - ) -> tuple[int, bytes, list[Mapping[str, Any]]]: - namespace = self.db_name + ".$cmd" - request_id, msg, to_send = _do_batched_op_msg( - namespace, self.op_type, cmd, docs, self.codec, self - ) - if not to_send: - raise InvalidOperation("cannot do an empty bulk write") - return request_id, msg, to_send - - async def execute( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: AsyncMongoClient - ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: - request_id, msg, to_send = self.__batch_command(cmd, docs) - result = await self.write_command(cmd, request_id, msg, to_send, client) - await client._process_response(result, self.session) - return result, to_send - - async def execute_unack( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: AsyncMongoClient - ) -> list[Mapping[str, Any]]: - request_id, msg, to_send = self.__batch_command(cmd, docs) - # Though this isn't strictly a "legacy" write, the helper - # handles publishing commands and sending our message - # without receiving a result. Send 0 for max_doc_size - # to disable size checking. Size checking is handled while - # the documents are encoded to BSON. - await self.unack_write(cmd, request_id, msg, 0, to_send, client) - return to_send - - @property - def max_bson_size(self) -> int: - """A proxy for SockInfo.max_bson_size.""" - return self.conn.max_bson_size - - @property - def max_message_size(self) -> int: - """A proxy for SockInfo.max_message_size.""" - if self.compress: - # Subtract 16 bytes for the message header. - return self.conn.max_message_size - 16 - return self.conn.max_message_size - - @property - def max_write_batch_size(self) -> int: - """A proxy for SockInfo.max_write_batch_size.""" - return self.conn.max_write_batch_size - - @property - def max_split_size(self) -> int: - """The maximum size of a BSON command before batch splitting.""" - return self.max_bson_size - - async def unack_write( - self, - cmd: MutableMapping[str, Any], - request_id: int, - msg: bytes, - max_doc_size: int, - docs: list[Mapping[str, Any]], - client: AsyncMongoClient, - ) -> Optional[Mapping[str, Any]]: - """A proxy for Connection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.STARTED, - command=cmd, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - cmd = self._start(cmd, request_id, docs) - try: - result = await self.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value] - duration = datetime.datetime.now() - self.start_time - if result is not None: - reply = _convert_write_result(self.name, cmd, result) - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.SUCCEEDED, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - self._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - self.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(self.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.FAILED, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if self.publish: - assert self.start_time is not None - self._fail(request_id, failure, duration) - raise - finally: - self.start_time = datetime.datetime.now() - return result - - @_handle_reauth - async def write_command( - self, - cmd: MutableMapping[str, Any], - request_id: int, - msg: bytes, - docs: list[Mapping[str, Any]], - client: AsyncMongoClient, - ) -> dict[str, Any]: - """A proxy for SocketInfo.write_command that handles event publishing.""" - cmd[self.field] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.STARTED, - command=cmd, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - self._start(cmd, request_id, docs) - try: - reply = await self.conn.write_command(request_id, msg, self.codec) - duration = datetime.datetime.now() - self.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.SUCCEEDED, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - self._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - self.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.FAILED, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if self.publish: - self._fail(request_id, failure, duration) - raise - finally: - self.start_time = datetime.datetime.now() - return reply - - def _start( - self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]] - ) -> MutableMapping[str, Any]: - """Publish a CommandStartedEvent.""" - cmd[self.field] = docs - self.listeners.publish_command_start( - cmd, - self.db_name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - ) - return cmd - - def _succeed(self, request_id: int, reply: _DocumentOut, duration: timedelta) -> None: - """Publish a CommandSucceededEvent.""" - self.listeners.publish_command_success( - duration, - reply, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) - - def _fail(self, request_id: int, failure: _DocumentOut, duration: timedelta) -> None: - """Publish a CommandFailedEvent.""" - self.listeners.publish_command_failure( - duration, - failure, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) - - -# From the Client Side Encryption spec: -# Because automatic encryption increases the size of commands, the driver -# MUST split bulk writes at a reduced size limit before undergoing automatic -# encryption. The write payload MUST be split at 2MiB (2097152). -_MAX_SPLIT_SIZE_ENC = 2097152 - - -class _EncryptedBulkWriteContext(_BulkWriteContext): - __slots__ = () - - def __batch_command( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] - ) -> tuple[dict[str, Any], list[Mapping[str, Any]]]: - namespace = self.db_name + ".$cmd" - msg, to_send = _encode_batched_write_command( - namespace, self.op_type, cmd, docs, self.codec, self - ) - if not to_send: - raise InvalidOperation("cannot do an empty bulk write") - - # Chop off the OP_QUERY header to get a properly batched write command. - cmd_start = msg.index(b"\x00", 4) + 9 - outgoing = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS) - return outgoing, to_send - - async def execute( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: AsyncMongoClient - ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: - batched_cmd, to_send = self.__batch_command(cmd, docs) - result: Mapping[str, Any] = await self.conn.command( - self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client - ) - return result, to_send - - async def execute_unack( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: AsyncMongoClient - ) -> list[Mapping[str, Any]]: - batched_cmd, to_send = self.__batch_command(cmd, docs) - await self.conn.command( - self.db_name, - batched_cmd, - write_concern=WriteConcern(w=0), - session=self.session, - client=client, - ) - return to_send - - @property - def max_split_size(self) -> int: - """Reduce the batch splitting size.""" - return _MAX_SPLIT_SIZE_ENC - - -def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> NoReturn: - """Internal helper for raising DocumentTooLarge.""" - if operation == "insert": - raise DocumentTooLarge( - "BSON document too large (%d bytes)" - " - the connected server supports" - " BSON document sizes up to %d" - " bytes." % (doc_size, max_size) - ) - else: - # There's nothing intelligent we can say - # about size for update and delete - raise DocumentTooLarge(f"{operation!r} command document too large") - - -# OP_MSG ------------------------------------------------------------- - - -_OP_MSG_MAP = { - _INSERT: b"documents\x00", - _UPDATE: b"updates\x00", - _DELETE: b"deletes\x00", -} - - -def _batched_op_msg_impl( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, - buf: _BytesIO, -) -> tuple[list[Mapping[str, Any]], int]: - """Create a batched OP_MSG write.""" - max_bson_size = ctx.max_bson_size - max_write_batch_size = ctx.max_write_batch_size - max_message_size = ctx.max_message_size - - flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00" - # Flags - buf.write(flags) - - # Type 0 Section - buf.write(b"\x00") - buf.write(_dict_to_bson(command, False, opts)) - - # Type 1 Section - buf.write(b"\x01") - size_location = buf.tell() - # Save space for size - buf.write(b"\x00\x00\x00\x00") - try: - buf.write(_OP_MSG_MAP[operation]) - except KeyError: - raise InvalidOperation("Unknown command") from None - - to_send = [] - idx = 0 - for doc in docs: - # Encode the current operation - value = _dict_to_bson(doc, False, opts) - doc_length = len(value) - new_message_size = buf.tell() + doc_length - # Does first document exceed max_message_size? - doc_too_large = idx == 0 and (new_message_size > max_message_size) - # When OP_MSG is used unacknowledged we have to check - # document size client side or applications won't be notified. - # Otherwise we let the server deal with documents that are too large - # since ordered=False causes those documents to be skipped instead of - # halting the bulk write operation. - unacked_doc_too_large = not ack and (doc_length > max_bson_size) - if doc_too_large or unacked_doc_too_large: - write_op = list(_FIELD_MAP.keys())[operation] - _raise_document_too_large(write_op, len(value), max_bson_size) - # We have enough data, return this batch. - if new_message_size > max_message_size: - break - buf.write(value) - to_send.append(doc) - idx += 1 - # We have enough documents, return this batch. - if idx == max_write_batch_size: - break - - # Write type 1 section size - length = buf.tell() - buf.seek(size_location) - buf.write(_pack_int(length - size_location)) - - return to_send, length - - -def _encode_batched_op_msg( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[bytes, list[Mapping[str, Any]]]: - """Encode the next batched insert, update, or delete operation - as OP_MSG. - """ - buf = _BytesIO() - - to_send, _ = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) - return buf.getvalue(), to_send - - -if _use_c: - _encode_batched_op_msg = _cmessage._encode_batched_op_msg - - -def _batched_op_msg_compressed( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[int, bytes, list[Mapping[str, Any]]]: - """Create the next batched insert, update, or delete operation - with OP_MSG, compressed. - """ - data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx) - - assert ctx.conn.compression_context is not None - request_id, msg = _compress(2013, data, ctx.conn.compression_context) - return request_id, msg, to_send - - -def _batched_op_msg( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[int, bytes, list[Mapping[str, Any]]]: - """OP_MSG implementation entry point.""" - buf = _BytesIO() - - # Save space for message length and request id - buf.write(_ZERO_64) - # responseTo, opCode - buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00") - - to_send, length = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) - - # Header - request id and message length - buf.seek(4) - request_id = _randint() - buf.write(_pack_int(request_id)) - buf.seek(0) - buf.write(_pack_int(length)) - - return request_id, buf.getvalue(), to_send - - -if _use_c: - _batched_op_msg = _cmessage._batched_op_msg - - -def _do_batched_op_msg( - namespace: str, - operation: int, - command: MutableMapping[str, Any], - docs: list[Mapping[str, Any]], - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[int, bytes, list[Mapping[str, Any]]]: - """Create the next batched insert, update, or delete operation - using OP_MSG. - """ - command["$db"] = namespace.split(".", 1)[0] - if "writeConcern" in command: - ack = bool(command["writeConcern"].get("w", 1)) - else: - ack = True - if ctx.conn.compression_context: - return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx) - return _batched_op_msg(operation, command, docs, ack, opts, ctx) - - -# End OP_MSG ----------------------------------------------------- - - -def _encode_batched_write_command( - namespace: str, - operation: int, - command: MutableMapping[str, Any], - docs: list[Mapping[str, Any]], - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[bytes, list[Mapping[str, Any]]]: - """Encode the next batched insert, update, or delete command.""" - buf = _BytesIO() - - to_send, _ = _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf) - return buf.getvalue(), to_send - - -if _use_c: - _encode_batched_write_command = _cmessage._encode_batched_write_command - - -def _batched_write_command_impl( - namespace: str, - operation: int, - command: MutableMapping[str, Any], - docs: list[Mapping[str, Any]], - opts: CodecOptions, - ctx: _BulkWriteContext, - buf: _BytesIO, -) -> tuple[list[Mapping[str, Any]], int]: - """Create a batched OP_QUERY write command.""" - max_bson_size = ctx.max_bson_size - max_write_batch_size = ctx.max_write_batch_size - # Max BSON object size + 16k - 2 bytes for ending NUL bytes. - # Server guarantees there is enough room: SERVER-10643. - max_cmd_size = max_bson_size + _COMMAND_OVERHEAD - max_split_size = ctx.max_split_size - - # No options - buf.write(_ZERO_32) - # Namespace as C string - buf.write(namespace.encode("utf8")) - buf.write(_ZERO_8) - # Skip: 0, Limit: -1 - buf.write(_SKIPLIM) - - # Where to write command document length - command_start = buf.tell() - buf.write(encode(command)) - - # Start of payload - buf.seek(-1, 2) - # Work around some Jython weirdness. - buf.truncate() - try: - buf.write(_OP_MAP[operation]) - except KeyError: - raise InvalidOperation("Unknown command") from None - - # Where to write list document length - list_start = buf.tell() - 4 - to_send = [] - idx = 0 - for doc in docs: - # Encode the current operation - key = str(idx).encode("utf8") - value = _dict_to_bson(doc, False, opts) - # Is there enough room to add this document? max_cmd_size accounts for - # the two trailing null bytes. - doc_too_large = len(value) > max_cmd_size - if doc_too_large: - write_op = list(_FIELD_MAP.keys())[operation] - _raise_document_too_large(write_op, len(value), max_bson_size) - enough_data = idx >= 1 and (buf.tell() + len(key) + len(value)) >= max_split_size - enough_documents = idx >= max_write_batch_size - if enough_data or enough_documents: - break - buf.write(_BSONOBJ) - buf.write(key) - buf.write(_ZERO_8) - buf.write(value) - to_send.append(doc) - idx += 1 - - # Finalize the current OP_QUERY message. - # Close list and command documents - buf.write(_ZERO_16) - - # Write document lengths and request id - length = buf.tell() - buf.seek(list_start) - buf.write(_pack_int(length - list_start - 1)) - buf.seek(command_start) - buf.write(_pack_int(length - command_start)) - - return to_send, length - - -class _OpReply: - """A MongoDB OP_REPLY response message.""" - - __slots__ = ("flags", "cursor_id", "number_returned", "documents") - - UNPACK_FROM = struct.Struct(" list[bytes]: - """Check the response header from the database, without decoding BSON. - - Check the response for errors and unpack. - - Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or - OperationFailure. - - :param cursor_id: cursor_id we sent to get this response - - used for raising an informative exception when we get cursor id not - valid at server response. - """ - if self.flags & 1: - # Shouldn't get this response if we aren't doing a getMore - if cursor_id is None: - raise ProtocolError("No cursor id for getMore operation") - - # Fake a getMore command response. OP_GET_MORE provides no - # document. - msg = "Cursor not found, cursor id: %d" % (cursor_id,) - errobj = {"ok": 0, "errmsg": msg, "code": 43} - raise CursorNotFound(msg, 43, errobj) - elif self.flags & 2: - error_object: dict = bson.BSON(self.documents).decode() - # Fake the ok field if it doesn't exist. - error_object.setdefault("ok", 0) - if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR): - raise NotPrimaryError(error_object["$err"], error_object) - elif error_object.get("code") == 50: - default_msg = "operation exceeded time limit" - raise ExecutionTimeout( - error_object.get("$err", default_msg), error_object.get("code"), error_object - ) - raise OperationFailure( - "database error: %s" % error_object.get("$err"), - error_object.get("code"), - error_object, - ) - if self.documents: - return [self.documents] - return [] - - def unpack_response( - self, - cursor_id: Optional[int] = None, - codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, - user_fields: Optional[Mapping[str, Any]] = None, - legacy_response: bool = False, - ) -> list[dict[str, Any]]: - """Unpack a response from the database and decode the BSON document(s). - - Check the response for errors and unpack, returning a dictionary - containing the response data. - - Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or - OperationFailure. - - :param cursor_id: cursor_id we sent to get this response - - used for raising an informative exception when we get cursor id not - valid at server response - :param codec_options: an instance of - :class:`~bson.codec_options.CodecOptions` - :param user_fields: Response fields that should be decoded - using the TypeDecoders from codec_options, passed to - bson._decode_all_selective. - """ - self.raw_response(cursor_id) - if legacy_response: - return bson.decode_all(self.documents, codec_options) - return bson._decode_all_selective(self.documents, codec_options, user_fields) - - def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: - """Unpack a command response.""" - docs = self.unpack_response(codec_options=codec_options) - assert self.number_returned == 1 - return docs[0] - - def raw_command_response(self) -> NoReturn: - """Return the bytes of the command response.""" - # This should never be called on _OpReply. - raise NotImplementedError - - @property - def more_to_come(self) -> bool: - """Is the moreToCome bit set on this response?""" - return False - - @classmethod - def unpack(cls, msg: bytes) -> _OpReply: - """Construct an _OpReply from raw bytes.""" - # PYTHON-945: ignore starting_from field. - flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg) - - documents = msg[20:] - return cls(flags, cursor_id, number_returned, documents) - - -class _OpMsg: - """A MongoDB OP_MSG response message.""" - - __slots__ = ("flags", "cursor_id", "number_returned", "payload_document") - - UNPACK_FROM = struct.Struct(" list[Mapping[str, Any]]: - """ - cursor_id is ignored - user_fields is used to determine which fields must not be decoded - """ - inflated_response = _decode_selective( - RawBSONDocument(self.payload_document), user_fields, _RAW_ARRAY_BSON_OPTIONS - ) - return [inflated_response] - - def unpack_response( - self, - cursor_id: Optional[int] = None, - codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, - user_fields: Optional[Mapping[str, Any]] = None, - legacy_response: bool = False, - ) -> list[dict[str, Any]]: - """Unpack a OP_MSG command response. - - :param cursor_id: Ignored, for compatibility with _OpReply. - :param codec_options: an instance of - :class:`~bson.codec_options.CodecOptions` - :param user_fields: Response fields that should be decoded - using the TypeDecoders from codec_options, passed to - bson._decode_all_selective. - """ - # If _OpMsg is in-use, this cannot be a legacy response. - assert not legacy_response - return bson._decode_all_selective(self.payload_document, codec_options, user_fields) - - def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: - """Unpack a command response.""" - return self.unpack_response(codec_options=codec_options)[0] - - def raw_command_response(self) -> bytes: - """Return the bytes of the command response.""" - return self.payload_document - - @property - def more_to_come(self) -> bool: - """Is the moreToCome bit set on this response?""" - return bool(self.flags & self.MORE_TO_COME) - - @classmethod - def unpack(cls, msg: bytes) -> _OpMsg: - """Construct an _OpMsg from raw bytes.""" - flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg) - if flags != 0: - if flags & cls.CHECKSUM_PRESENT: - raise ProtocolError(f"Unsupported OP_MSG flag checksumPresent: 0x{flags:x}") - - if flags ^ cls.MORE_TO_COME: - raise ProtocolError(f"Unsupported OP_MSG flags: 0x{flags:x}") - if first_payload_type != 0: - raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}") - - if len(msg) != first_payload_size + 5: - raise ProtocolError("Unsupported OP_MSG reply: >1 section") - - payload_document = msg[5:] - return cls(flags, payload_document) - - -_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = { - _OpReply.OP_CODE: _OpReply.unpack, - _OpMsg.OP_CODE: _OpMsg.unpack, -} diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index c743c569d..27a420bb1 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -58,42 +58,14 @@ from typing import ( from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, helpers_constants -from pymongo.asynchronous import ( - client_session, - common, - database, - helpers, - message, - periodic_executor, - uri_parser, -) +from pymongo import _csot, common, helpers_shared, uri_parser +from pymongo.asynchronous import client_session, database, periodic_executor from pymongo.asynchronous.change_stream import ChangeStream, ClusterChangeStream -from pymongo.asynchronous.client_options import ClientOptions from pymongo.asynchronous.client_session import _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.logger import _CLIENT_LOGGER, _log_or_warn -from pymongo.asynchronous.monitoring import ConnectionClosedReason -from pymongo.asynchronous.operations import _Op -from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode -from pymongo.asynchronous.server_selectors import writable_server_selector from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext -from pymongo.asynchronous.topology_description import TOPOLOGY_TYPE, TopologyDescription -from pymongo.asynchronous.typings import ( - ClusterTime, - _Address, - _CollationIn, - _DocumentType, - _DocumentTypeArg, - _Pipeline, -) -from pymongo.asynchronous.uri_parser import ( - _check_options, - _handle_option_deprecations, - _handle_security_options, - _normalize_options, -) +from pymongo.client_options import ClientOptions from pymongo.errors import ( AutoReconnect, BulkWriteError, @@ -108,29 +80,52 @@ from pymongo.errors import ( WriteConcernError, ) from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks +from pymongo.logger import _CLIENT_LOGGER, _log_or_warn +from pymongo.message import _CursorAddress, _GetMore, _Query +from pymongo.monitoring import ConnectionClosedReason +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE +from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription +from pymongo.typings import ( + ClusterTime, + _Address, + _CollationIn, + _DocumentType, + _DocumentTypeArg, + _Pipeline, +) +from pymongo.uri_parser import ( + _check_options, + _handle_option_deprecations, + _handle_security_options, + _normalize_options, +) from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern if TYPE_CHECKING: from types import TracebackType from bson.objectid import ObjectId - from pymongo.asynchronous.bulk import _Bulk - from pymongo.asynchronous.client_session import ClientSession, _ServerSession + from pymongo.asynchronous.bulk import _AsyncBulk + from pymongo.asynchronous.client_session import AsyncClientSession, _ServerSession from pymongo.asynchronous.cursor import _ConnectionManager - from pymongo.asynchronous.message import _CursorAddress, _GetMore, _Query - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.response import Response + from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server - from pymongo.asynchronous.server_selectors import Selection from pymongo.read_concern import ReadConcern + from pymongo.response import Response + from pymongo.server_selectors import Selection T = TypeVar("T") -_WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], Coroutine[Any, Any, T]] +_WriteCall = Callable[ + [Optional["AsyncClientSession"], "AsyncConnection", bool], Coroutine[Any, Any, T] +] _ReadCall = Callable[ - [Optional["ClientSession"], "Server", "Connection", _ServerMode], Coroutine[Any, Any, T] + [Optional["AsyncClientSession"], "Server", "AsyncConnection", _ServerMode], + Coroutine[Any, Any, T], ] _IS_SYNC = False @@ -892,7 +887,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): self._kill_cursors_executor = executor self._opened = False - def _should_pin_cursor(self, session: Optional[ClientSession]) -> Optional[bool]: + def _should_pin_cursor(self, session: Optional[AsyncClientSession]) -> Optional[bool]: return self._options.load_balanced and not (session and session.in_transaction) def _after_fork(self) -> None: @@ -915,7 +910,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): batch_size: Optional[int] = None, collation: Optional[_CollationIn] = None, start_at_operation_time: Optional[Timestamp] = None, - session: Optional[client_session.ClientSession] = None, + session: Optional[client_session.AsyncClientSession] = None, start_after: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, full_document_before_change: Optional[str] = None, @@ -989,7 +984,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): the specified :class:`~bson.timestamp.Timestamp`. Requires MongoDB >= 4.0. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param start_after: The same as `resume_after` except that `start_after` can resume notifications after an invalidate event. This option and `resume_after` are mutually exclusive. @@ -1163,30 +1158,30 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): """Request that a cursor and/or connection be cleaned up soon.""" self._kill_cursors_queue.append((address, cursor_id, conn_mgr)) - def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession: + def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession: server_session = _EmptyServerSession() opts = client_session.SessionOptions(**kwargs) - return client_session.ClientSession(self, server_session, opts, implicit) + return client_session.AsyncClientSession(self, server_session, opts, implicit) def start_session( self, causal_consistency: Optional[bool] = None, default_transaction_options: Optional[client_session.TransactionOptions] = None, snapshot: Optional[bool] = False, - ) -> client_session.ClientSession: + ) -> client_session.AsyncClientSession: """Start a logical session. This method takes the same parameters as :class:`~pymongo.client_session.SessionOptions`. See the :mod:`~pymongo.client_session` module for details and examples. - A :class:`~pymongo.client_session.ClientSession` may only be used with - the MongoClient that started it. :class:`ClientSession` instances are + A :class:`~pymongo.client_session.AsyncClientSession` may only be used with + the MongoClient that started it. :class:`AsyncClientSession` instances are **not thread-safe or fork-safe**. They can only be used by one thread - or process at a time. A single :class:`ClientSession` cannot be used + or process at a time. A single :class:`AsyncClientSession` cannot be used to run multiple operations concurrently. - :return: An instance of :class:`~pymongo.client_session.ClientSession`. + :return: An instance of :class:`~pymongo.client_session.AsyncClientSession`. .. versionadded:: 3.6 """ @@ -1197,7 +1192,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): snapshot=snapshot, ) - def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]: + def _ensure_session( + self, session: Optional[AsyncClientSession] = None + ) -> Optional[AsyncClientSession]: """If provided session is None, lend a temporary session.""" if session: return session @@ -1211,7 +1208,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): return None def _send_cluster_time( - self, command: MutableMapping[str, Any], session: Optional[ClientSession] + self, command: MutableMapping[str, Any], session: Optional[AsyncClientSession] ) -> None: topology_time = self._topology.max_cluster_time() session_time = session.cluster_time if session else None @@ -1474,7 +1471,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): async def _end_sessions(self, session_ids: list[_ServerSession]) -> None: """Send endSessions command(s) with the given session ids.""" try: - # Use Connection.command directly to avoid implicitly creating + # Use AsyncConnection.command directly to avoid implicitly creating # another session. async with await self._conn_for_reads( ReadPreference.PRIMARY_PREFERRED, None, operation=_Op.END_SESSIONS @@ -1535,8 +1532,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): @contextlib.asynccontextmanager async def _checkout( - self, server: Server, session: Optional[ClientSession] - ) -> AsyncGenerator[Connection, None]: + self, server: Server, session: Optional[AsyncClientSession] + ) -> AsyncGenerator[AsyncConnection, None]: in_txn = session and session.in_transaction async with _MongoClientErrorHandler(self, server, session) as err_handler: # Reuse the pinned connection, if it exists. @@ -1570,7 +1567,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): async def _select_server( self, server_selector: Callable[[Selection], Selection], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, address: Optional[_Address] = None, deprioritized_servers: Optional[list[Server]] = None, @@ -1581,7 +1578,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): :Parameters: - `server_selector`: The server selector to use if the session is not pinned and no address is given. - - `session`: The ClientSession for the next operation, or None. May + - `session`: The AsyncClientSession for the next operation, or None. May be pinned to a mongos server address. - `address` (optional): Address when sending a message to a specific server, used for getMore. @@ -1615,15 +1612,15 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): raise async def _conn_for_writes( - self, session: Optional[ClientSession], operation: str - ) -> AsyncContextManager[Connection]: + self, session: Optional[AsyncClientSession], operation: str + ) -> AsyncContextManager[AsyncConnection]: server = await self._select_server(writable_server_selector, session, operation) return self._checkout(server, session) @contextlib.asynccontextmanager async def _conn_from_server( - self, read_preference: _ServerMode, server: Server, session: Optional[ClientSession] - ) -> AsyncGenerator[tuple[Connection, _ServerMode], None]: + self, read_preference: _ServerMode, server: Server, session: Optional[AsyncClientSession] + ) -> AsyncGenerator[tuple[AsyncConnection, _ServerMode], None]: assert read_preference is not None, "read_preference must not be None" # Get a connection for a server matching the read preference, and yield # conn with the effective read preference. The Server Selection @@ -1648,9 +1645,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): async def _conn_for_reads( self, read_preference: _ServerMode, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, - ) -> AsyncContextManager[tuple[Connection, _ServerMode]]: + ) -> AsyncContextManager[tuple[AsyncConnection, _ServerMode]]: assert read_preference is not None, "read_preference must not be None" server = await self._select_server(read_preference, session, operation) return self._conn_from_server(read_preference, server, session) @@ -1672,13 +1669,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): if operation.conn_mgr: server = await self._select_server( operation.read_preference, - operation.session, + operation.session, # type: ignore[arg-type] operation.name, address=address, ) async with operation.conn_mgr._alock: - async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: + async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) return await server.run_operation( operation.conn_mgr.conn, @@ -1690,9 +1687,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): ) async def _cmd( - _session: Optional[ClientSession], + _session: Optional[AsyncClientSession], server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: _ServerMode, ) -> Response: operation.reset() # Reset op in case of retry. @@ -1708,9 +1705,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): return await self._retryable_read( _cmd, operation.read_preference, - operation.session, + operation.session, # type: ignore[arg-type] address=address, - retryable=isinstance(operation, message._Query), + retryable=isinstance(operation, _Query), operation=operation.name, ) @@ -1718,8 +1715,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): self, retryable: bool, func: _WriteCall[T], - session: Optional[ClientSession], - bulk: Optional[_Bulk], + session: Optional[AsyncClientSession], + bulk: Optional[_AsyncBulk], operation: str, operation_id: Optional[int] = None, ) -> T: @@ -1748,8 +1745,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): async def _retry_internal( self, func: _WriteCall[T] | _ReadCall[T], - session: Optional[ClientSession], - bulk: Optional[_Bulk], + session: Optional[AsyncClientSession], + bulk: Optional[_AsyncBulk], operation: str, is_read: bool = False, address: Optional[_Address] = None, @@ -1787,7 +1784,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): self, func: _ReadCall[T], read_pref: _ServerMode, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, address: Optional[_Address] = None, retryable: bool = True, @@ -1830,9 +1827,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): self, retryable: bool, func: _WriteCall[T], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, - bulk: Optional[_Bulk] = None, + bulk: Optional[_AsyncBulk] = None, operation_id: Optional[int] = None, ) -> T: """Execute an operation with consecutive retries if possible @@ -1856,7 +1853,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): cursor_id: int, address: Optional[_CursorAddress], conn_mgr: _ConnectionManager, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], explicit_session: bool, ) -> None: """Cleanup a cursor from __del__ without locking. @@ -1880,7 +1877,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): cursor_id: int, address: Optional[_CursorAddress], conn_mgr: _ConnectionManager, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], explicit_session: bool, ) -> None: """Cleanup a cursor from cursor.close() using a lock. @@ -1913,7 +1910,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): self, cursor_id: int, address: Optional[_CursorAddress], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, conn_mgr: Optional[_ConnectionManager] = None, ) -> None: """Send a kill cursors message with the given id. @@ -1941,7 +1938,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): cursor_ids: Sequence[int], address: Optional[_CursorAddress], topology: Topology, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], ) -> None: """Send a kill cursors message with the given ids.""" if address: @@ -1960,8 +1957,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): self, cursor_ids: Sequence[int], address: _CursorAddress, - session: Optional[ClientSession], - conn: Connection, + session: Optional[AsyncClientSession], + conn: AsyncConnection, ) -> None: namespace = address.namespace db, coll = namespace.split(".", 1) @@ -1994,7 +1991,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): # can be caught in _process_periodic_tasks raise else: - helpers._handle_exception() + helpers_shared._handle_exception() # Don't re-open topology if it's closed and there's no pending cursors. if address_to_cursor_ids: @@ -2006,7 +2003,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): if isinstance(exc, InvalidOperation) and self._topology._closed: raise else: - helpers._handle_exception() + helpers_shared._handle_exception() # This method is run periodically by a background thread. async def _process_periodic_tasks(self) -> None: @@ -2020,7 +2017,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): if isinstance(exc, InvalidOperation) and self._topology._closed: return else: - helpers._handle_exception() + helpers_shared._handle_exception() def _return_server_session( self, server_session: Union[_ServerSession, _EmptyServerSession] @@ -2032,12 +2029,12 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): @contextlib.asynccontextmanager async def _tmp_session( - self, session: Optional[client_session.ClientSession], close: bool = True - ) -> AsyncGenerator[Optional[client_session.ClientSession], None, None]: + self, session: Optional[client_session.AsyncClientSession], close: bool = True + ) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None, None]: """If provided session is None, lend a temporary session.""" if session is not None: - if not isinstance(session, client_session.ClientSession): - raise ValueError("'session' argument must be a ClientSession or None.") + if not isinstance(session, client_session.AsyncClientSession): + raise ValueError("'session' argument must be an AsyncClientSession or None.") # Don't call end_session. yield session return @@ -2061,19 +2058,19 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): yield None async def _process_response( - self, reply: Mapping[str, Any], session: Optional[ClientSession] + self, reply: Mapping[str, Any], session: Optional[AsyncClientSession] ) -> None: await self._topology.receive_cluster_time(reply.get("$clusterTime")) if session is not None: session._process_response(reply) async def server_info( - self, session: Optional[client_session.ClientSession] = None + self, session: Optional[client_session.AsyncClientSession] = None ) -> dict[str, Any]: """Get information about the MongoDB server we're connected to. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. .. versionchanged:: 3.6 Added ``session`` parameter. @@ -2087,7 +2084,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): async def _list_databases( self, - session: Optional[client_session.ClientSession] = None, + session: Optional[client_session.AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> AsyncCommandCursor[dict[str, Any]]: @@ -2109,14 +2106,14 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): async def list_databases( self, - session: Optional[client_session.ClientSession] = None, + session: Optional[client_session.AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> AsyncCommandCursor[dict[str, Any]]: """Get a cursor over the databases of the connected server. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: Optional parameters of the @@ -2134,13 +2131,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): async def list_database_names( self, - session: Optional[client_session.ClientSession] = None, + session: Optional[client_session.AsyncClientSession] = None, comment: Optional[Any] = None, ) -> list[str]: """Get a list of the names of all databases on the connected server. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -2156,7 +2153,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): async def drop_database( self, name_or_database: Union[str, database.AsyncDatabase[_DocumentTypeArg]], - session: Optional[client_session.ClientSession] = None, + session: Optional[client_session.AsyncClientSession] = None, comment: Optional[Any] = None, ) -> None: """Drop a database. @@ -2168,7 +2165,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): :class:`~pymongo.database.Database` instance representing the database to drop :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -2236,10 +2233,10 @@ def _add_retryable_write_error(exc: PyMongoError, max_wire_version: int, is_mong # Do not consult writeConcernError for pre-4.4 mongos. if isinstance(exc, WriteConcernError) and is_mongos: pass - elif code in helpers_constants._RETRYABLE_ERROR_CODES: + elif code in helpers_shared._RETRYABLE_ERROR_CODES: exc._add_error_label("RetryableWriteError") - # Connection errors are always retryable except NotPrimaryError and WaitQueueTimeoutError which is + # AsyncConnection errors are always retryable except NotPrimaryError and WaitQueueTimeoutError which is # handled above. if isinstance(exc, ConnectionFailure) and not isinstance( exc, (NotPrimaryError, WaitQueueTimeoutError) @@ -2261,7 +2258,9 @@ class _MongoClientErrorHandler: "handled", ) - def __init__(self, client: AsyncMongoClient, server: Server, session: Optional[ClientSession]): + def __init__( + self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] + ): self.client = client self.server_address = server.description.address self.session = session @@ -2275,7 +2274,7 @@ class _MongoClientErrorHandler: self.service_id: Optional[ObjectId] = None self.handled = False - def contribute_socket(self, conn: Connection, completed_handshake: bool = True) -> None: + def contribute_socket(self, conn: AsyncConnection, completed_handshake: bool = True) -> None: """Provide socket information to the error handler.""" self.max_wire_version = conn.max_wire_version self.sock_generation = conn.generation @@ -2327,10 +2326,10 @@ class _ClientConnectionRetryable(Generic[T]): self, mongo_client: AsyncMongoClient, func: _WriteCall[T] | _ReadCall[T], - bulk: Optional[_Bulk], + bulk: Optional[_AsyncBulk], operation: str, is_read: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, read_pref: Optional[_ServerMode] = None, address: Optional[_Address] = None, retryable: bool = False, @@ -2392,7 +2391,7 @@ class _ClientConnectionRetryable(Generic[T]): exc_code = getattr(exc, "code", None) if self._is_not_eligible_for_retry() or ( isinstance(exc, OperationFailure) - and exc_code not in helpers_constants._RETRYABLE_ERROR_CODES + and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES ): raise self._retrying = True diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index 6bd806108..a5f743512 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -21,19 +21,20 @@ import time import weakref from typing import TYPE_CHECKING, Any, Mapping, Optional, cast +from pymongo import common from pymongo._csot import MovingMinimum -from pymongo.asynchronous import common, periodic_executor -from pymongo.asynchronous.hello import Hello +from pymongo.asynchronous import periodic_executor from pymongo.asynchronous.periodic_executor import _shutdown_executors -from pymongo.asynchronous.pool import _is_faas -from pymongo.asynchronous.read_preferences import MovingAverage -from pymongo.asynchronous.server_description import ServerDescription -from pymongo.asynchronous.srv_resolver import _SrvResolver from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled +from pymongo.hello import Hello from pymongo.lock import _create_lock +from pymongo.pool_options import _is_faas +from pymongo.read_preferences import MovingAverage +from pymongo.server_description import ServerDescription +from pymongo.srv_resolver import _SrvResolver if TYPE_CHECKING: - from pymongo.asynchronous.pool import Connection, Pool, _CancellationContext + from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology @@ -294,7 +295,7 @@ class Monitor(MonitorBase): ) return sd - async def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: + async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float]: """Return (Hello, round_trip_time). Can raise ConnectionFailure or OperationFailure. diff --git a/pymongo/asynchronous/monitoring.py b/pymongo/asynchronous/monitoring.py deleted file mode 100644 index 36d015fe2..000000000 --- a/pymongo/asynchronous/monitoring.py +++ /dev/null @@ -1,1903 +0,0 @@ -# Copyright 2015-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Tools to monitor driver events. - -.. versionadded:: 3.1 - -.. attention:: Starting in PyMongo 3.11, the monitoring classes outlined below - are included in the PyMongo distribution under the - :mod:`~pymongo.event_loggers` submodule. - -Use :func:`register` to register global listeners for specific events. -Listeners must inherit from one of the abstract classes below and implement -the correct functions for that class. - -For example, a simple command logger might be implemented like this:: - - import logging - - from pymongo import monitoring - - class CommandLogger(monitoring.CommandListener): - - def started(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} started on server " - "{0.connection_id}".format(event)) - - def succeeded(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} on server {0.connection_id} " - "succeeded in {0.duration_micros} " - "microseconds".format(event)) - - def failed(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} on server {0.connection_id} " - "failed in {0.duration_micros} " - "microseconds".format(event)) - - monitoring.register(CommandLogger()) - -Server discovery and monitoring events are also available. For example:: - - class ServerLogger(monitoring.ServerListener): - - def opened(self, event): - logging.info("Server {0.server_address} added to topology " - "{0.topology_id}".format(event)) - - def description_changed(self, event): - previous_server_type = event.previous_description.server_type - new_server_type = event.new_description.server_type - if new_server_type != previous_server_type: - # server_type_name was added in PyMongo 3.4 - logging.info( - "Server {0.server_address} changed type from " - "{0.previous_description.server_type_name} to " - "{0.new_description.server_type_name}".format(event)) - - def closed(self, event): - logging.warning("Server {0.server_address} removed from topology " - "{0.topology_id}".format(event)) - - - class HeartbeatLogger(monitoring.ServerHeartbeatListener): - - def started(self, event): - logging.info("Heartbeat sent to server " - "{0.connection_id}".format(event)) - - def succeeded(self, event): - # The reply.document attribute was added in PyMongo 3.4. - logging.info("Heartbeat to server {0.connection_id} " - "succeeded with reply " - "{0.reply.document}".format(event)) - - def failed(self, event): - logging.warning("Heartbeat to server {0.connection_id} " - "failed with error {0.reply}".format(event)) - - class TopologyLogger(monitoring.TopologyListener): - - def opened(self, event): - logging.info("Topology with id {0.topology_id} " - "opened".format(event)) - - def description_changed(self, event): - logging.info("Topology description updated for " - "topology id {0.topology_id}".format(event)) - previous_topology_type = event.previous_description.topology_type - new_topology_type = event.new_description.topology_type - if new_topology_type != previous_topology_type: - # topology_type_name was added in PyMongo 3.4 - logging.info( - "Topology {0.topology_id} changed type from " - "{0.previous_description.topology_type_name} to " - "{0.new_description.topology_type_name}".format(event)) - # The has_writable_server and has_readable_server methods - # were added in PyMongo 3.4. - if not event.new_description.has_writable_server(): - logging.warning("No writable servers available.") - if not event.new_description.has_readable_server(): - logging.warning("No readable servers available.") - - def closed(self, event): - logging.info("Topology with id {0.topology_id} " - "closed".format(event)) - -Connection monitoring and pooling events are also available. For example:: - - class ConnectionPoolLogger(ConnectionPoolListener): - - def pool_created(self, event): - logging.info("[pool {0.address}] pool created".format(event)) - - def pool_ready(self, event): - logging.info("[pool {0.address}] pool is ready".format(event)) - - def pool_cleared(self, event): - logging.info("[pool {0.address}] pool cleared".format(event)) - - def pool_closed(self, event): - logging.info("[pool {0.address}] pool closed".format(event)) - - def connection_created(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection created".format(event)) - - def connection_ready(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection setup succeeded".format(event)) - - def connection_closed(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection closed, reason: " - "{0.reason}".format(event)) - - def connection_check_out_started(self, event): - logging.info("[pool {0.address}] connection check out " - "started".format(event)) - - def connection_check_out_failed(self, event): - logging.info("[pool {0.address}] connection check out " - "failed, reason: {0.reason}".format(event)) - - def connection_checked_out(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection checked out of pool".format(event)) - - def connection_checked_in(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection checked into pool".format(event)) - - -Event listeners can also be registered per instance of -:class:`~pymongo.mongo_client.MongoClient`:: - - client = MongoClient(event_listeners=[CommandLogger()]) - -Note that previously registered global listeners are automatically included -when configuring per client event listeners. Registering a new global listener -will not add that listener to existing client instances. - -.. note:: Events are delivered **synchronously**. Application threads block - waiting for event handlers (e.g. :meth:`~CommandListener.started`) to - return. Care must be taken to ensure that your event handlers are efficient - enough to not adversely affect overall application performance. - -.. warning:: The command documents published through this API are *not* copies. - If you intend to modify them in any way you must copy them in your event - handler first. -""" - -from __future__ import annotations - -import datetime -from collections import abc, namedtuple -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence - -from bson.objectid import ObjectId -from pymongo.asynchronous.hello import Hello -from pymongo.asynchronous.hello_compat import HelloCompat -from pymongo.asynchronous.helpers import _handle_exception -from pymongo.asynchronous.typings import _Address, _DocumentOut -from pymongo.helpers_constants import _SENSITIVE_COMMANDS - -if TYPE_CHECKING: - from datetime import timedelta - - from pymongo.asynchronous.server_description import ServerDescription - from pymongo.asynchronous.topology_description import TopologyDescription - -_IS_SYNC = False - -_Listeners = namedtuple( - "_Listeners", - ( - "command_listeners", - "server_listeners", - "server_heartbeat_listeners", - "topology_listeners", - "cmap_listeners", - ), -) - -_LISTENERS = _Listeners([], [], [], [], []) - - -class _EventListener: - """Abstract base class for all event listeners.""" - - -class CommandListener(_EventListener): - """Abstract base class for command listeners. - - Handles `CommandStartedEvent`, `CommandSucceededEvent`, - and `CommandFailedEvent`. - """ - - def started(self, event: CommandStartedEvent) -> None: - """Abstract method to handle a `CommandStartedEvent`. - - :param event: An instance of :class:`CommandStartedEvent`. - """ - raise NotImplementedError - - def succeeded(self, event: CommandSucceededEvent) -> None: - """Abstract method to handle a `CommandSucceededEvent`. - - :param event: An instance of :class:`CommandSucceededEvent`. - """ - raise NotImplementedError - - def failed(self, event: CommandFailedEvent) -> None: - """Abstract method to handle a `CommandFailedEvent`. - - :param event: An instance of :class:`CommandFailedEvent`. - """ - raise NotImplementedError - - -class ConnectionPoolListener(_EventListener): - """Abstract base class for connection pool listeners. - - Handles all of the connection pool events defined in the Connection - Monitoring and Pooling Specification: - :class:`PoolCreatedEvent`, :class:`PoolClearedEvent`, - :class:`PoolClosedEvent`, :class:`ConnectionCreatedEvent`, - :class:`ConnectionReadyEvent`, :class:`ConnectionClosedEvent`, - :class:`ConnectionCheckOutStartedEvent`, - :class:`ConnectionCheckOutFailedEvent`, - :class:`ConnectionCheckedOutEvent`, - and :class:`ConnectionCheckedInEvent`. - - .. versionadded:: 3.9 - """ - - def pool_created(self, event: PoolCreatedEvent) -> None: - """Abstract method to handle a :class:`PoolCreatedEvent`. - - Emitted when a connection Pool is created. - - :param event: An instance of :class:`PoolCreatedEvent`. - """ - raise NotImplementedError - - def pool_ready(self, event: PoolReadyEvent) -> None: - """Abstract method to handle a :class:`PoolReadyEvent`. - - Emitted when a connection Pool is marked ready. - - :param event: An instance of :class:`PoolReadyEvent`. - - .. versionadded:: 4.0 - """ - raise NotImplementedError - - def pool_cleared(self, event: PoolClearedEvent) -> None: - """Abstract method to handle a `PoolClearedEvent`. - - Emitted when a connection Pool is cleared. - - :param event: An instance of :class:`PoolClearedEvent`. - """ - raise NotImplementedError - - def pool_closed(self, event: PoolClosedEvent) -> None: - """Abstract method to handle a `PoolClosedEvent`. - - Emitted when a connection Pool is closed. - - :param event: An instance of :class:`PoolClosedEvent`. - """ - raise NotImplementedError - - def connection_created(self, event: ConnectionCreatedEvent) -> None: - """Abstract method to handle a :class:`ConnectionCreatedEvent`. - - Emitted when a connection Pool creates a Connection object. - - :param event: An instance of :class:`ConnectionCreatedEvent`. - """ - raise NotImplementedError - - def connection_ready(self, event: ConnectionReadyEvent) -> None: - """Abstract method to handle a :class:`ConnectionReadyEvent`. - - Emitted when a connection has finished its setup, and is now ready to - use. - - :param event: An instance of :class:`ConnectionReadyEvent`. - """ - raise NotImplementedError - - def connection_closed(self, event: ConnectionClosedEvent) -> None: - """Abstract method to handle a :class:`ConnectionClosedEvent`. - - Emitted when a connection Pool closes a connection. - - :param event: An instance of :class:`ConnectionClosedEvent`. - """ - raise NotImplementedError - - def connection_check_out_started(self, event: ConnectionCheckOutStartedEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckOutStartedEvent`. - - Emitted when the driver starts attempting to check out a connection. - - :param event: An instance of :class:`ConnectionCheckOutStartedEvent`. - """ - raise NotImplementedError - - def connection_check_out_failed(self, event: ConnectionCheckOutFailedEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckOutFailedEvent`. - - Emitted when the driver's attempt to check out a connection fails. - - :param event: An instance of :class:`ConnectionCheckOutFailedEvent`. - """ - raise NotImplementedError - - def connection_checked_out(self, event: ConnectionCheckedOutEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckedOutEvent`. - - Emitted when the driver successfully checks out a connection. - - :param event: An instance of :class:`ConnectionCheckedOutEvent`. - """ - raise NotImplementedError - - def connection_checked_in(self, event: ConnectionCheckedInEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckedInEvent`. - - Emitted when the driver checks in a connection back to the connection - Pool. - - :param event: An instance of :class:`ConnectionCheckedInEvent`. - """ - raise NotImplementedError - - -class ServerHeartbeatListener(_EventListener): - """Abstract base class for server heartbeat listeners. - - Handles `ServerHeartbeatStartedEvent`, `ServerHeartbeatSucceededEvent`, - and `ServerHeartbeatFailedEvent`. - - .. versionadded:: 3.3 - """ - - def started(self, event: ServerHeartbeatStartedEvent) -> None: - """Abstract method to handle a `ServerHeartbeatStartedEvent`. - - :param event: An instance of :class:`ServerHeartbeatStartedEvent`. - """ - raise NotImplementedError - - def succeeded(self, event: ServerHeartbeatSucceededEvent) -> None: - """Abstract method to handle a `ServerHeartbeatSucceededEvent`. - - :param event: An instance of :class:`ServerHeartbeatSucceededEvent`. - """ - raise NotImplementedError - - def failed(self, event: ServerHeartbeatFailedEvent) -> None: - """Abstract method to handle a `ServerHeartbeatFailedEvent`. - - :param event: An instance of :class:`ServerHeartbeatFailedEvent`. - """ - raise NotImplementedError - - -class TopologyListener(_EventListener): - """Abstract base class for topology monitoring listeners. - Handles `TopologyOpenedEvent`, `TopologyDescriptionChangedEvent`, and - `TopologyClosedEvent`. - - .. versionadded:: 3.3 - """ - - def opened(self, event: TopologyOpenedEvent) -> None: - """Abstract method to handle a `TopologyOpenedEvent`. - - :param event: An instance of :class:`TopologyOpenedEvent`. - """ - raise NotImplementedError - - def description_changed(self, event: TopologyDescriptionChangedEvent) -> None: - """Abstract method to handle a `TopologyDescriptionChangedEvent`. - - :param event: An instance of :class:`TopologyDescriptionChangedEvent`. - """ - raise NotImplementedError - - def closed(self, event: TopologyClosedEvent) -> None: - """Abstract method to handle a `TopologyClosedEvent`. - - :param event: An instance of :class:`TopologyClosedEvent`. - """ - raise NotImplementedError - - -class ServerListener(_EventListener): - """Abstract base class for server listeners. - Handles `ServerOpeningEvent`, `ServerDescriptionChangedEvent`, and - `ServerClosedEvent`. - - .. versionadded:: 3.3 - """ - - def opened(self, event: ServerOpeningEvent) -> None: - """Abstract method to handle a `ServerOpeningEvent`. - - :param event: An instance of :class:`ServerOpeningEvent`. - """ - raise NotImplementedError - - def description_changed(self, event: ServerDescriptionChangedEvent) -> None: - """Abstract method to handle a `ServerDescriptionChangedEvent`. - - :param event: An instance of :class:`ServerDescriptionChangedEvent`. - """ - raise NotImplementedError - - def closed(self, event: ServerClosedEvent) -> None: - """Abstract method to handle a `ServerClosedEvent`. - - :param event: An instance of :class:`ServerClosedEvent`. - """ - raise NotImplementedError - - -def _to_micros(dur: timedelta) -> int: - """Convert duration 'dur' to microseconds.""" - return int(dur.total_seconds() * 10e5) - - -def _validate_event_listeners( - option: str, listeners: Sequence[_EventListeners] -) -> Sequence[_EventListeners]: - """Validate event listeners""" - if not isinstance(listeners, abc.Sequence): - raise TypeError(f"{option} must be a list or tuple") - for listener in listeners: - if not isinstance(listener, _EventListener): - raise TypeError( - f"Listeners for {option} must be either a " - "CommandListener, ServerHeartbeatListener, " - "ServerListener, TopologyListener, or " - "ConnectionPoolListener." - ) - return listeners - - -def register(listener: _EventListener) -> None: - """Register a global event listener. - - :param listener: A subclasses of :class:`CommandListener`, - :class:`ServerHeartbeatListener`, :class:`ServerListener`, - :class:`TopologyListener`, or :class:`ConnectionPoolListener`. - """ - if not isinstance(listener, _EventListener): - raise TypeError( - f"Listeners for {listener} must be either a " - "CommandListener, ServerHeartbeatListener, " - "ServerListener, TopologyListener, or " - "ConnectionPoolListener." - ) - if isinstance(listener, CommandListener): - _LISTENERS.command_listeners.append(listener) - if isinstance(listener, ServerHeartbeatListener): - _LISTENERS.server_heartbeat_listeners.append(listener) - if isinstance(listener, ServerListener): - _LISTENERS.server_listeners.append(listener) - if isinstance(listener, TopologyListener): - _LISTENERS.topology_listeners.append(listener) - if isinstance(listener, ConnectionPoolListener): - _LISTENERS.cmap_listeners.append(listener) - - -# The "hello" command is also deemed sensitive when attempting speculative -# authentication. -def _is_speculative_authenticate(command_name: str, doc: Mapping[str, Any]) -> bool: - if ( - command_name.lower() in ("hello", HelloCompat.LEGACY_CMD) - and "speculativeAuthenticate" in doc - ): - return True - return False - - -class _CommandEvent: - """Base class for command events.""" - - __slots__ = ( - "__cmd_name", - "__rqst_id", - "__conn_id", - "__op_id", - "__service_id", - "__db", - "__server_conn_id", - ) - - def __init__( - self, - command_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - database_name: str = "", - server_connection_id: Optional[int] = None, - ) -> None: - self.__cmd_name = command_name - self.__rqst_id = request_id - self.__conn_id = connection_id - self.__op_id = operation_id - self.__service_id = service_id - self.__db = database_name - self.__server_conn_id = server_connection_id - - @property - def command_name(self) -> str: - """The command name.""" - return self.__cmd_name - - @property - def request_id(self) -> int: - """The request id for this operation.""" - return self.__rqst_id - - @property - def connection_id(self) -> _Address: - """The address (host, port) of the server this command was sent to.""" - return self.__conn_id - - @property - def service_id(self) -> Optional[ObjectId]: - """The service_id this command was sent to, or ``None``. - - .. versionadded:: 3.12 - """ - return self.__service_id - - @property - def operation_id(self) -> Optional[int]: - """An id for this series of events or None.""" - return self.__op_id - - @property - def database_name(self) -> str: - """The database_name this command was sent to, or ``""``. - - .. versionadded:: 4.6 - """ - return self.__db - - @property - def server_connection_id(self) -> Optional[int]: - """The server-side connection id for the connection this command was sent on, or ``None``. - - .. versionadded:: 4.7 - """ - return self.__server_conn_id - - -class CommandStartedEvent(_CommandEvent): - """Event published when a command starts. - - :param command: The command document. - :param database_name: The name of the database this command was run against. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this command - was sent to. - :param operation_id: An optional identifier for a series of related events. - :param service_id: The service_id this command was sent to, or ``None``. - """ - - __slots__ = ("__cmd",) - - def __init__( - self, - command: _DocumentOut, - database_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - server_connection_id: Optional[int] = None, - ) -> None: - if not command: - raise ValueError(f"{command!r} is not a valid command") - # Command name must be first key. - command_name = next(iter(command)) - super().__init__( - command_name, - request_id, - connection_id, - operation_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - cmd_name = command_name.lower() - if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, command): - self.__cmd: _DocumentOut = {} - else: - self.__cmd = command - - @property - def command(self) -> _DocumentOut: - """The command document.""" - return self.__cmd - - @property - def database_name(self) -> str: - """The name of the database this command was run against.""" - return super().database_name - - def __repr__(self) -> str: - return ( - "<{} {} db: {!r}, command: {!r}, operation_id: {}, service_id: {}, server_connection_id: {}>" - ).format( - self.__class__.__name__, - self.connection_id, - self.database_name, - self.command_name, - self.operation_id, - self.service_id, - self.server_connection_id, - ) - - -class CommandSucceededEvent(_CommandEvent): - """Event published when a command succeeds. - - :param duration: The command duration as a datetime.timedelta. - :param reply: The server reply document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this command - was sent to. - :param operation_id: An optional identifier for a series of related events. - :param service_id: The service_id this command was sent to, or ``None``. - :param database_name: The database this command was sent to, or ``""``. - """ - - __slots__ = ("__duration_micros", "__reply") - - def __init__( - self, - duration: datetime.timedelta, - reply: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - database_name: str = "", - server_connection_id: Optional[int] = None, - ) -> None: - super().__init__( - command_name, - request_id, - connection_id, - operation_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - self.__duration_micros = _to_micros(duration) - cmd_name = command_name.lower() - if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, reply): - self.__reply: _DocumentOut = {} - else: - self.__reply = reply - - @property - def duration_micros(self) -> int: - """The duration of this operation in microseconds.""" - return self.__duration_micros - - @property - def reply(self) -> _DocumentOut: - """The server failure document for this operation.""" - return self.__reply - - def __repr__(self) -> str: - return ( - "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, service_id: {}, server_connection_id: {}>" - ).format( - self.__class__.__name__, - self.connection_id, - self.database_name, - self.command_name, - self.operation_id, - self.duration_micros, - self.service_id, - self.server_connection_id, - ) - - -class CommandFailedEvent(_CommandEvent): - """Event published when a command fails. - - :param duration: The command duration as a datetime.timedelta. - :param failure: The server reply document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this command - was sent to. - :param operation_id: An optional identifier for a series of related events. - :param service_id: The service_id this command was sent to, or ``None``. - :param database_name: The database this command was sent to, or ``""``. - """ - - __slots__ = ("__duration_micros", "__failure") - - def __init__( - self, - duration: datetime.timedelta, - failure: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - database_name: str = "", - server_connection_id: Optional[int] = None, - ) -> None: - super().__init__( - command_name, - request_id, - connection_id, - operation_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - self.__duration_micros = _to_micros(duration) - self.__failure = failure - - @property - def duration_micros(self) -> int: - """The duration of this operation in microseconds.""" - return self.__duration_micros - - @property - def failure(self) -> _DocumentOut: - """The server failure document for this operation.""" - return self.__failure - - def __repr__(self) -> str: - return ( - "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, " - "failure: {!r}, service_id: {}, server_connection_id: {}>" - ).format( - self.__class__.__name__, - self.connection_id, - self.database_name, - self.command_name, - self.operation_id, - self.duration_micros, - self.failure, - self.service_id, - self.server_connection_id, - ) - - -class _PoolEvent: - """Base class for pool events.""" - - __slots__ = ("__address",) - - def __init__(self, address: _Address) -> None: - self.__address = address - - @property - def address(self) -> _Address: - """The address (host, port) pair of the server the pool is attempting - to connect to. - """ - return self.__address - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.__address!r})" - - -class PoolCreatedEvent(_PoolEvent): - """Published when a Connection Pool is created. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__options",) - - def __init__(self, address: _Address, options: dict[str, Any]) -> None: - super().__init__(address) - self.__options = options - - @property - def options(self) -> dict[str, Any]: - """Any non-default pool options that were set on this Connection Pool.""" - return self.__options - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__options!r})" - - -class PoolReadyEvent(_PoolEvent): - """Published when a Connection Pool is marked ready. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - - .. versionadded:: 4.0 - """ - - __slots__ = () - - -class PoolClearedEvent(_PoolEvent): - """Published when a Connection Pool is cleared. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - :param service_id: The service_id this command was sent to, or ``None``. - :param interrupt_connections: True if all active connections were interrupted by the Pool during clearing. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__service_id", "__interrupt_connections") - - def __init__( - self, - address: _Address, - service_id: Optional[ObjectId] = None, - interrupt_connections: bool = False, - ) -> None: - super().__init__(address) - self.__service_id = service_id - self.__interrupt_connections = interrupt_connections - - @property - def service_id(self) -> Optional[ObjectId]: - """Connections with this service_id are cleared. - - When service_id is ``None``, all connections in the pool are cleared. - - .. versionadded:: 3.12 - """ - return self.__service_id - - @property - def interrupt_connections(self) -> bool: - """If True, active connections are interrupted during clearing. - - .. versionadded:: 4.7 - """ - return self.__interrupt_connections - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__service_id!r}, {self.__interrupt_connections!r})" - - -class PoolClosedEvent(_PoolEvent): - """Published when a Connection Pool is closed. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionClosedReason: - """An enum that defines values for `reason` on a - :class:`ConnectionClosedEvent`. - - .. versionadded:: 3.9 - """ - - STALE = "stale" - """The pool was cleared, making the connection no longer valid.""" - - IDLE = "idle" - """The connection became stale by being idle for too long (maxIdleTimeMS). - """ - - ERROR = "error" - """The connection experienced an error, making it no longer valid.""" - - POOL_CLOSED = "poolClosed" - """The pool was closed, making the connection no longer valid.""" - - -class ConnectionCheckOutFailedReason: - """An enum that defines values for `reason` on a - :class:`ConnectionCheckOutFailedEvent`. - - .. versionadded:: 3.9 - """ - - TIMEOUT = "timeout" - """The connection check out attempt exceeded the specified timeout.""" - - POOL_CLOSED = "poolClosed" - """The pool was previously closed, and cannot provide new connections.""" - - CONN_ERROR = "connectionError" - """The connection check out attempt experienced an error while setting up - a new connection. - """ - - -class _ConnectionEvent: - """Private base class for connection events.""" - - __slots__ = ("__address",) - - def __init__(self, address: _Address) -> None: - self.__address = address - - @property - def address(self) -> _Address: - """The address (host, port) pair of the server this connection is - attempting to connect to. - """ - return self.__address - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.__address!r})" - - -class _ConnectionIdEvent(_ConnectionEvent): - """Private base class for connection events with an id.""" - - __slots__ = ("__connection_id",) - - def __init__(self, address: _Address, connection_id: int) -> None: - super().__init__(address) - self.__connection_id = connection_id - - @property - def connection_id(self) -> int: - """The ID of the connection.""" - return self.__connection_id - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__connection_id!r})" - - -class _ConnectionDurationEvent(_ConnectionIdEvent): - """Private base class for connection events with a duration.""" - - __slots__ = ("__duration",) - - def __init__(self, address: _Address, connection_id: int, duration: Optional[float]) -> None: - super().__init__(address, connection_id) - self.__duration = duration - - @property - def duration(self) -> Optional[float]: - """The duration of the connection event. - - .. versionadded:: 4.7 - """ - return self.__duration - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.connection_id!r}, {self.__duration!r})" - - -class ConnectionCreatedEvent(_ConnectionIdEvent): - """Published when a Connection Pool creates a Connection object. - - NOTE: This connection is not ready for use until the - :class:`ConnectionReadyEvent` is published. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionReadyEvent(_ConnectionDurationEvent): - """Published when a Connection has finished its setup, and is ready to use. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionClosedEvent(_ConnectionIdEvent): - """Published when a Connection is closed. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - :param reason: A reason explaining why this connection was closed. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__reason",) - - def __init__(self, address: _Address, connection_id: int, reason: str): - super().__init__(address, connection_id) - self.__reason = reason - - @property - def reason(self) -> str: - """A reason explaining why this connection was closed. - - The reason must be one of the strings from the - :class:`ConnectionClosedReason` enum. - """ - return self.__reason - - def __repr__(self) -> str: - return "{}({!r}, {!r}, {!r})".format( - self.__class__.__name__, - self.address, - self.connection_id, - self.__reason, - ) - - -class ConnectionCheckOutStartedEvent(_ConnectionEvent): - """Published when the driver starts attempting to check out a connection. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionCheckOutFailedEvent(_ConnectionDurationEvent): - """Published when the driver's attempt to check out a connection fails. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param reason: A reason explaining why connection check out failed. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__reason",) - - def __init__(self, address: _Address, reason: str, duration: Optional[float]) -> None: - super().__init__(address=address, connection_id=0, duration=duration) - self.__reason = reason - - @property - def reason(self) -> str: - """A reason explaining why connection check out failed. - - The reason must be one of the strings from the - :class:`ConnectionCheckOutFailedReason` enum. - """ - return self.__reason - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__reason!r}, {self.duration!r})" - - -class ConnectionCheckedOutEvent(_ConnectionDurationEvent): - """Published when the driver successfully checks out a connection. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionCheckedInEvent(_ConnectionIdEvent): - """Published when the driver checks in a Connection into the Pool. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class _ServerEvent: - """Base class for server events.""" - - __slots__ = ("__server_address", "__topology_id") - - def __init__(self, server_address: _Address, topology_id: ObjectId) -> None: - self.__server_address = server_address - self.__topology_id = topology_id - - @property - def server_address(self) -> _Address: - """The address (host, port) pair of the server""" - return self.__server_address - - @property - def topology_id(self) -> ObjectId: - """A unique identifier for the topology this server is a part of.""" - return self.__topology_id - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} {self.server_address} topology_id: {self.topology_id}>" - - -class ServerDescriptionChangedEvent(_ServerEvent): - """Published when server description changes. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__previous_description", "__new_description") - - def __init__( - self, - previous_description: ServerDescription, - new_description: ServerDescription, - *args: Any, - ) -> None: - super().__init__(*args) - self.__previous_description = previous_description - self.__new_description = new_description - - @property - def previous_description(self) -> ServerDescription: - """The previous - :class:`~pymongo.server_description.ServerDescription`. - """ - return self.__previous_description - - @property - def new_description(self) -> ServerDescription: - """The new - :class:`~pymongo.server_description.ServerDescription`. - """ - return self.__new_description - - def __repr__(self) -> str: - return "<{} {} changed from: {}, to: {}>".format( - self.__class__.__name__, - self.server_address, - self.previous_description, - self.new_description, - ) - - -class ServerOpeningEvent(_ServerEvent): - """Published when server is initialized. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class ServerClosedEvent(_ServerEvent): - """Published when server is closed. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class TopologyEvent: - """Base class for topology description events.""" - - __slots__ = ("__topology_id",) - - def __init__(self, topology_id: ObjectId) -> None: - self.__topology_id = topology_id - - @property - def topology_id(self) -> ObjectId: - """A unique identifier for the topology this server is a part of.""" - return self.__topology_id - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} topology_id: {self.topology_id}>" - - -class TopologyDescriptionChangedEvent(TopologyEvent): - """Published when the topology description changes. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__previous_description", "__new_description") - - def __init__( - self, - previous_description: TopologyDescription, - new_description: TopologyDescription, - *args: Any, - ) -> None: - super().__init__(*args) - self.__previous_description = previous_description - self.__new_description = new_description - - @property - def previous_description(self) -> TopologyDescription: - """The previous - :class:`~pymongo.topology_description.TopologyDescription`. - """ - return self.__previous_description - - @property - def new_description(self) -> TopologyDescription: - """The new - :class:`~pymongo.topology_description.TopologyDescription`. - """ - return self.__new_description - - def __repr__(self) -> str: - return "<{} topology_id: {} changed from: {}, to: {}>".format( - self.__class__.__name__, - self.topology_id, - self.previous_description, - self.new_description, - ) - - -class TopologyOpenedEvent(TopologyEvent): - """Published when the topology is initialized. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class TopologyClosedEvent(TopologyEvent): - """Published when the topology is closed. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class _ServerHeartbeatEvent: - """Base class for server heartbeat events.""" - - __slots__ = ("__connection_id", "__awaited") - - def __init__(self, connection_id: _Address, awaited: bool = False) -> None: - self.__connection_id = connection_id - self.__awaited = awaited - - @property - def connection_id(self) -> _Address: - """The address (host, port) of the server this heartbeat was sent - to. - """ - return self.__connection_id - - @property - def awaited(self) -> bool: - """Whether the heartbeat was issued as an awaitable hello command. - - .. versionadded:: 4.6 - """ - return self.__awaited - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} {self.connection_id} awaited: {self.awaited}>" - - -class ServerHeartbeatStartedEvent(_ServerHeartbeatEvent): - """Published when a heartbeat is started. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent): - """Fired when the server heartbeat succeeds. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__duration", "__reply") - - def __init__( - self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False - ) -> None: - super().__init__(connection_id, awaited) - self.__duration = duration - self.__reply = reply - - @property - def duration(self) -> float: - """The duration of this heartbeat in microseconds.""" - return self.__duration - - @property - def reply(self) -> Hello: - """An instance of :class:`~pymongo.hello.Hello`.""" - return self.__reply - - @property - def awaited(self) -> bool: - """Whether the heartbeat was awaited. - - If true, then :meth:`duration` reflects the sum of the round trip time - to the server and the time that the server waited before sending a - response. - - .. versionadded:: 3.11 - """ - return super().awaited - - def __repr__(self) -> str: - return "<{} {} duration: {}, awaited: {}, reply: {}>".format( - self.__class__.__name__, - self.connection_id, - self.duration, - self.awaited, - self.reply, - ) - - -class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent): - """Fired when the server heartbeat fails, either with an "ok: 0" - or a socket exception. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__duration", "__reply") - - def __init__( - self, duration: float, reply: Exception, connection_id: _Address, awaited: bool = False - ) -> None: - super().__init__(connection_id, awaited) - self.__duration = duration - self.__reply = reply - - @property - def duration(self) -> float: - """The duration of this heartbeat in microseconds.""" - return self.__duration - - @property - def reply(self) -> Exception: - """A subclass of :exc:`Exception`.""" - return self.__reply - - @property - def awaited(self) -> bool: - """Whether the heartbeat was awaited. - - If true, then :meth:`duration` reflects the sum of the round trip time - to the server and the time that the server waited before sending a - response. - - .. versionadded:: 3.11 - """ - return super().awaited - - def __repr__(self) -> str: - return "<{} {} duration: {}, awaited: {}, reply: {!r}>".format( - self.__class__.__name__, - self.connection_id, - self.duration, - self.awaited, - self.reply, - ) - - -class _EventListeners: - """Configure event listeners for a client instance. - - Any event listeners registered globally are included by default. - - :param listeners: A list of event listeners. - """ - - def __init__(self, listeners: Optional[Sequence[_EventListener]]): - self.__command_listeners = _LISTENERS.command_listeners[:] - self.__server_listeners = _LISTENERS.server_listeners[:] - lst = _LISTENERS.server_heartbeat_listeners - self.__server_heartbeat_listeners = lst[:] - self.__topology_listeners = _LISTENERS.topology_listeners[:] - self.__cmap_listeners = _LISTENERS.cmap_listeners[:] - if listeners is not None: - for lst in listeners: - if isinstance(lst, CommandListener): - self.__command_listeners.append(lst) - if isinstance(lst, ServerListener): - self.__server_listeners.append(lst) - if isinstance(lst, ServerHeartbeatListener): - self.__server_heartbeat_listeners.append(lst) - if isinstance(lst, TopologyListener): - self.__topology_listeners.append(lst) - if isinstance(lst, ConnectionPoolListener): - self.__cmap_listeners.append(lst) - self.__enabled_for_commands = bool(self.__command_listeners) - self.__enabled_for_server = bool(self.__server_listeners) - self.__enabled_for_server_heartbeat = bool(self.__server_heartbeat_listeners) - self.__enabled_for_topology = bool(self.__topology_listeners) - self.__enabled_for_cmap = bool(self.__cmap_listeners) - - @property - def enabled_for_commands(self) -> bool: - """Are any CommandListener instances registered?""" - return self.__enabled_for_commands - - @property - def enabled_for_server(self) -> bool: - """Are any ServerListener instances registered?""" - return self.__enabled_for_server - - @property - def enabled_for_server_heartbeat(self) -> bool: - """Are any ServerHeartbeatListener instances registered?""" - return self.__enabled_for_server_heartbeat - - @property - def enabled_for_topology(self) -> bool: - """Are any TopologyListener instances registered?""" - return self.__enabled_for_topology - - @property - def enabled_for_cmap(self) -> bool: - """Are any ConnectionPoolListener instances registered?""" - return self.__enabled_for_cmap - - def event_listeners(self) -> list[_EventListeners]: - """List of registered event listeners.""" - return ( - self.__command_listeners - + self.__server_heartbeat_listeners - + self.__server_listeners - + self.__topology_listeners - + self.__cmap_listeners - ) - - def publish_command_start( - self, - command: _DocumentOut, - database_name: str, - request_id: int, - connection_id: _Address, - server_connection_id: Optional[int], - op_id: Optional[int] = None, - service_id: Optional[ObjectId] = None, - ) -> None: - """Publish a CommandStartedEvent to all command listeners. - - :param command: The command document. - :param database_name: The name of the database this command was run - against. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this - command was sent to. - :param op_id: The (optional) operation id for this operation. - :param service_id: The service_id this command was sent to, or ``None``. - """ - if op_id is None: - op_id = request_id - event = CommandStartedEvent( - command, - database_name, - request_id, - connection_id, - op_id, - service_id=service_id, - server_connection_id=server_connection_id, - ) - for subscriber in self.__command_listeners: - try: - subscriber.started(event) - except Exception: - _handle_exception() - - def publish_command_success( - self, - duration: timedelta, - reply: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - server_connection_id: Optional[int], - op_id: Optional[int] = None, - service_id: Optional[ObjectId] = None, - speculative_hello: bool = False, - database_name: str = "", - ) -> None: - """Publish a CommandSucceededEvent to all command listeners. - - :param duration: The command duration as a datetime.timedelta. - :param reply: The server reply document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this - command was sent to. - :param op_id: The (optional) operation id for this operation. - :param service_id: The service_id this command was sent to, or ``None``. - :param speculative_hello: Was the command sent with speculative auth? - :param database_name: The database this command was sent to, or ``""``. - """ - if op_id is None: - op_id = request_id - if speculative_hello: - # Redact entire response when the command started contained - # speculativeAuthenticate. - reply = {} - event = CommandSucceededEvent( - duration, - reply, - command_name, - request_id, - connection_id, - op_id, - service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - for subscriber in self.__command_listeners: - try: - subscriber.succeeded(event) - except Exception: - _handle_exception() - - def publish_command_failure( - self, - duration: timedelta, - failure: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - server_connection_id: Optional[int], - op_id: Optional[int] = None, - service_id: Optional[ObjectId] = None, - database_name: str = "", - ) -> None: - """Publish a CommandFailedEvent to all command listeners. - - :param duration: The command duration as a datetime.timedelta. - :param failure: The server reply document or failure description - document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this - command was sent to. - :param op_id: The (optional) operation id for this operation. - :param service_id: The service_id this command was sent to, or ``None``. - :param database_name: The database this command was sent to, or ``""``. - """ - if op_id is None: - op_id = request_id - event = CommandFailedEvent( - duration, - failure, - command_name, - request_id, - connection_id, - op_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - for subscriber in self.__command_listeners: - try: - subscriber.failed(event) - except Exception: - _handle_exception() - - def publish_server_heartbeat_started(self, connection_id: _Address, awaited: bool) -> None: - """Publish a ServerHeartbeatStartedEvent to all server heartbeat - listeners. - - :param connection_id: The address (host, port) pair of the connection. - :param awaited: True if this heartbeat is part of an awaitable hello command. - """ - event = ServerHeartbeatStartedEvent(connection_id, awaited) - for subscriber in self.__server_heartbeat_listeners: - try: - subscriber.started(event) - except Exception: - _handle_exception() - - def publish_server_heartbeat_succeeded( - self, connection_id: _Address, duration: float, reply: Hello, awaited: bool - ) -> None: - """Publish a ServerHeartbeatSucceededEvent to all server heartbeat - listeners. - - :param connection_id: The address (host, port) pair of the connection. - :param duration: The execution time of the event in the highest possible - resolution for the platform. - :param reply: The command reply. - :param awaited: True if the response was awaited. - """ - event = ServerHeartbeatSucceededEvent(duration, reply, connection_id, awaited) - for subscriber in self.__server_heartbeat_listeners: - try: - subscriber.succeeded(event) - except Exception: - _handle_exception() - - def publish_server_heartbeat_failed( - self, connection_id: _Address, duration: float, reply: Exception, awaited: bool - ) -> None: - """Publish a ServerHeartbeatFailedEvent to all server heartbeat - listeners. - - :param connection_id: The address (host, port) pair of the connection. - :param duration: The execution time of the event in the highest possible - resolution for the platform. - :param reply: The command reply. - :param awaited: True if the response was awaited. - """ - event = ServerHeartbeatFailedEvent(duration, reply, connection_id, awaited) - for subscriber in self.__server_heartbeat_listeners: - try: - subscriber.failed(event) - except Exception: - _handle_exception() - - def publish_server_opened(self, server_address: _Address, topology_id: ObjectId) -> None: - """Publish a ServerOpeningEvent to all server listeners. - - :param server_address: The address (host, port) pair of the server. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = ServerOpeningEvent(server_address, topology_id) - for subscriber in self.__server_listeners: - try: - subscriber.opened(event) - except Exception: - _handle_exception() - - def publish_server_closed(self, server_address: _Address, topology_id: ObjectId) -> None: - """Publish a ServerClosedEvent to all server listeners. - - :param server_address: The address (host, port) pair of the server. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = ServerClosedEvent(server_address, topology_id) - for subscriber in self.__server_listeners: - try: - subscriber.closed(event) - except Exception: - _handle_exception() - - def publish_server_description_changed( - self, - previous_description: ServerDescription, - new_description: ServerDescription, - server_address: _Address, - topology_id: ObjectId, - ) -> None: - """Publish a ServerDescriptionChangedEvent to all server listeners. - - :param previous_description: The previous server description. - :param server_address: The address (host, port) pair of the server. - :param new_description: The new server description. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = ServerDescriptionChangedEvent( - previous_description, new_description, server_address, topology_id - ) - for subscriber in self.__server_listeners: - try: - subscriber.description_changed(event) - except Exception: - _handle_exception() - - def publish_topology_opened(self, topology_id: ObjectId) -> None: - """Publish a TopologyOpenedEvent to all topology listeners. - - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = TopologyOpenedEvent(topology_id) - for subscriber in self.__topology_listeners: - try: - subscriber.opened(event) - except Exception: - _handle_exception() - - def publish_topology_closed(self, topology_id: ObjectId) -> None: - """Publish a TopologyClosedEvent to all topology listeners. - - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = TopologyClosedEvent(topology_id) - for subscriber in self.__topology_listeners: - try: - subscriber.closed(event) - except Exception: - _handle_exception() - - def publish_topology_description_changed( - self, - previous_description: TopologyDescription, - new_description: TopologyDescription, - topology_id: ObjectId, - ) -> None: - """Publish a TopologyDescriptionChangedEvent to all topology listeners. - - :param previous_description: The previous topology description. - :param new_description: The new topology description. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = TopologyDescriptionChangedEvent(previous_description, new_description, topology_id) - for subscriber in self.__topology_listeners: - try: - subscriber.description_changed(event) - except Exception: - _handle_exception() - - def publish_pool_created(self, address: _Address, options: dict[str, Any]) -> None: - """Publish a :class:`PoolCreatedEvent` to all pool listeners.""" - event = PoolCreatedEvent(address, options) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_created(event) - except Exception: - _handle_exception() - - def publish_pool_ready(self, address: _Address) -> None: - """Publish a :class:`PoolReadyEvent` to all pool listeners.""" - event = PoolReadyEvent(address) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_ready(event) - except Exception: - _handle_exception() - - def publish_pool_cleared( - self, - address: _Address, - service_id: Optional[ObjectId], - interrupt_connections: bool = False, - ) -> None: - """Publish a :class:`PoolClearedEvent` to all pool listeners.""" - event = PoolClearedEvent(address, service_id, interrupt_connections) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_cleared(event) - except Exception: - _handle_exception() - - def publish_pool_closed(self, address: _Address) -> None: - """Publish a :class:`PoolClosedEvent` to all pool listeners.""" - event = PoolClosedEvent(address) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_closed(event) - except Exception: - _handle_exception() - - def publish_connection_created(self, address: _Address, connection_id: int) -> None: - """Publish a :class:`ConnectionCreatedEvent` to all connection - listeners. - """ - event = ConnectionCreatedEvent(address, connection_id) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_created(event) - except Exception: - _handle_exception() - - def publish_connection_ready( - self, address: _Address, connection_id: int, duration: float - ) -> None: - """Publish a :class:`ConnectionReadyEvent` to all connection listeners.""" - event = ConnectionReadyEvent(address, connection_id, duration) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_ready(event) - except Exception: - _handle_exception() - - def publish_connection_closed(self, address: _Address, connection_id: int, reason: str) -> None: - """Publish a :class:`ConnectionClosedEvent` to all connection - listeners. - """ - event = ConnectionClosedEvent(address, connection_id, reason) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_closed(event) - except Exception: - _handle_exception() - - def publish_connection_check_out_started(self, address: _Address) -> None: - """Publish a :class:`ConnectionCheckOutStartedEvent` to all connection - listeners. - """ - event = ConnectionCheckOutStartedEvent(address) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_check_out_started(event) - except Exception: - _handle_exception() - - def publish_connection_check_out_failed( - self, address: _Address, reason: str, duration: float - ) -> None: - """Publish a :class:`ConnectionCheckOutFailedEvent` to all connection - listeners. - """ - event = ConnectionCheckOutFailedEvent(address, reason, duration) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_check_out_failed(event) - except Exception: - _handle_exception() - - def publish_connection_checked_out( - self, address: _Address, connection_id: int, duration: float - ) -> None: - """Publish a :class:`ConnectionCheckedOutEvent` to all connection - listeners. - """ - event = ConnectionCheckedOutEvent(address, connection_id, duration) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_checked_out(event) - except Exception: - _handle_exception() - - def publish_connection_checked_in(self, address: _Address, connection_id: int) -> None: - """Publish a :class:`ConnectionCheckedInEvent` to all connection - listeners. - """ - event = ConnectionCheckedInEvent(address, connection_id) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_checked_in(event) - except Exception: - _handle_exception() diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index 25fffaca1..ff43a5ffc 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -33,20 +33,18 @@ from typing import ( ) from bson import _decode_all_selective -from pymongo import _csot -from pymongo.asynchronous import helpers as _async_helpers -from pymongo.asynchronous import message as _async_message -from pymongo.asynchronous.common import MAX_MESSAGE_SIZE -from pymongo.asynchronous.compression_support import _NO_COMPRESSION, decompress -from pymongo.asynchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.asynchronous.message import _UNPACK_REPLY, _OpMsg, _OpReply -from pymongo.asynchronous.monitoring import _is_speculative_authenticate +from pymongo import _csot, helpers_shared, message +from pymongo.common import MAX_MESSAGE_SIZE +from pymongo.compression_support import _NO_COMPRESSION, decompress from pymongo.errors import ( NotPrimaryError, OperationFailure, ProtocolError, _OperationCancelled, ) +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply +from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( _POLL_TIMEOUT, _UNPACK_COMPRESSION_HEADER, @@ -58,27 +56,27 @@ from pymongo.socket_checker import _errno_from_exception if TYPE_CHECKING: from bson import CodecOptions - from pymongo.asynchronous.client_session import ClientSession - from pymongo.asynchronous.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.monitoring import _EventListeners - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.read_preferences import _ServerMode - from pymongo.asynchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.monitoring import _EventListeners from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode + from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType from pymongo.write_concern import WriteConcern _IS_SYNC = False async def command( - conn: Connection, + conn: AsyncConnection, dbname: str, spec: MutableMapping[str, Any], is_mongos: bool, read_preference: Optional[_ServerMode], codec_options: CodecOptions[_DocumentType], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], client: Optional[AsyncMongoClient], check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, @@ -97,13 +95,13 @@ async def command( ) -> _DocumentType: """Execute a command over the socket, or raise socket.error. - :param conn: a Connection instance + :param conn: a AsyncConnection instance :param dbname: name of the database on which to run the command :param spec: a command document as an ordered dict type, eg SON. :param is_mongos: are we connected to a mongos? :param read_preference: a read preference :param codec_options: a CodecOptions instance - :param session: optional ClientSession instance. + :param session: optional AsyncClientSession instance. :param client: optional AsyncMongoClient instance for updating $clusterTime. :param check: raise OperationFailure if there are errors :param allowable_errors: errors to ignore if `check` is True @@ -130,7 +128,7 @@ async def command( orig = spec if is_mongos and not use_op_msg: assert read_preference is not None - spec = _async_message._maybe_add_read_preference(spec, read_preference) + spec = message._maybe_add_read_preference(spec, read_preference) if read_concern and not (session and session.in_transaction): if read_concern.level: spec["readConcern"] = read_concern.document @@ -158,22 +156,20 @@ async def command( if use_op_msg: flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 - request_id, msg, size, max_doc_size = _async_message._op_msg( + request_id, msg, size, max_doc_size = message._op_msg( flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx ) # If this is an unacknowledged write then make sure the encoded doc(s) # are small enough, otherwise rely on the server to return an error. if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: - _async_message._raise_document_too_large(name, size, max_bson_size) + message._raise_document_too_large(name, size, max_bson_size) else: - request_id, msg, size = _async_message._query( + request_id, msg, size = message._query( 0, ns, 0, -1, spec, None, codec_options, compression_ctx ) - if max_bson_size is not None and size > max_bson_size + _async_message._COMMAND_OVERHEAD: - _async_message._raise_document_too_large( - name, size, max_bson_size + _async_message._COMMAND_OVERHEAD - ) + if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: + message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) if client is not None: if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( @@ -220,7 +216,7 @@ async def command( if client: await client._process_response(response_doc, session) if check: - _async_helpers._check_command_response( + helpers_shared._check_command_response( response_doc, conn.max_wire_version, allowable_errors, @@ -231,7 +227,7 @@ async def command( if isinstance(exc, (NotPrimaryError, OperationFailure)): failure: _DocumentOut = exc.details # type: ignore[assignment] else: - failure = _async_message._convert_exception(exc) + failure = message._convert_exception(exc) if client is not None: if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( @@ -310,7 +306,7 @@ async def command( async def receive_message( - conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE + conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE ) -> Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise socket.error.""" if _csot.get_timeout(): @@ -355,7 +351,7 @@ async def receive_message( return unpack_reply(data) -async def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: +async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None: """Block until at least one byte is read, or a timeout, or a cancel.""" sock = conn.conn timed_out = False @@ -390,7 +386,7 @@ async def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: async def _receive_data_on_socket( - conn: Connection, length: int, deadline: Optional[float] + conn: AsyncConnection, length: int, deadline: Optional[float] ) -> memoryview: buf = bytearray(length) mv = memoryview(buf) diff --git a/pymongo/asynchronous/operations.py b/pymongo/asynchronous/operations.py deleted file mode 100644 index d4beff759..000000000 --- a/pymongo/asynchronous/operations.py +++ /dev/null @@ -1,625 +0,0 @@ -# Copyright 2015-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Operation class definitions.""" -from __future__ import annotations - -import enum -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Mapping, - Optional, - Sequence, - Tuple, - Union, -) - -from bson.raw_bson import RawBSONDocument -from pymongo.asynchronous import helpers -from pymongo.asynchronous.collation import validate_collation_or_none -from pymongo.asynchronous.common import validate_is_mapping, validate_list -from pymongo.asynchronous.helpers import _gen_index_name, _index_document, _index_list -from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _Pipeline -from pymongo.write_concern import validate_boolean - -if TYPE_CHECKING: - from pymongo.asynchronous.bulk import _Bulk - -_IS_SYNC = False - -# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary -_IndexList = Union[ - Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any] -] -_IndexKeyHint = Union[str, _IndexList] - - -class _Op(str, enum.Enum): - ABORT = "abortTransaction" - AGGREGATE = "aggregate" - COMMIT = "commitTransaction" - COUNT = "count" - CREATE = "create" - CREATE_INDEXES = "createIndexes" - CREATE_SEARCH_INDEXES = "createSearchIndexes" - DELETE = "delete" - DISTINCT = "distinct" - DROP = "drop" - DROP_DATABASE = "dropDatabase" - DROP_INDEXES = "dropIndexes" - DROP_SEARCH_INDEXES = "dropSearchIndexes" - END_SESSIONS = "endSessions" - FIND_AND_MODIFY = "findAndModify" - FIND = "find" - INSERT = "insert" - LIST_COLLECTIONS = "listCollections" - LIST_INDEXES = "listIndexes" - LIST_SEARCH_INDEX = "listSearchIndexes" - LIST_DATABASES = "listDatabases" - UPDATE = "update" - UPDATE_INDEX = "updateIndex" - UPDATE_SEARCH_INDEX = "updateSearchIndex" - RENAME = "rename" - GETMORE = "getMore" - KILL_CURSORS = "killCursors" - TEST = "testOperation" - - -class InsertOne(Generic[_DocumentType]): - """Represents an insert_one operation.""" - - __slots__ = ("_doc",) - - def __init__(self, document: _DocumentType) -> None: - """Create an InsertOne instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`. - - :param document: The document to insert. If the document is missing an - _id field one will be added. - """ - self._doc = document - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_insert(self._doc) # type: ignore[arg-type] - - def __repr__(self) -> str: - return f"InsertOne({self._doc!r})" - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return other._doc == self._doc - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -class DeleteOne: - """Represents a delete_one operation.""" - - __slots__ = ("_filter", "_collation", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - collation: Optional[_CollationIn] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create a DeleteOne instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`. - - :param filter: A query that matches the document to delete. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.AsyncCollection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.4 and above. - - .. versionchanged:: 3.11 - Added the ``hint`` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - if filter is not None: - validate_is_mapping("filter", filter) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - self._filter = filter - self._collation = collation - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_delete( - self._filter, - 1, - collation=validate_collation_or_none(self._collation), - hint=self._hint, - ) - - def __repr__(self) -> str: - return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})" - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return (other._filter, other._collation, other._hint) == ( - self._filter, - self._collation, - self._hint, - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -class DeleteMany: - """Represents a delete_many operation.""" - - __slots__ = ("_filter", "_collation", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - collation: Optional[_CollationIn] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create a DeleteMany instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`. - - :param filter: A query that matches the documents to delete. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.AsyncCollection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.4 and above. - - .. versionchanged:: 3.11 - Added the ``hint`` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - if filter is not None: - validate_is_mapping("filter", filter) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - self._filter = filter - self._collation = collation - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_delete( - self._filter, - 0, - collation=validate_collation_or_none(self._collation), - hint=self._hint, - ) - - def __repr__(self) -> str: - return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})" - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return (other._filter, other._collation, other._hint) == ( - self._filter, - self._collation, - self._hint, - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -class ReplaceOne(Generic[_DocumentType]): - """Represents a replace_one operation.""" - - __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - replacement: Union[_DocumentType, RawBSONDocument], - upsert: bool = False, - collation: Optional[_CollationIn] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create a ReplaceOne instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`. - - :param filter: A query that matches the document to replace. - :param replacement: The new document. - :param upsert: If ``True``, perform an insert if no documents - match the filter. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.AsyncCollection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.2 and above. - - .. versionchanged:: 3.11 - Added the ``hint`` option. - .. versionchanged:: 3.5 - Added the ``collation`` option. - """ - if filter is not None: - validate_is_mapping("filter", filter) - if upsert is not None: - validate_boolean("upsert", upsert) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - self._filter = filter - self._doc = replacement - self._upsert = upsert - self._collation = collation - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_replace( - self._filter, - self._doc, - self._upsert, - collation=validate_collation_or_none(self._collation), - hint=self._hint, - ) - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return ( - other._filter, - other._doc, - other._upsert, - other._collation, - other._hint, - ) == ( - self._filter, - self._doc, - self._upsert, - self._collation, - other._hint, - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __repr__(self) -> str: - return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format( - self.__class__.__name__, - self._filter, - self._doc, - self._upsert, - self._collation, - self._hint, - ) - - -class _UpdateOp: - """Private base class for update operations.""" - - __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - doc: Union[Mapping[str, Any], _Pipeline], - upsert: bool, - collation: Optional[_CollationIn], - array_filters: Optional[list[Mapping[str, Any]]], - hint: Optional[_IndexKeyHint], - ): - if filter is not None: - validate_is_mapping("filter", filter) - if upsert is not None: - validate_boolean("upsert", upsert) - if array_filters is not None: - validate_list("array_filters", array_filters) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - - self._filter = filter - self._doc = doc - self._upsert = upsert - self._collation = collation - self._array_filters = array_filters - - def __eq__(self, other: object) -> bool: - if isinstance(other, type(self)): - return ( - other._filter, - other._doc, - other._upsert, - other._collation, - other._array_filters, - other._hint, - ) == ( - self._filter, - self._doc, - self._upsert, - self._collation, - self._array_filters, - self._hint, - ) - return NotImplemented - - def __repr__(self) -> str: - return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( - self.__class__.__name__, - self._filter, - self._doc, - self._upsert, - self._collation, - self._array_filters, - self._hint, - ) - - -class UpdateOne(_UpdateOp): - """Represents an update_one operation.""" - - __slots__ = () - - def __init__( - self, - filter: Mapping[str, Any], - update: Union[Mapping[str, Any], _Pipeline], - upsert: bool = False, - collation: Optional[_CollationIn] = None, - array_filters: Optional[list[Mapping[str, Any]]] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Represents an update_one operation. - - For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`. - - :param filter: A query that matches the document to update. - :param update: The modifications to apply. - :param upsert: If ``True``, perform an insert if no documents - match the filter. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param array_filters: A list of filters specifying which - array elements an update should apply. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.AsyncCollection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.2 and above. - - .. versionchanged:: 3.11 - Added the `hint` option. - .. versionchanged:: 3.9 - Added the ability to accept a pipeline as the `update`. - .. versionchanged:: 3.6 - Added the `array_filters` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - super().__init__(filter, update, upsert, collation, array_filters, hint) - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_update( - self._filter, - self._doc, - False, - self._upsert, - collation=validate_collation_or_none(self._collation), - array_filters=self._array_filters, - hint=self._hint, - ) - - -class UpdateMany(_UpdateOp): - """Represents an update_many operation.""" - - __slots__ = () - - def __init__( - self, - filter: Mapping[str, Any], - update: Union[Mapping[str, Any], _Pipeline], - upsert: bool = False, - collation: Optional[_CollationIn] = None, - array_filters: Optional[list[Mapping[str, Any]]] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create an UpdateMany instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`. - - :param filter: A query that matches the documents to update. - :param update: The modifications to apply. - :param upsert: If ``True``, perform an insert if no documents - match the filter. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param array_filters: A list of filters specifying which - array elements an update should apply. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.AsyncCollection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.2 and above. - - .. versionchanged:: 3.11 - Added the `hint` option. - .. versionchanged:: 3.9 - Added the ability to accept a pipeline as the `update`. - .. versionchanged:: 3.6 - Added the `array_filters` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - super().__init__(filter, update, upsert, collation, array_filters, hint) - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_update( - self._filter, - self._doc, - True, - self._upsert, - collation=validate_collation_or_none(self._collation), - array_filters=self._array_filters, - hint=self._hint, - ) - - -class IndexModel: - """Represents an index to create.""" - - __slots__ = ("__document",) - - def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None: - """Create an Index instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.create_indexes`. - - Takes either a single key or a list containing (key, direction) pairs - or keys. If no direction is given, :data:`~pymongo.ASCENDING` will - be assumed. - The key(s) must be an instance of :class:`str`, and the direction(s) must - be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, - :data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`, - :data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`). - - Valid options include, but are not limited to: - - - `name`: custom name to use for this index - if none is - given, a name will be generated. - - `unique`: if ``True``, creates a uniqueness constraint on the index. - - `background`: if ``True``, this index should be created in the - background. - - `sparse`: if ``True``, omit from the index any documents that lack - the indexed field. - - `bucketSize`: for use with geoHaystack indexes. - Number of documents to group together within a certain proximity - to a given longitude and latitude. - - `min`: minimum value for keys in a :data:`~pymongo.GEO2D` - index. - - `max`: maximum value for keys in a :data:`~pymongo.GEO2D` - index. - - `expireAfterSeconds`: Used to create an expiring (TTL) - collection. MongoDB will automatically delete documents from - this collection after seconds. The indexed field must - be a UTC datetime or the data will not expire. - - `partialFilterExpression`: A document that specifies a filter for - a partial index. - - `collation`: An instance of :class:`~pymongo.collation.Collation` - that specifies the collation to use. - - `wildcardProjection`: Allows users to include or exclude specific - field paths from a `wildcard index`_ using the { "$**" : 1} key - pattern. Requires MongoDB >= 4.2. - - `hidden`: if ``True``, this index will be hidden from the query - planner and will not be evaluated as part of query plan - selection. Requires MongoDB >= 4.4. - - See the MongoDB documentation for a full list of supported options by - server version. - - :param keys: a single key or a list containing (key, direction) pairs - or keys specifying the index to create. - :param kwargs: any additional index creation - options (see the above list) should be passed as keyword - arguments. - - .. versionchanged:: 3.11 - Added the ``hidden`` option. - .. versionchanged:: 3.2 - Added the ``partialFilterExpression`` option to support partial - indexes. - - .. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/ - """ - keys = _index_list(keys) - if kwargs.get("name") is None: - kwargs["name"] = _gen_index_name(keys) - kwargs["key"] = _index_document(keys) - collation = validate_collation_or_none(kwargs.pop("collation", None)) - self.__document = kwargs - if collation is not None: - self.__document["collation"] = collation - - @property - def document(self) -> dict[str, Any]: - """An index document suitable for passing to the createIndexes - command. - """ - return self.__document - - -class SearchIndexModel: - """Represents a search index to create.""" - - __slots__ = ("__document",) - - def __init__( - self, - definition: Mapping[str, Any], - name: Optional[str] = None, - type: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Create a Search Index instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.create_search_index` and :meth:`~pymongo.collection.AsyncCollection.create_search_indexes`. - - :param definition: The definition for this index. - :param name: The name for this index, if present. - :param type: The type for this index which defaults to "search". Alternative values include "vectorSearch". - :param kwargs: Keyword arguments supplying any additional options. - - .. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster. - .. versionadded:: 4.5 - .. versionchanged:: 4.7 - Added the type and kwargs arguments. - """ - self.__document: dict[str, Any] = {} - if name is not None: - self.__document["name"] = name - self.__document["definition"] = definition - if type is not None: - self.__document["type"] = type - self.__document.update(kwargs) - - @property - def document(self) -> Mapping[str, Any]: - """The document for this index.""" - return self.__document diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 27597a6db..57265a7d5 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -16,17 +16,14 @@ from __future__ import annotations import collections import contextlib -import copy import logging import os -import platform import socket import ssl import sys import threading import time import weakref -from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -39,39 +36,18 @@ from typing import ( Union, ) -import bson from bson import DEFAULT_CODEC_OPTIONS -from pymongo import __version__, _csot -from pymongo.asynchronous import helpers +from pymongo import _csot, helpers_shared from pymongo.asynchronous.client_session import _validate_session_write_concern -from pymongo.asynchronous.common import ( +from pymongo.asynchronous.helpers import _handle_reauth +from pymongo.asynchronous.network import command, receive_message +from pymongo.common import ( MAX_BSON_SIZE, - MAX_CONNECTING, - MAX_IDLE_TIME_SEC, MAX_MESSAGE_SIZE, - MAX_POOL_SIZE, MAX_WIRE_VERSION, MAX_WRITE_BATCH_SIZE, - MIN_POOL_SIZE, ORDERED_TYPES, - WAIT_QUEUE_TIMEOUT, ) -from pymongo.asynchronous.hello import Hello -from pymongo.asynchronous.hello_compat import HelloCompat -from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.asynchronous.logger import ( - _CONNECTION_LOGGER, - _ConnectionStatusMessage, - _debug_log, - _verbose_connection_error_reason, -) -from pymongo.asynchronous.monitoring import ( - ConnectionCheckOutFailedReason, - ConnectionClosedReason, - _EventListeners, -) -from pymongo.asynchronous.network import command, receive_message -from pymongo.asynchronous.read_preferences import ReadPreference from pymongo.errors import ( # type:ignore[attr-defined] AutoReconnect, ConfigurationError, @@ -86,8 +62,21 @@ from pymongo.errors import ( # type:ignore[attr-defined] WaitQueueTimeoutError, _CertificateError, ) +from pymongo.hello import Hello, HelloCompat from pymongo.lock import _ACondition, _ALock, _create_lock +from pymongo.logger import ( + _CONNECTION_LOGGER, + _ConnectionStatusMessage, + _debug_log, + _verbose_connection_error_reason, +) +from pymongo.monitoring import ( + ConnectionCheckOutFailedReason, + ConnectionClosedReason, +) from pymongo.network_layer import async_sendall +from pymongo.pool_options import PoolOptions +from pymongo.read_preferences import ReadPreference from pymongo.server_api import _add_to_command from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker @@ -96,22 +85,19 @@ from pymongo.ssl_support import HAS_SNI, SSLError if TYPE_CHECKING: from bson import CodecOptions from bson.objectid import ObjectId - from pymongo.asynchronous.auth import MongoCredential, _AuthContext - from pymongo.asynchronous.client_session import ClientSession - from pymongo.asynchronous.compression_support import ( - CompressionSettings, + from pymongo.asynchronous.auth import _AuthContext + from pymongo.asynchronous.client_session import AsyncClientSession + from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler + from pymongo.compression_support import ( SnappyContext, ZlibContext, ZstdContext, ) - from pymongo.asynchronous.message import _OpMsg, _OpReply - from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler - from pymongo.asynchronous.read_preferences import _ServerMode - from pymongo.asynchronous.typings import ClusterTime, _Address, _CollationIn - from pymongo.driver_info import DriverInfo - from pymongo.pyopenssl_context import SSLContext, _sslConn + from pymongo.message import _OpMsg, _OpReply + from pymongo.pyopenssl_context import _sslConn from pymongo.read_concern import ReadConcern - from pymongo.server_api import ServerApi + from pymongo.read_preferences import _ServerMode + from pymongo.typings import ClusterTime, _Address, _CollationIn from pymongo.write_concern import WriteConcern try: @@ -191,236 +177,6 @@ else: _set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT) -_METADATA: dict[str, Any] = {"driver": {"name": "PyMongo", "version": __version__}} - -if sys.platform.startswith("linux"): - _METADATA["os"] = { - "type": platform.system(), - "name": platform.system(), - "architecture": platform.machine(), - # Kernel version (e.g. 4.4.0-17-generic). - "version": platform.release(), - } -elif sys.platform == "darwin": - _METADATA["os"] = { - "type": platform.system(), - "name": platform.system(), - "architecture": platform.machine(), - # (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin - # kernel version. - "version": platform.mac_ver()[0], - } -elif sys.platform == "win32": - _ver = sys.getwindowsversion() - _METADATA["os"] = { - "type": "Windows", - "name": "Windows", - # Avoid using platform calls, see PYTHON-4455. - "architecture": os.environ.get("PROCESSOR_ARCHITECTURE") or platform.machine(), - # Windows patch level (e.g. 10.0.17763-SP0). - "version": ".".join(map(str, _ver[:3])) + f"-SP{_ver[-1] or '0'}", - } -elif sys.platform.startswith("java"): - _name, _ver, _arch = platform.java_ver()[-1] - _METADATA["os"] = { - # Linux, Windows 7, Mac OS X, etc. - "type": _name, - "name": _name, - # x86, x86_64, AMD64, etc. - "architecture": _arch, - # Linux kernel version, OSX version, etc. - "version": _ver, - } -else: - # Get potential alias (e.g. SunOS 5.11 becomes Solaris 2.11) - _aliased = platform.system_alias(platform.system(), platform.release(), platform.version()) - _METADATA["os"] = { - "type": platform.system(), - "name": " ".join([part for part in _aliased[:2] if part]), - "architecture": platform.machine(), - "version": _aliased[2], - } - -if platform.python_implementation().startswith("PyPy"): - _METADATA["platform"] = " ".join( - ( - platform.python_implementation(), - ".".join(map(str, sys.pypy_version_info)), # type: ignore - "(Python %s)" % ".".join(map(str, sys.version_info)), - ) - ) -elif sys.platform.startswith("java"): - _METADATA["platform"] = " ".join( - ( - platform.python_implementation(), - ".".join(map(str, sys.version_info)), - "(%s)" % " ".join((platform.system(), platform.release())), - ) - ) -else: - _METADATA["platform"] = " ".join( - (platform.python_implementation(), ".".join(map(str, sys.version_info))) - ) - -DOCKER_ENV_PATH = "/.dockerenv" -ENV_VAR_K8S = "KUBERNETES_SERVICE_HOST" - -RUNTIME_NAME_DOCKER = "docker" -ORCHESTRATOR_NAME_K8S = "kubernetes" - - -def get_container_env_info() -> dict[str, str]: - """Returns the runtime and orchestrator of a container. - If neither value is present, the metadata client.env.container field will be omitted.""" - container = {} - - if Path(DOCKER_ENV_PATH).exists(): - container["runtime"] = RUNTIME_NAME_DOCKER - if os.getenv(ENV_VAR_K8S): - container["orchestrator"] = ORCHESTRATOR_NAME_K8S - - return container - - -def _is_lambda() -> bool: - if os.getenv("AWS_LAMBDA_RUNTIME_API"): - return True - env = os.getenv("AWS_EXECUTION_ENV") - if env: - return env.startswith("AWS_Lambda_") - return False - - -def _is_azure_func() -> bool: - return bool(os.getenv("FUNCTIONS_WORKER_RUNTIME")) - - -def _is_gcp_func() -> bool: - return bool(os.getenv("K_SERVICE") or os.getenv("FUNCTION_NAME")) - - -def _is_vercel() -> bool: - return bool(os.getenv("VERCEL")) - - -def _is_faas() -> bool: - return _is_lambda() or _is_azure_func() or _is_gcp_func() or _is_vercel() - - -def _getenv_int(key: str) -> Optional[int]: - """Like os.getenv but returns an int, or None if the value is missing/malformed.""" - val = os.getenv(key) - if not val: - return None - try: - return int(val) - except ValueError: - return None - - -def _metadata_env() -> dict[str, Any]: - env: dict[str, Any] = {} - container = get_container_env_info() - if container: - env["container"] = container - # Skip if multiple (or no) envs are matched. - if (_is_lambda(), _is_azure_func(), _is_gcp_func(), _is_vercel()).count(True) != 1: - return env - if _is_lambda(): - env["name"] = "aws.lambda" - region = os.getenv("AWS_REGION") - if region: - env["region"] = region - memory_mb = _getenv_int("AWS_LAMBDA_FUNCTION_MEMORY_SIZE") - if memory_mb is not None: - env["memory_mb"] = memory_mb - elif _is_azure_func(): - env["name"] = "azure.func" - elif _is_gcp_func(): - env["name"] = "gcp.func" - region = os.getenv("FUNCTION_REGION") - if region: - env["region"] = region - memory_mb = _getenv_int("FUNCTION_MEMORY_MB") - if memory_mb is not None: - env["memory_mb"] = memory_mb - timeout_sec = _getenv_int("FUNCTION_TIMEOUT_SEC") - if timeout_sec is not None: - env["timeout_sec"] = timeout_sec - elif _is_vercel(): - env["name"] = "vercel" - region = os.getenv("VERCEL_REGION") - if region: - env["region"] = region - return env - - -_MAX_METADATA_SIZE = 512 - - -# See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations -def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None: - """Perform metadata truncation.""" - if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: - return - # 1. Omit fields from env except env.name. - env_name = metadata.get("env", {}).get("name") - if env_name: - metadata["env"] = {"name": env_name} - if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: - return - # 2. Omit fields from os except os.type. - os_type = metadata.get("os", {}).get("type") - if os_type: - metadata["os"] = {"type": os_type} - if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: - return - # 3. Omit the env document entirely. - metadata.pop("env", None) - encoded_size = len(bson.encode(metadata)) - if encoded_size <= _MAX_METADATA_SIZE: - return - # 4. Truncate platform. - overflow = encoded_size - _MAX_METADATA_SIZE - plat = metadata.get("platform", "") - if plat: - plat = plat[:-overflow] - if plat: - metadata["platform"] = plat - else: - metadata.pop("platform", None) - encoded_size = len(bson.encode(metadata)) - if encoded_size <= _MAX_METADATA_SIZE: - return - # 5. Truncate driver info. - overflow = encoded_size - _MAX_METADATA_SIZE - driver = metadata.get("driver", {}) - if driver: - # Truncate driver version. - driver_version = driver.get("version")[:-overflow] - if len(driver_version) >= len(_METADATA["driver"]["version"]): - metadata["driver"]["version"] = driver_version - else: - metadata["driver"]["version"] = _METADATA["driver"]["version"] - encoded_size = len(bson.encode(metadata)) - if encoded_size <= _MAX_METADATA_SIZE: - return - # Truncate driver name. - overflow = encoded_size - _MAX_METADATA_SIZE - driver_name = driver.get("name")[:-overflow] - if len(driver_name) >= len(_METADATA["driver"]["name"]): - metadata["driver"]["name"] = driver_name - else: - metadata["driver"]["name"] = _METADATA["driver"]["name"] - - -# If the first getaddrinfo call of this interpreter's life is on a thread, -# while the main thread holds the import lock, getaddrinfo deadlocks trying -# to import the IDNA codec. Import it here, where presumably we're on the -# main thread, to avoid the deadlock. See PYTHON-607. -"foo".encode("idna") - - def _raise_connection_failure( address: Any, error: Exception, @@ -481,238 +237,6 @@ def format_timeout_details(details: Optional[dict[str, float]]) -> str: return result -class PoolOptions: - """Read only connection pool options for an AsyncMongoClient. - - Should not be instantiated directly by application developers. Access - a client's pool options via - :attr:`~pymongo.client_options.ClientOptions.pool_options` instead:: - - pool_opts = client.options.pool_options - pool_opts.max_pool_size - pool_opts.min_pool_size - - """ - - __slots__ = ( - "__max_pool_size", - "__min_pool_size", - "__max_idle_time_seconds", - "__connect_timeout", - "__socket_timeout", - "__wait_queue_timeout", - "__ssl_context", - "__tls_allow_invalid_hostnames", - "__event_listeners", - "__appname", - "__driver", - "__metadata", - "__compression_settings", - "__max_connecting", - "__pause_enabled", - "__server_api", - "__load_balanced", - "__credentials", - ) - - def __init__( - self, - max_pool_size: int = MAX_POOL_SIZE, - min_pool_size: int = MIN_POOL_SIZE, - max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC, - connect_timeout: Optional[float] = None, - socket_timeout: Optional[float] = None, - wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT, - ssl_context: Optional[SSLContext] = None, - tls_allow_invalid_hostnames: bool = False, - event_listeners: Optional[_EventListeners] = None, - appname: Optional[str] = None, - driver: Optional[DriverInfo] = None, - compression_settings: Optional[CompressionSettings] = None, - max_connecting: int = MAX_CONNECTING, - pause_enabled: bool = True, - server_api: Optional[ServerApi] = None, - load_balanced: Optional[bool] = None, - credentials: Optional[MongoCredential] = None, - ): - self.__max_pool_size = max_pool_size - self.__min_pool_size = min_pool_size - self.__max_idle_time_seconds = max_idle_time_seconds - self.__connect_timeout = connect_timeout - self.__socket_timeout = socket_timeout - self.__wait_queue_timeout = wait_queue_timeout - self.__ssl_context = ssl_context - self.__tls_allow_invalid_hostnames = tls_allow_invalid_hostnames - self.__event_listeners = event_listeners - self.__appname = appname - self.__driver = driver - self.__compression_settings = compression_settings - self.__max_connecting = max_connecting - self.__pause_enabled = pause_enabled - self.__server_api = server_api - self.__load_balanced = load_balanced - self.__credentials = credentials - self.__metadata = copy.deepcopy(_METADATA) - if appname: - self.__metadata["application"] = {"name": appname} - - # Combine the "driver" AsyncMongoClient option with PyMongo's info, like: - # { - # 'driver': { - # 'name': 'PyMongo|MyDriver', - # 'version': '4.2.0|1.2.3', - # }, - # 'platform': 'CPython 3.8.0|MyPlatform' - # } - if driver: - if driver.name: - self.__metadata["driver"]["name"] = "{}|{}".format( - _METADATA["driver"]["name"], - driver.name, - ) - if driver.version: - self.__metadata["driver"]["version"] = "{}|{}".format( - _METADATA["driver"]["version"], - driver.version, - ) - if driver.platform: - self.__metadata["platform"] = "{}|{}".format(_METADATA["platform"], driver.platform) - - env = _metadata_env() - if env: - self.__metadata["env"] = env - - _truncate_metadata(self.__metadata) - - @property - def _credentials(self) -> Optional[MongoCredential]: - """A :class:`~pymongo.auth.MongoCredentials` instance or None.""" - return self.__credentials - - @property - def non_default_options(self) -> dict[str, Any]: - """The non-default options this pool was created with. - - Added for CMAP's :class:`PoolCreatedEvent`. - """ - opts = {} - if self.__max_pool_size != MAX_POOL_SIZE: - opts["maxPoolSize"] = self.__max_pool_size - if self.__min_pool_size != MIN_POOL_SIZE: - opts["minPoolSize"] = self.__min_pool_size - if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC: - assert self.__max_idle_time_seconds is not None - opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000 - if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT: - assert self.__wait_queue_timeout is not None - opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000 - if self.__max_connecting != MAX_CONNECTING: - opts["maxConnecting"] = self.__max_connecting - return opts - - @property - def max_pool_size(self) -> float: - """The maximum allowable number of concurrent connections to each - connected server. Requests to a server will block if there are - `maxPoolSize` outstanding connections to the requested server. - Defaults to 100. Cannot be 0. - - When a server's pool has reached `max_pool_size`, operations for that - server block waiting for a socket to be returned to the pool. If - ``waitQueueTimeoutMS`` is set, a blocked operation will raise - :exc:`~pymongo.errors.ConnectionFailure` after a timeout. - By default ``waitQueueTimeoutMS`` is not set. - """ - return self.__max_pool_size - - @property - def min_pool_size(self) -> int: - """The minimum required number of concurrent connections that the pool - will maintain to each connected server. Default is 0. - """ - return self.__min_pool_size - - @property - def max_connecting(self) -> int: - """The maximum number of concurrent connection creation attempts per - pool. Defaults to 2. - """ - return self.__max_connecting - - @property - def pause_enabled(self) -> bool: - return self.__pause_enabled - - @property - def max_idle_time_seconds(self) -> Optional[int]: - """The maximum number of seconds that a connection can remain - idle in the pool before being removed and replaced. Defaults to - `None` (no limit). - """ - return self.__max_idle_time_seconds - - @property - def connect_timeout(self) -> Optional[float]: - """How long a connection can take to be opened before timing out.""" - return self.__connect_timeout - - @property - def socket_timeout(self) -> Optional[float]: - """How long a send or receive on a socket can take before timing out.""" - return self.__socket_timeout - - @property - def wait_queue_timeout(self) -> Optional[int]: - """How long a thread will wait for a socket from the pool if the pool - has no free sockets. - """ - return self.__wait_queue_timeout - - @property - def _ssl_context(self) -> Optional[SSLContext]: - """An SSLContext instance or None.""" - return self.__ssl_context - - @property - def tls_allow_invalid_hostnames(self) -> bool: - """If True skip ssl.match_hostname.""" - return self.__tls_allow_invalid_hostnames - - @property - def _event_listeners(self) -> Optional[_EventListeners]: - """An instance of pymongo.monitoring._EventListeners.""" - return self.__event_listeners - - @property - def appname(self) -> Optional[str]: - """The application name, for sending with hello in server handshake.""" - return self.__appname - - @property - def driver(self) -> Optional[DriverInfo]: - """Driver name and version, for sending with hello in handshake.""" - return self.__driver - - @property - def _compression_settings(self) -> Optional[CompressionSettings]: - return self.__compression_settings - - @property - def metadata(self) -> dict[str, Any]: - """A dict of metadata about the application, driver, os, and platform.""" - return self.__metadata.copy() - - @property - def server_api(self) -> Optional[ServerApi]: - """A pymongo.server_api.ServerApi or None.""" - return self.__server_api - - @property - def load_balanced(self) -> Optional[bool]: - """True if this Pool is configured in load balanced mode.""" - return self.__load_balanced - - class _CancellationContext: def __init__(self) -> None: self._cancelled = False @@ -727,7 +251,7 @@ class _CancellationContext: return self._cancelled -class Connection: +class AsyncConnection: """Store a connection with some metadata. :param conn: a raw connection object @@ -946,7 +470,7 @@ class Connection: self.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response() response_doc = unpacked_docs[0] - helpers._check_command_response(response_doc, self.max_wire_version) + helpers_shared._check_command_response(response_doc, self.max_wire_version) return response_doc @_handle_reauth @@ -962,7 +486,7 @@ class Connection: write_concern: Optional[WriteConcern] = None, parse_write_concern_error: bool = False, collation: Optional[_CollationIn] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, client: Optional[AsyncMongoClient] = None, retryable_write: bool = False, publish_events: bool = True, @@ -982,7 +506,7 @@ class Connection: :param parse_write_concern_error: Whether to parse the ``writeConcernError`` field in the command response. :param collation: The collation for this command. - :param session: optional ClientSession instance. + :param session: optional AsyncClientSession instance. :param client: optional AsyncMongoClient for gossipping $clusterTime. :param retryable_write: True if this command is a retryable write. :param publish_events: Should we publish events for this command? @@ -1099,7 +623,7 @@ class Connection: result = reply.command_response(codec_options) # Raises NotPrimaryError or OperationFailure. - helpers._check_command_response(result, self.max_wire_version) + helpers_shared._check_command_response(result, self.max_wire_version) return result async def authenticate(self, reauthenticate: bool = False) -> None: @@ -1137,7 +661,7 @@ class Connection: ) def validate_session( - self, client: Optional[AsyncMongoClient], session: Optional[ClientSession] + self, client: Optional[AsyncMongoClient], session: Optional[AsyncClientSession] ) -> None: """Validate this session before use with client. @@ -1190,7 +714,7 @@ class Connection: def send_cluster_time( self, command: MutableMapping[str, Any], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], client: Optional[AsyncMongoClient], ) -> None: """Add $clusterTime.""" @@ -1250,7 +774,7 @@ class Connection: return hash(self.conn) def __repr__(self) -> str: - return "Connection({}){} at {}".format( + return "AsyncConnection({}){} at {}".format( repr(self.conn), self.closed and " CLOSED" or "", id(self), @@ -1441,7 +965,7 @@ class Pool: """ :param address: a (hostname, port) tuple :param options: a PoolOptions instance - :param handshake: whether to call hello for each new Connection + :param handshake: whether to call hello for each new AsyncConnection """ if options.pause_enabled: self.state = PoolState.PAUSED @@ -1512,7 +1036,7 @@ class Pool: # Retain references to pinned connections to prevent the CPython GC # from thinking that a cursor's pinned connection can be GC'd when the # cursor is GC'd (see PYTHON-2751). - self.__pinned_sockets: set[Connection] = set() + self.__pinned_sockets: set[AsyncConnection] = set() self.ncursors = 0 self.ntxns = 0 @@ -1700,8 +1224,8 @@ class Pool: self.requests -= 1 self.size_cond.notify() - async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection: - """Connect to Mongo and return a new Connection. + async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnection: + """Connect to Mongo and return a new AsyncConnection. Can raise ConnectionFailure. @@ -1751,7 +1275,7 @@ class Pool: raise - conn = Connection(sock, self, self.address, conn_id) # type: ignore[arg-type] + conn = AsyncConnection(sock, self, self.address, conn_id) # type: ignore[arg-type] async with self.lock: self.active_contexts.add(conn.cancel_context) try: @@ -1771,10 +1295,10 @@ class Pool: @contextlib.asynccontextmanager async def checkout( self, handler: Optional[_MongoClientErrorHandler] = None - ) -> AsyncGenerator[Connection, None]: + ) -> AsyncGenerator[AsyncConnection, None]: """Get a connection from the pool. Use with a "with" statement. - Returns a :class:`Connection` object wrapping a connected + Returns a :class:`AsyncConnection` object wrapping a connected :class:`socket.socket`. This method should always be used in a with-statement:: @@ -1874,8 +1398,8 @@ class Pool: async def _get_conn( self, checkout_started_time: float, handler: Optional[_MongoClientErrorHandler] = None - ) -> Connection: - """Get or create a Connection. Can raise ConnectionFailure.""" + ) -> AsyncConnection: + """Get or create a AsyncConnection. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. # See test.test_client:TestClient.test_fork for an example of # what could go wrong otherwise @@ -1998,7 +1522,7 @@ class Pool: conn.active = True return conn - async def checkin(self, conn: Connection) -> None: + async def checkin(self, conn: AsyncConnection) -> None: """Return the connection to the pool, or if it's closed discard it. :param conn: The connection to check into the pool. @@ -2070,7 +1594,7 @@ class Pool: self.operation_count -= 1 self.size_cond.notify() - def _perished(self, conn: Connection) -> bool: + def _perished(self, conn: AsyncConnection) -> bool: """Return True and close the connection if it is "perished". This side-effecty function checks if this socket has been idle for diff --git a/pymongo/asynchronous/read_preferences.py b/pymongo/asynchronous/read_preferences.py deleted file mode 100644 index 8b6fb6075..000000000 --- a/pymongo/asynchronous/read_preferences.py +++ /dev/null @@ -1,624 +0,0 @@ -# Copyright 2012-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License", -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for choosing which member of a replica set to read from.""" - -from __future__ import annotations - -from collections import abc -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence - -from pymongo.asynchronous import max_staleness_selectors -from pymongo.asynchronous.server_selectors import ( - member_with_tags_server_selector, - secondary_with_tags_server_selector, -) -from pymongo.errors import ConfigurationError - -if TYPE_CHECKING: - from pymongo.asynchronous.server_selectors import Selection - from pymongo.asynchronous.topology_description import TopologyDescription - -_IS_SYNC = False - -_PRIMARY = 0 -_PRIMARY_PREFERRED = 1 -_SECONDARY = 2 -_SECONDARY_PREFERRED = 3 -_NEAREST = 4 - - -_MONGOS_MODES = ( - "primary", - "primaryPreferred", - "secondary", - "secondaryPreferred", - "nearest", -) - -_Hedge = Mapping[str, Any] -_TagSets = Sequence[Mapping[str, Any]] - - -def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]: - """Validate tag sets for a MongoClient.""" - if tag_sets is None: - return tag_sets - - if not isinstance(tag_sets, (list, tuple)): - raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence") - if len(tag_sets) == 0: - raise ValueError( - f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags" - ) - - for tags in tag_sets: - if not isinstance(tags, abc.Mapping): - raise TypeError( - f"Tag set {tags!r} invalid, must be an instance of dict, " - "bson.son.SON or other type that inherits from " - "collection.Mapping" - ) - - return list(tag_sets) - - -def _invalid_max_staleness_msg(max_staleness: Any) -> str: - return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness - - -# Some duplication with common.py to avoid import cycle. -def _validate_max_staleness(max_staleness: Any) -> int: - """Validate max_staleness.""" - if max_staleness == -1: - return -1 - - if not isinstance(max_staleness, int): - raise TypeError(_invalid_max_staleness_msg(max_staleness)) - - if max_staleness <= 0: - raise ValueError(_invalid_max_staleness_msg(max_staleness)) - - return max_staleness - - -def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]: - """Validate hedge.""" - if hedge is None: - return None - - if not isinstance(hedge, dict): - raise TypeError(f"hedge must be a dictionary, not {hedge!r}") - - return hedge - - -class _ServerMode: - """Base class for all read preferences.""" - - __slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge") - - def __init__( - self, - mode: int, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - self.__mongos_mode = _MONGOS_MODES[mode] - self.__mode = mode - self.__tag_sets = _validate_tag_sets(tag_sets) - self.__max_staleness = _validate_max_staleness(max_staleness) - self.__hedge = _validate_hedge(hedge) - - @property - def name(self) -> str: - """The name of this read preference.""" - return self.__class__.__name__ - - @property - def mongos_mode(self) -> str: - """The mongos mode of this read preference.""" - return self.__mongos_mode - - @property - def document(self) -> dict[str, Any]: - """Read preference as a document.""" - doc: dict[str, Any] = {"mode": self.__mongos_mode} - if self.__tag_sets not in (None, [{}]): - doc["tags"] = self.__tag_sets - if self.__max_staleness != -1: - doc["maxStalenessSeconds"] = self.__max_staleness - if self.__hedge not in (None, {}): - doc["hedge"] = self.__hedge - return doc - - @property - def mode(self) -> int: - """The mode of this read preference instance.""" - return self.__mode - - @property - def tag_sets(self) -> _TagSets: - """Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to - read only from members whose ``dc`` tag has the value ``"ny"``. - To specify a priority-order for tag sets, provide a list of - tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag - set, ``{}``, means "read from any member that matches the mode, - ignoring tags." MongoClient tries each set of tags in turn - until it finds a set of tags with at least one matching member. - For example, to only send a query to an analytic node:: - - Nearest(tag_sets=[{"node":"analytics"}]) - - Or using :class:`SecondaryPreferred`:: - - SecondaryPreferred(tag_sets=[{"node":"analytics"}]) - - .. seealso:: `Data-Center Awareness - `_ - """ - return list(self.__tag_sets) if self.__tag_sets else [{}] - - @property - def max_staleness(self) -> int: - """The maximum estimated length of time (in seconds) a replica set - secondary can fall behind the primary in replication before it will - no longer be selected for operations, or -1 for no maximum. - """ - return self.__max_staleness - - @property - def hedge(self) -> Optional[_Hedge]: - """The read preference ``hedge`` parameter. - - A dictionary that configures how the server will perform hedged reads. - It consists of the following keys: - - - ``enabled``: Enables or disables hedged reads in sharded clusters. - - Hedged reads are automatically enabled in MongoDB 4.4+ when using a - ``nearest`` read preference. To explicitly enable hedged reads, set - the ``enabled`` key to ``true``:: - - >>> Nearest(hedge={'enabled': True}) - - To explicitly disable hedged reads, set the ``enabled`` key to - ``False``:: - - >>> Nearest(hedge={'enabled': False}) - - .. versionadded:: 3.11 - """ - return self.__hedge - - @property - def min_wire_version(self) -> int: - """The wire protocol version the server must support. - - Some read preferences impose version requirements on all servers (e.g. - maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5). - - All servers' maxWireVersion must be at least this read preference's - `min_wire_version`, or the driver raises - :exc:`~pymongo.errors.ConfigurationError`. - """ - return 0 if self.__max_staleness == -1 else 5 - - def __repr__(self) -> str: - return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format( - self.name, - self.__tag_sets, - self.__max_staleness, - self.__hedge, - ) - - def __eq__(self, other: Any) -> bool: - if isinstance(other, _ServerMode): - return ( - self.mode == other.mode - and self.tag_sets == other.tag_sets - and self.max_staleness == other.max_staleness - and self.hedge == other.hedge - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __getstate__(self) -> dict[str, Any]: - """Return value of object for pickling. - - Needed explicitly because __slots__() defined. - """ - return { - "mode": self.__mode, - "tag_sets": self.__tag_sets, - "max_staleness": self.__max_staleness, - "hedge": self.__hedge, - } - - def __setstate__(self, value: Mapping[str, Any]) -> None: - """Restore from pickling.""" - self.__mode = value["mode"] - self.__mongos_mode = _MONGOS_MODES[self.__mode] - self.__tag_sets = _validate_tag_sets(value["tag_sets"]) - self.__max_staleness = _validate_max_staleness(value["max_staleness"]) - self.__hedge = _validate_hedge(value["hedge"]) - - def __call__(self, selection: Selection) -> Selection: - return selection - - -class Primary(_ServerMode): - """Primary read preference. - - * When directly connected to one mongod queries are allowed if the server - is standalone or a replica set primary. - * When connected to a mongos queries are sent to the primary of a shard. - * When connected to a replica set queries are sent to the primary of - the replica set. - """ - - __slots__ = () - - def __init__(self) -> None: - super().__init__(_PRIMARY) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to a Selection.""" - return selection.primary_selection - - def __repr__(self) -> str: - return "Primary()" - - def __eq__(self, other: Any) -> bool: - if isinstance(other, _ServerMode): - return other.mode == _PRIMARY - return NotImplemented - - -class PrimaryPreferred(_ServerMode): - """PrimaryPreferred read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are sent to the primary of a shard if - available, otherwise a shard secondary. - * When connected to a replica set queries are sent to the primary if - available, otherwise a secondary. - - .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first - created reads will be routed to an available secondary until the - primary of the replica set is discovered. - - :param tag_sets: The :attr:`~tag_sets` to use if the primary is not - available. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` to use if the primary is not available. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - if selection.primary: - return selection.primary_selection - else: - return secondary_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - -class Secondary(_ServerMode): - """Secondary read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are distributed among shard - secondaries. An error is raised if no secondaries are available. - * When connected to a replica set queries are distributed among - secondaries. An error is raised if no secondaries are available. - - :param tag_sets: The :attr:`~tag_sets` for this read preference. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_SECONDARY, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - return secondary_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - -class SecondaryPreferred(_ServerMode): - """SecondaryPreferred read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are distributed among shard - secondaries, or the shard primary if no secondary is available. - * When connected to a replica set queries are distributed among - secondaries, or the primary if no secondary is available. - - .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first - created reads will be routed to the primary of the replica set until - an available secondary is discovered. - - :param tag_sets: The :attr:`~tag_sets` for this read preference. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - secondaries = secondary_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - if secondaries: - return secondaries - else: - return selection.primary_selection - - -class Nearest(_ServerMode): - """Nearest read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are distributed among all members of - a shard. - * When connected to a replica set queries are distributed among all - members. - - :param tag_sets: The :attr:`~tag_sets` for this read preference. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_NEAREST, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - return member_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - -class _AggWritePref: - """Agg $out/$merge write preference. - - * If there are readable servers and there is any pre-5.0 server, use - primary read preference. - * Otherwise use `pref` read preference. - - :param pref: The read preference to use on MongoDB 5.0+. - """ - - __slots__ = ("pref", "effective_pref") - - def __init__(self, pref: _ServerMode): - self.pref = pref - self.effective_pref: _ServerMode = ReadPreference.PRIMARY - - def selection_hook(self, topology_description: TopologyDescription) -> None: - common_wv = topology_description.common_wire_version - if ( - topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED) - and common_wv - and common_wv < 13 - ): - self.effective_pref = ReadPreference.PRIMARY - else: - self.effective_pref = self.pref - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to a Selection.""" - return self.effective_pref(selection) - - def __repr__(self) -> str: - return f"_AggWritePref(pref={self.pref!r})" - - # Proxy other calls to the effective_pref so that _AggWritePref can be - # used in place of an actual read preference. - def __getattr__(self, name: str) -> Any: - return getattr(self.effective_pref, name) - - -_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest) - - -def make_read_preference( - mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1 -) -> _ServerMode: - if mode == _PRIMARY: - if tag_sets not in (None, [{}]): - raise ConfigurationError("Read preference primary cannot be combined with tags") - if max_staleness != -1: - raise ConfigurationError( - "Read preference primary cannot be combined with maxStalenessSeconds" - ) - return Primary() - return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore - - -_MODES = ( - "PRIMARY", - "PRIMARY_PREFERRED", - "SECONDARY", - "SECONDARY_PREFERRED", - "NEAREST", -) - - -class ReadPreference: - """An enum that defines some commonly used read preference modes. - - Apps can also create a custom read preference, for example:: - - Nearest(tag_sets=[{"node":"analytics"}]) - - See :doc:`/examples/high_availability` for code examples. - - A read preference is used in three cases: - - :class:`~pymongo.mongo_client.MongoClient` connected to a single mongod: - - - ``PRIMARY``: Queries are allowed if the server is standalone or a replica - set primary. - - All other modes allow queries to standalone servers, to a replica set - primary, or to replica set secondaries. - - :class:`~pymongo.mongo_client.MongoClient` initialized with the - ``replicaSet`` option: - - - ``PRIMARY``: Read from the primary. This is the default, and provides the - strongest consistency. If no primary is available, raise - :class:`~pymongo.errors.AutoReconnect`. - - - ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is - none, read from a secondary. - - - ``SECONDARY``: Read from a secondary. If no secondary is available, - raise :class:`~pymongo.errors.AutoReconnect`. - - - ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise - from the primary. - - - ``NEAREST``: Read from any member. - - :class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a - sharded cluster of replica sets: - - - ``PRIMARY``: Read from the primary of the shard, or raise - :class:`~pymongo.errors.OperationFailure` if there is none. - This is the default. - - - ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is - none, read from a secondary of the shard. - - - ``SECONDARY``: Read from a secondary of the shard, or raise - :class:`~pymongo.errors.OperationFailure` if there is none. - - - ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available, - otherwise from the shard primary. - - - ``NEAREST``: Read from any shard member. - """ - - PRIMARY = Primary() - PRIMARY_PREFERRED = PrimaryPreferred() - SECONDARY = Secondary() - SECONDARY_PREFERRED = SecondaryPreferred() - NEAREST = Nearest() - - -def read_pref_mode_from_name(name: str) -> int: - """Get the read preference mode from mongos/uri name.""" - return _MONGOS_MODES.index(name) - - -class MovingAverage: - """Tracks an exponentially-weighted moving average.""" - - average: Optional[float] - - def __init__(self) -> None: - self.average = None - - def add_sample(self, sample: float) -> None: - if sample < 0: - # Likely system time change while waiting for hello response - # and not using time.monotonic. Ignore it, the next one will - # probably be valid. - return - if self.average is None: - self.average = sample - else: - # The Server Selection Spec requires an exponentially weighted - # average with alpha = 0.2. - self.average = 0.8 * self.average + 0.2 * sample - - def get(self) -> Optional[float]: - """Get the calculated average, or None if no samples yet.""" - return self.average - - def reset(self) -> None: - self.average = None diff --git a/pymongo/asynchronous/response.py b/pymongo/asynchronous/response.py deleted file mode 100644 index f19328f6e..000000000 --- a/pymongo/asynchronous/response.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Represent a response from the server.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union - -if TYPE_CHECKING: - from datetime import timedelta - - from pymongo.asynchronous.message import _OpMsg, _OpReply - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.typings import _Address, _DocumentOut - -_IS_SYNC = False - - -class Response: - __slots__ = ("_data", "_address", "_request_id", "_duration", "_from_command", "_docs") - - def __init__( - self, - data: Union[_OpMsg, _OpReply], - address: _Address, - request_id: int, - duration: Optional[timedelta], - from_command: bool, - docs: Sequence[Mapping[str, Any]], - ): - """Represent a response from the server. - - :param data: A network response message. - :param address: (host, port) of the source server. - :param request_id: The request id of this operation. - :param duration: The duration of the operation. - :param from_command: if the response is the result of a db command. - """ - self._data = data - self._address = address - self._request_id = request_id - self._duration = duration - self._from_command = from_command - self._docs = docs - - @property - def data(self) -> Union[_OpMsg, _OpReply]: - """Server response's raw BSON bytes.""" - return self._data - - @property - def address(self) -> _Address: - """(host, port) of the source server.""" - return self._address - - @property - def request_id(self) -> int: - """The request id of this operation.""" - return self._request_id - - @property - def duration(self) -> Optional[timedelta]: - """The duration of the operation.""" - return self._duration - - @property - def from_command(self) -> bool: - """If the response is a result from a db command.""" - return self._from_command - - @property - def docs(self) -> Sequence[Mapping[str, Any]]: - """The decoded document(s).""" - return self._docs - - -class PinnedResponse(Response): - __slots__ = ("_conn", "_more_to_come") - - def __init__( - self, - data: Union[_OpMsg, _OpReply], - address: _Address, - conn: Connection, - request_id: int, - duration: Optional[timedelta], - from_command: bool, - docs: list[_DocumentOut], - more_to_come: bool, - ): - """Represent a response to an exhaust cursor's initial query. - - :param data: A network response message. - :param address: (host, port) of the source server. - :param conn: The Connection used for the initial query. - :param request_id: The request id of this operation. - :param duration: The duration of the operation. - :param from_command: If the response is the result of a db command. - :param docs: List of documents. - :param more_to_come: Bool indicating whether cursor is ready to be - exhausted. - """ - super().__init__(data, address, request_id, duration, from_command, docs) - self._conn = conn - self._more_to_come = more_to_come - - @property - def conn(self) -> Connection: - """The Connection used for the initial query. - - The server will send batches on this socket, without waiting for - getMores from the client, until the result set is exhausted or there - is an error. - """ - return self._conn - - @property - def more_to_come(self) -> bool: - """If true, server is ready to send batches on the socket until the - result set is exhausted or there is an error. - """ - return self._more_to_come diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index cf812d05c..892594c97 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -27,11 +27,12 @@ from typing import ( ) from bson import _decode_all_selective -from pymongo.asynchronous.helpers import _check_command_response, _handle_reauth -from pymongo.asynchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.asynchronous.message import _convert_exception, _GetMore, _OpMsg, _Query -from pymongo.asynchronous.response import PinnedResponse, Response +from pymongo.asynchronous.helpers import _handle_reauth from pymongo.errors import NotPrimaryError, OperationFailure +from pymongo.helpers_shared import _check_command_response +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query +from pymongo.response import PinnedResponse, Response if TYPE_CHECKING: from queue import Queue @@ -40,11 +41,11 @@ if TYPE_CHECKING: from bson.objectid import ObjectId from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler from pymongo.asynchronous.monitor import Monitor - from pymongo.asynchronous.monitoring import _EventListeners - from pymongo.asynchronous.pool import Connection, Pool - from pymongo.asynchronous.read_preferences import _ServerMode - from pymongo.asynchronous.server_description import ServerDescription - from pymongo.asynchronous.typings import _DocumentOut + from pymongo.asynchronous.pool import AsyncConnection, Pool + from pymongo.monitoring import _EventListeners + from pymongo.read_preferences import _ServerMode + from pymongo.server_description import ServerDescription + from pymongo.typings import _DocumentOut _IS_SYNC = False @@ -105,10 +106,23 @@ class Server: """Check the server's state soon.""" self._monitor.request_check() + async def operation_to_command( + self, operation: Union[_Query, _GetMore], conn: AsyncConnection, apply_timeout: bool = False + ) -> tuple[dict[str, Any], str]: + cmd, db = operation.as_command(conn, apply_timeout) + # Support auto encryption + if operation.client._encrypter and not operation.client._encrypter._bypass_auto_encryption: + cmd = await operation.client._encrypter.encrypt( # type: ignore[misc, assignment] + operation.db, cmd, operation.codec_options + ) + operation.update_command(cmd) + + return cmd, db + @_handle_reauth async def run_operation( self, - conn: Connection, + conn: AsyncConnection, operation: Union[_Query, _GetMore], read_preference: _ServerMode, listeners: Optional[_EventListeners], @@ -121,26 +135,26 @@ class Server: cursors. Can raise ConnectionFailure, OperationFailure, etc. - :param conn: A Connection instance. + :param conn: An AsyncConnection instance. :param operation: A _Query or _GetMore object. :param read_preference: The read preference to use. :param listeners: Instance of _EventListeners or None. :param unpack_res: A callable that decodes the wire protocol response. + :param client: An AsyncMongoClient instance. """ - duration = None assert listeners is not None publish = listeners.enabled_for_commands start = datetime.now() use_cmd = operation.use_command(conn) more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come + cmd, dbn = await self.operation_to_command(operation, conn, use_cmd) if more_to_come: request_id = 0 else: - message = await operation.get_message(read_preference, conn, use_cmd) + message = operation.get_message(read_preference, conn, use_cmd) request_id, data, max_doc_size = self._split_message(message) - cmd, dbn = await operation.as_command(conn) if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, @@ -159,7 +173,6 @@ class Server: ) if publish: - cmd, dbn = await operation.as_command(conn) if "$db" not in cmd: cmd["$db"] = dbn assert listeners is not None @@ -195,7 +208,7 @@ class Server: ) if use_cmd: first = docs[0] - await operation.client._process_response(first, operation.session) + await operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] _check_command_response(first, conn.max_wire_version) except Exception as exc: duration = datetime.now() - start @@ -278,7 +291,7 @@ class Server: ) # Decrypt response. - client = operation.client + client = operation.client # type: ignore[assignment] if client and client._encrypter: if use_cmd: decrypted = client._encrypter.decrypt(reply.raw_command_response()) @@ -286,7 +299,7 @@ class Server: response: Response - if client._should_pin_cursor(operation.session) or operation.exhaust: + if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type] conn.pin_cursor() if isinstance(reply, _OpMsg): # In OP_MSG, the server keeps sending only if the @@ -321,7 +334,7 @@ class Server: async def checkout( self, handler: Optional[_MongoClientErrorHandler] = None - ) -> AsyncContextManager[Connection]: + ) -> AsyncContextManager[AsyncConnection]: return self.pool.checkout(handler) @property diff --git a/pymongo/asynchronous/server_description.py b/pymongo/asynchronous/server_description.py deleted file mode 100644 index 8e15c3400..000000000 --- a/pymongo/asynchronous/server_description.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Represent one server the driver is connected to.""" -from __future__ import annotations - -import time -import warnings -from typing import Any, Mapping, Optional - -from bson import EPOCH_NAIVE -from bson.objectid import ObjectId -from pymongo.asynchronous.hello import Hello -from pymongo.asynchronous.typings import ClusterTime, _Address -from pymongo.server_type import SERVER_TYPE - -_IS_SYNC = False - - -class ServerDescription: - """Immutable representation of one server. - - :param address: A (host, port) pair - :param hello: Optional Hello instance - :param round_trip_time: Optional float - :param error: Optional, the last error attempting to connect to the server - :param round_trip_time: Optional float, the min latency from the most recent samples - """ - - __slots__ = ( - "_address", - "_server_type", - "_all_hosts", - "_tags", - "_replica_set_name", - "_primary", - "_max_bson_size", - "_max_message_size", - "_max_write_batch_size", - "_min_wire_version", - "_max_wire_version", - "_round_trip_time", - "_min_round_trip_time", - "_me", - "_is_writable", - "_is_readable", - "_ls_timeout_minutes", - "_error", - "_set_version", - "_election_id", - "_cluster_time", - "_last_write_date", - "_last_update_time", - "_topology_version", - ) - - def __init__( - self, - address: _Address, - hello: Optional[Hello] = None, - round_trip_time: Optional[float] = None, - error: Optional[Exception] = None, - min_round_trip_time: float = 0.0, - ) -> None: - self._address = address - if not hello: - hello = Hello({}) - - self._server_type = hello.server_type - self._all_hosts = hello.all_hosts - self._tags = hello.tags - self._replica_set_name = hello.replica_set_name - self._primary = hello.primary - self._max_bson_size = hello.max_bson_size - self._max_message_size = hello.max_message_size - self._max_write_batch_size = hello.max_write_batch_size - self._min_wire_version = hello.min_wire_version - self._max_wire_version = hello.max_wire_version - self._set_version = hello.set_version - self._election_id = hello.election_id - self._cluster_time = hello.cluster_time - self._is_writable = hello.is_writable - self._is_readable = hello.is_readable - self._ls_timeout_minutes = hello.logical_session_timeout_minutes - self._round_trip_time = round_trip_time - self._min_round_trip_time = min_round_trip_time - self._me = hello.me - self._last_update_time = time.monotonic() - self._error = error - self._topology_version = hello.topology_version - if error: - details = getattr(error, "details", None) - if isinstance(details, dict): - self._topology_version = details.get("topologyVersion") - - self._last_write_date: Optional[float] - if hello.last_write_date: - # Convert from datetime to seconds. - delta = hello.last_write_date - EPOCH_NAIVE - self._last_write_date = delta.total_seconds() - else: - self._last_write_date = None - - @property - def address(self) -> _Address: - """The address (host, port) of this server.""" - return self._address - - @property - def server_type(self) -> int: - """The type of this server.""" - return self._server_type - - @property - def server_type_name(self) -> str: - """The server type as a human readable string. - - .. versionadded:: 3.4 - """ - return SERVER_TYPE._fields[self._server_type] - - @property - def all_hosts(self) -> set[tuple[str, int]]: - """List of hosts, passives, and arbiters known to this server.""" - return self._all_hosts - - @property - def tags(self) -> Mapping[str, Any]: - return self._tags - - @property - def replica_set_name(self) -> Optional[str]: - """Replica set name or None.""" - return self._replica_set_name - - @property - def primary(self) -> Optional[tuple[str, int]]: - """This server's opinion about who the primary is, or None.""" - return self._primary - - @property - def max_bson_size(self) -> int: - return self._max_bson_size - - @property - def max_message_size(self) -> int: - return self._max_message_size - - @property - def max_write_batch_size(self) -> int: - return self._max_write_batch_size - - @property - def min_wire_version(self) -> int: - return self._min_wire_version - - @property - def max_wire_version(self) -> int: - return self._max_wire_version - - @property - def set_version(self) -> Optional[int]: - return self._set_version - - @property - def election_id(self) -> Optional[ObjectId]: - return self._election_id - - @property - def cluster_time(self) -> Optional[ClusterTime]: - return self._cluster_time - - @property - def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]: - warnings.warn( - "'election_tuple' is deprecated, use 'set_version' and 'election_id' instead", - DeprecationWarning, - stacklevel=2, - ) - return self._set_version, self._election_id - - @property - def me(self) -> Optional[tuple[str, int]]: - return self._me - - @property - def logical_session_timeout_minutes(self) -> Optional[int]: - return self._ls_timeout_minutes - - @property - def last_write_date(self) -> Optional[float]: - return self._last_write_date - - @property - def last_update_time(self) -> float: - return self._last_update_time - - @property - def round_trip_time(self) -> Optional[float]: - """The current average latency or None.""" - # This override is for unittesting only! - if self._address in self._host_to_round_trip_time: - return self._host_to_round_trip_time[self._address] - - return self._round_trip_time - - @property - def min_round_trip_time(self) -> float: - """The min latency from the most recent samples.""" - return self._min_round_trip_time - - @property - def error(self) -> Optional[Exception]: - """The last error attempting to connect to the server, or None.""" - return self._error - - @property - def is_writable(self) -> bool: - return self._is_writable - - @property - def is_readable(self) -> bool: - return self._is_readable - - @property - def mongos(self) -> bool: - return self._server_type == SERVER_TYPE.Mongos - - @property - def is_server_type_known(self) -> bool: - return self.server_type != SERVER_TYPE.Unknown - - @property - def retryable_writes_supported(self) -> bool: - """Checks if this server supports retryable writes.""" - return ( - self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary) - ) or self._server_type == SERVER_TYPE.LoadBalancer - - @property - def retryable_reads_supported(self) -> bool: - """Checks if this server supports retryable writes.""" - return self._max_wire_version >= 6 - - @property - def topology_version(self) -> Optional[Mapping[str, Any]]: - return self._topology_version - - def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription: - unknown = ServerDescription(self.address, error=error) - unknown._topology_version = self.topology_version - return unknown - - def __eq__(self, other: Any) -> bool: - if isinstance(other, ServerDescription): - return ( - (self._address == other.address) - and (self._server_type == other.server_type) - and (self._min_wire_version == other.min_wire_version) - and (self._max_wire_version == other.max_wire_version) - and (self._me == other.me) - and (self._all_hosts == other.all_hosts) - and (self._tags == other.tags) - and (self._replica_set_name == other.replica_set_name) - and (self._set_version == other.set_version) - and (self._election_id == other.election_id) - and (self._primary == other.primary) - and (self._ls_timeout_minutes == other.logical_session_timeout_minutes) - and (self._error == other.error) - ) - - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __repr__(self) -> str: - errmsg = "" - if self.error: - errmsg = f", error={self.error!r}" - return "<{} {} server_type: {}, rtt: {}{}>".format( - self.__class__.__name__, - self.address, - self.server_type_name, - self.round_trip_time, - errmsg, - ) - - # For unittesting only. Use under no circumstances! - _host_to_round_trip_time: dict = {} diff --git a/pymongo/asynchronous/server_selectors.py b/pymongo/asynchronous/server_selectors.py deleted file mode 100644 index eeaebadd6..000000000 --- a/pymongo/asynchronous/server_selectors.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright 2014-2016 MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Criteria to select some ServerDescriptions from a TopologyDescription.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, TypeVar, cast - -from pymongo.server_type import SERVER_TYPE - -if TYPE_CHECKING: - from pymongo.asynchronous.server_description import ServerDescription - from pymongo.asynchronous.topology_description import TopologyDescription - -_IS_SYNC = False - -T = TypeVar("T") -TagSet = Mapping[str, Any] -TagSets = Sequence[TagSet] - - -class Selection: - """Input or output of a server selector function.""" - - @classmethod - def from_topology_description(cls, topology_description: TopologyDescription) -> Selection: - known_servers = topology_description.known_servers - primary = None - for sd in known_servers: - if sd.server_type == SERVER_TYPE.RSPrimary: - primary = sd - break - - return Selection( - topology_description, - topology_description.known_servers, - topology_description.common_wire_version, - primary, - ) - - def __init__( - self, - topology_description: TopologyDescription, - server_descriptions: list[ServerDescription], - common_wire_version: Optional[int], - primary: Optional[ServerDescription], - ): - self.topology_description = topology_description - self.server_descriptions = server_descriptions - self.primary = primary - self.common_wire_version = common_wire_version - - def with_server_descriptions(self, server_descriptions: list[ServerDescription]) -> Selection: - return Selection( - self.topology_description, server_descriptions, self.common_wire_version, self.primary - ) - - def secondary_with_max_last_write_date(self) -> Optional[ServerDescription]: - secondaries = secondary_server_selector(self) - if secondaries.server_descriptions: - return max( - secondaries.server_descriptions, key=lambda sd: cast(float, sd.last_write_date) - ) - return None - - @property - def primary_selection(self) -> Selection: - primaries = [self.primary] if self.primary else [] - return self.with_server_descriptions(primaries) - - @property - def heartbeat_frequency(self) -> int: - return self.topology_description.heartbeat_frequency - - @property - def topology_type(self) -> int: - return self.topology_description.topology_type - - def __bool__(self) -> bool: - return bool(self.server_descriptions) - - def __getitem__(self, item: int) -> ServerDescription: - return self.server_descriptions[item] - - -def any_server_selector(selection: T) -> T: - return selection - - -def readable_server_selector(selection: Selection) -> Selection: - return selection.with_server_descriptions( - [s for s in selection.server_descriptions if s.is_readable] - ) - - -def writable_server_selector(selection: Selection) -> Selection: - return selection.with_server_descriptions( - [s for s in selection.server_descriptions if s.is_writable] - ) - - -def secondary_server_selector(selection: Selection) -> Selection: - return selection.with_server_descriptions( - [s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSSecondary] - ) - - -def arbiter_server_selector(selection: Selection) -> Selection: - return selection.with_server_descriptions( - [s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSArbiter] - ) - - -def writable_preferred_server_selector(selection: Selection) -> Selection: - """Like PrimaryPreferred but doesn't use tags or latency.""" - return writable_server_selector(selection) or secondary_server_selector(selection) - - -def apply_single_tag_set(tag_set: TagSet, selection: Selection) -> Selection: - """All servers matching one tag set. - - A tag set is a dict. A server matches if its tags are a superset: - A server tagged {'a': '1', 'b': '2'} matches the tag set {'a': '1'}. - - The empty tag set {} matches any server. - """ - - def tags_match(server_tags: Mapping[str, Any]) -> bool: - for key, value in tag_set.items(): - if key not in server_tags or server_tags[key] != value: - return False - - return True - - return selection.with_server_descriptions( - [s for s in selection.server_descriptions if tags_match(s.tags)] - ) - - -def apply_tag_sets(tag_sets: TagSets, selection: Selection) -> Selection: - """All servers match a list of tag sets. - - tag_sets is a list of dicts. The empty tag set {} matches any server, - and may be provided at the end of the list as a fallback. So - [{'a': 'value'}, {}] expresses a preference for servers tagged - {'a': 'value'}, but accepts any server if none matches the first - preference. - """ - for tag_set in tag_sets: - with_tag_set = apply_single_tag_set(tag_set, selection) - if with_tag_set: - return with_tag_set - - return selection.with_server_descriptions([]) - - -def secondary_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection: - """All near-enough secondaries matching the tag sets.""" - return apply_tag_sets(tag_sets, secondary_server_selector(selection)) - - -def member_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection: - """All near-enough members matching the tag sets.""" - return apply_tag_sets(tag_sets, readable_server_selector(selection)) diff --git a/pymongo/asynchronous/settings.py b/pymongo/asynchronous/settings.py index f88235cf5..c41c638e6 100644 --- a/pymongo/asynchronous/settings.py +++ b/pymongo/asynchronous/settings.py @@ -20,12 +20,14 @@ import traceback from typing import Any, Collection, Optional, Type, Union from bson.objectid import ObjectId -from pymongo.asynchronous import common, monitor, pool -from pymongo.asynchronous.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT -from pymongo.asynchronous.pool import Pool, PoolOptions -from pymongo.asynchronous.server_description import ServerDescription -from pymongo.asynchronous.topology_description import TOPOLOGY_TYPE, _ServerSelector +from pymongo import common +from pymongo.asynchronous import monitor, pool +from pymongo.asynchronous.pool import Pool +from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT from pymongo.errors import ConfigurationError +from pymongo.pool_options import PoolOptions +from pymongo.server_description import ServerDescription +from pymongo.topology_description import TOPOLOGY_TYPE, _ServerSelector _IS_SYNC = False diff --git a/pymongo/asynchronous/srv_resolver.py b/pymongo/asynchronous/srv_resolver.py deleted file mode 100644 index 1a37bad96..000000000 --- a/pymongo/asynchronous/srv_resolver.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2019-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Support for resolving hosts and options from mongodb+srv:// URIs.""" -from __future__ import annotations - -import ipaddress -import random -from typing import TYPE_CHECKING, Any, Optional, Union - -from pymongo.asynchronous.common import CONNECT_TIMEOUT -from pymongo.errors import ConfigurationError - -if TYPE_CHECKING: - from dns import resolver - -_IS_SYNC = False - - -def _have_dnspython() -> bool: - try: - import dns # noqa: F401 - - return True - except ImportError: - return False - - -# dnspython can return bytes or str from various parts -# of its API depending on version. We always want str. -def maybe_decode(text: Union[str, bytes]) -> str: - if isinstance(text, bytes): - return text.decode() - return text - - -# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet. -def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: - from dns import resolver - - if hasattr(resolver, "resolve"): - # dnspython >= 2 - return resolver.resolve(*args, **kwargs) - # dnspython 1.X - return resolver.query(*args, **kwargs) - - -_INVALID_HOST_MSG = ( - "Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. " - "Did you mean to use 'mongodb://'?" -) - - -class _SrvResolver: - def __init__( - self, - fqdn: str, - connect_timeout: Optional[float], - srv_service_name: str, - srv_max_hosts: int = 0, - ): - self.__fqdn = fqdn - self.__srv = srv_service_name - self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT - self.__srv_max_hosts = srv_max_hosts or 0 - # Validate the fully qualified domain name. - try: - ipaddress.ip_address(fqdn) - raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",)) - except ValueError: - pass - - try: - self.__plist = self.__fqdn.split(".")[1:] - except Exception: - raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None - self.__slen = len(self.__plist) - if self.__slen < 2: - raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) - - def get_options(self) -> Optional[str]: - from dns import resolver - - try: - results = _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout) - except (resolver.NoAnswer, resolver.NXDOMAIN): - # No TXT records - return None - except Exception as exc: - raise ConfigurationError(str(exc)) from None - if len(results) > 1: - raise ConfigurationError("Only one TXT record is supported") - return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") - - def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer: - try: - results = _resolve( - "_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout - ) - except Exception as exc: - if not encapsulate_errors: - # Raise the original error. - raise - # Else, raise all errors as ConfigurationError. - raise ConfigurationError(str(exc)) from None - return results - - def _get_srv_response_and_hosts( - self, encapsulate_errors: bool - ) -> tuple[resolver.Answer, list[tuple[str, Any]]]: - results = self._resolve_uri(encapsulate_errors) - - # Construct address tuples - nodes = [ - (maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) for res in results - ] - - # Validate hosts - for node in nodes: - try: - nlist = node[0].lower().split(".")[1:][-self.__slen :] - except Exception: - raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None - if self.__plist != nlist: - raise ConfigurationError(f"Invalid SRV host: {node[0]}") - if self.__srv_max_hosts: - nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) - return results, nodes - - def get_hosts(self) -> list[tuple[str, Any]]: - _, nodes = self._get_srv_response_and_hosts(True) - return nodes - - def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]: - results, nodes = self._get_srv_response_and_hosts(False) - rrset = results.rrset - ttl = rrset.ttl if rrset else 0 - return nodes, ttl diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index ac578113b..16cdd0eba 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -27,33 +27,12 @@ import weakref from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast -from pymongo import _csot, helpers_constants -from pymongo.asynchronous import common, periodic_executor +from pymongo import _csot, common, helpers_shared +from pymongo.asynchronous import periodic_executor from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool -from pymongo.asynchronous.hello import Hello -from pymongo.asynchronous.logger import ( - _SERVER_SELECTION_LOGGER, - _debug_log, - _ServerSelectionStatusMessage, -) from pymongo.asynchronous.monitor import SrvMonitor -from pymongo.asynchronous.pool import Pool, PoolOptions +from pymongo.asynchronous.pool import Pool from pymongo.asynchronous.server import Server -from pymongo.asynchronous.server_description import ServerDescription -from pymongo.asynchronous.server_selectors import ( - Selection, - any_server_selector, - arbiter_server_selector, - secondary_server_selector, - writable_server_selector, -) -from pymongo.asynchronous.topology_description import ( - SRV_POLLING_TOPOLOGIES, - TOPOLOGY_TYPE, - TopologyDescription, - _updated_topology_description_srv_polling, - updated_topology_description, -) from pymongo.errors import ( ConnectionFailure, InvalidOperation, @@ -64,12 +43,34 @@ from pymongo.errors import ( ServerSelectionTimeoutError, WriteError, ) +from pymongo.hello import Hello from pymongo.lock import _ACondition, _ALock, _create_lock +from pymongo.logger import ( + _SERVER_SELECTION_LOGGER, + _debug_log, + _ServerSelectionStatusMessage, +) +from pymongo.pool_options import PoolOptions +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import ( + Selection, + any_server_selector, + arbiter_server_selector, + secondary_server_selector, + writable_server_selector, +) +from pymongo.topology_description import ( + SRV_POLLING_TOPOLOGIES, + TOPOLOGY_TYPE, + TopologyDescription, + _updated_topology_description_srv_polling, + updated_topology_description, +) if TYPE_CHECKING: from bson import ObjectId from pymongo.asynchronous.settings import TopologySettings - from pymongo.asynchronous.typings import ClusterTime, _Address + from pymongo.typings import ClusterTime, _Address _IS_SYNC = False @@ -781,8 +782,8 @@ class Topology: # Default error code if one does not exist. default = 10107 if isinstance(error, NotPrimaryError) else None err_code = error.details.get("code", default) # type: ignore[union-attr] - if err_code in helpers_constants._NOT_PRIMARY_CODES: - is_shutting_down = err_code in helpers_constants._SHUTDOWN_CODES + if err_code in helpers_shared._NOT_PRIMARY_CODES: + is_shutting_down = err_code in helpers_shared._SHUTDOWN_CODES # Mark server Unknown, clear the pool, and request check. if not self._settings.load_balanced: await self._process_change(ServerDescription(address, error=error)) diff --git a/pymongo/asynchronous/topology_description.py b/pymongo/asynchronous/topology_description.py deleted file mode 100644 index ce7aff7f5..000000000 --- a/pymongo/asynchronous/topology_description.py +++ /dev/null @@ -1,678 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Represent a deployment of MongoDB servers.""" -from __future__ import annotations - -from random import sample -from typing import ( - Any, - Callable, - List, - Mapping, - MutableMapping, - NamedTuple, - Optional, - cast, -) - -from bson.min_key import MinKey -from bson.objectid import ObjectId -from pymongo.asynchronous import common -from pymongo.asynchronous.read_preferences import ReadPreference, _AggWritePref, _ServerMode -from pymongo.asynchronous.server_description import ServerDescription -from pymongo.asynchronous.server_selectors import Selection -from pymongo.asynchronous.typings import _Address -from pymongo.errors import ConfigurationError -from pymongo.server_type import SERVER_TYPE - -_IS_SYNC = False - - -# Enumeration for various kinds of MongoDB cluster topologies. -class _TopologyType(NamedTuple): - Single: int - ReplicaSetNoPrimary: int - ReplicaSetWithPrimary: int - Sharded: int - Unknown: int - LoadBalanced: int - - -TOPOLOGY_TYPE = _TopologyType(*range(6)) - -# Topologies compatible with SRV record polling. -SRV_POLLING_TOPOLOGIES: tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) - - -_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]] - - -class TopologyDescription: - def __init__( - self, - topology_type: int, - server_descriptions: dict[_Address, ServerDescription], - replica_set_name: Optional[str], - max_set_version: Optional[int], - max_election_id: Optional[ObjectId], - topology_settings: Any, - ) -> None: - """Representation of a deployment of MongoDB servers. - - :param topology_type: initial type - :param server_descriptions: dict of (address, ServerDescription) for - all seeds - :param replica_set_name: replica set name or None - :param max_set_version: greatest setVersion seen from a primary, or None - :param max_election_id: greatest electionId seen from a primary, or None - :param topology_settings: a TopologySettings - """ - self._topology_type = topology_type - self._replica_set_name = replica_set_name - self._server_descriptions = server_descriptions - self._max_set_version = max_set_version - self._max_election_id = max_election_id - - # The heartbeat_frequency is used in staleness estimates. - self._topology_settings = topology_settings - - # Is PyMongo compatible with all servers' wire protocols? - self._incompatible_err = None - if self._topology_type != TOPOLOGY_TYPE.LoadBalanced: - self._init_incompatible_err() - - # Server Discovery And Monitoring Spec: Whenever a client updates the - # TopologyDescription from an hello response, it MUST set - # TopologyDescription.logicalSessionTimeoutMinutes to the smallest - # logicalSessionTimeoutMinutes value among ServerDescriptions of all - # data-bearing server types. If any have a null - # logicalSessionTimeoutMinutes, then - # TopologyDescription.logicalSessionTimeoutMinutes MUST be set to null. - readable_servers = self.readable_servers - if not readable_servers: - self._ls_timeout_minutes = None - elif any(s.logical_session_timeout_minutes is None for s in readable_servers): - self._ls_timeout_minutes = None - else: - self._ls_timeout_minutes = min( # type: ignore[type-var] - s.logical_session_timeout_minutes for s in readable_servers - ) - - def _init_incompatible_err(self) -> None: - """Internal compatibility check for non-load balanced topologies.""" - for s in self._server_descriptions.values(): - if not s.is_server_type_known: - continue - - # s.min/max_wire_version is the server's wire protocol. - # MIN/MAX_SUPPORTED_WIRE_VERSION is what PyMongo supports. - server_too_new = ( - # Server too new. - s.min_wire_version is not None - and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION - ) - - server_too_old = ( - # Server too old. - s.max_wire_version is not None - and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION - ) - - if server_too_new: - self._incompatible_err = ( - "Server at %s:%d requires wire version %d, but this " # type: ignore - "version of PyMongo only supports up to %d." - % ( - s.address[0], - s.address[1] or 0, - s.min_wire_version, - common.MAX_SUPPORTED_WIRE_VERSION, - ) - ) - - elif server_too_old: - self._incompatible_err = ( - "Server at %s:%d reports wire version %d, but this " # type: ignore - "version of PyMongo requires at least %d (MongoDB %s)." - % ( - s.address[0], - s.address[1] or 0, - s.max_wire_version, - common.MIN_SUPPORTED_WIRE_VERSION, - common.MIN_SUPPORTED_SERVER_VERSION, - ) - ) - - break - - def check_compatible(self) -> None: - """Raise ConfigurationError if any server is incompatible. - - A server is incompatible if its wire protocol version range does not - overlap with PyMongo's. - """ - if self._incompatible_err: - raise ConfigurationError(self._incompatible_err) - - def has_server(self, address: _Address) -> bool: - return address in self._server_descriptions - - def reset_server(self, address: _Address) -> TopologyDescription: - """A copy of this description, with one server marked Unknown.""" - unknown_sd = self._server_descriptions[address].to_unknown() - return updated_topology_description(self, unknown_sd) - - def reset(self) -> TopologyDescription: - """A copy of this description, with all servers marked Unknown.""" - if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: - topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary - else: - topology_type = self._topology_type - - # The default ServerDescription's type is Unknown. - sds = {address: ServerDescription(address) for address in self._server_descriptions} - - return TopologyDescription( - topology_type, - sds, - self._replica_set_name, - self._max_set_version, - self._max_election_id, - self._topology_settings, - ) - - def server_descriptions(self) -> dict[_Address, ServerDescription]: - """dict of (address, - :class:`~pymongo.server_description.ServerDescription`). - """ - return self._server_descriptions.copy() - - @property - def topology_type(self) -> int: - """The type of this topology.""" - return self._topology_type - - @property - def topology_type_name(self) -> str: - """The topology type as a human readable string. - - .. versionadded:: 3.4 - """ - return TOPOLOGY_TYPE._fields[self._topology_type] - - @property - def replica_set_name(self) -> Optional[str]: - """The replica set name.""" - return self._replica_set_name - - @property - def max_set_version(self) -> Optional[int]: - """Greatest setVersion seen from a primary, or None.""" - return self._max_set_version - - @property - def max_election_id(self) -> Optional[ObjectId]: - """Greatest electionId seen from a primary, or None.""" - return self._max_election_id - - @property - def logical_session_timeout_minutes(self) -> Optional[int]: - """Minimum logical session timeout, or None.""" - return self._ls_timeout_minutes - - @property - def known_servers(self) -> list[ServerDescription]: - """List of Servers of types besides Unknown.""" - return [s for s in self._server_descriptions.values() if s.is_server_type_known] - - @property - def has_known_servers(self) -> bool: - """Whether there are any Servers of types besides Unknown.""" - return any(s for s in self._server_descriptions.values() if s.is_server_type_known) - - @property - def readable_servers(self) -> list[ServerDescription]: - """List of readable Servers.""" - return [s for s in self._server_descriptions.values() if s.is_readable] - - @property - def common_wire_version(self) -> Optional[int]: - """Minimum of all servers' max wire versions, or None.""" - servers = self.known_servers - if servers: - return min(s.max_wire_version for s in self.known_servers) - - return None - - @property - def heartbeat_frequency(self) -> int: - return self._topology_settings.heartbeat_frequency - - @property - def srv_max_hosts(self) -> int: - return self._topology_settings._srv_max_hosts - - def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]: - if not selection: - return [] - round_trip_times: list[float] = [] - for server in selection.server_descriptions: - if server.round_trip_time is None: - config_err_msg = f"round_trip_time for server {server.address} is unexpectedly None: {self}, servers: {selection.server_descriptions}" - raise ConfigurationError(config_err_msg) - round_trip_times.append(server.round_trip_time) - # Round trip time in seconds. - fastest = min(round_trip_times) - threshold = self._topology_settings.local_threshold_ms / 1000.0 - return [ - s - for s in selection.server_descriptions - if (cast(float, s.round_trip_time) - fastest) <= threshold - ] - - def apply_selector( - self, - selector: Any, - address: Optional[_Address] = None, - custom_selector: Optional[_ServerSelector] = None, - ) -> list[ServerDescription]: - """List of servers matching the provided selector(s). - - :param selector: a callable that takes a Selection as input and returns - a Selection as output. For example, an instance of a read - preference from :mod:`~pymongo.read_preferences`. - :param address: A server address to select. - :param custom_selector: A callable that augments server - selection rules. Accepts a list of - :class:`~pymongo.server_description.ServerDescription` objects and - return a list of server descriptions that should be considered - suitable for the desired operation. - - .. versionadded:: 3.4 - """ - if getattr(selector, "min_wire_version", 0): - common_wv = self.common_wire_version - if common_wv and common_wv < selector.min_wire_version: - raise ConfigurationError( - "%s requires min wire version %d, but topology's min" - " wire version is %d" % (selector, selector.min_wire_version, common_wv) - ) - - if isinstance(selector, _AggWritePref): - selector.selection_hook(self) - - if self.topology_type == TOPOLOGY_TYPE.Unknown: - return [] - elif self.topology_type in (TOPOLOGY_TYPE.Single, TOPOLOGY_TYPE.LoadBalanced): - # Ignore selectors for standalone and load balancer mode. - return self.known_servers - if address: - # Ignore selectors when explicit address is requested. - description = self.server_descriptions().get(address) - return [description] if description else [] - - selection = Selection.from_topology_description(self) - # Ignore read preference for sharded clusters. - if self.topology_type != TOPOLOGY_TYPE.Sharded: - selection = selector(selection) - - # Apply custom selector followed by localThresholdMS. - if custom_selector is not None and selection: - selection = selection.with_server_descriptions( - custom_selector(selection.server_descriptions) - ) - return self._apply_local_threshold(selection) - - def has_readable_server(self, read_preference: _ServerMode = ReadPreference.PRIMARY) -> bool: - """Does this topology have any readable servers available matching the - given read preference? - - :param read_preference: an instance of a read preference from - :mod:`~pymongo.read_preferences`. Defaults to - :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. - - .. note:: When connected directly to a single server this method - always returns ``True``. - - .. versionadded:: 3.4 - """ - common.validate_read_preference("read_preference", read_preference) - return any(self.apply_selector(read_preference)) - - def has_writable_server(self) -> bool: - """Does this topology have a writable server available? - - .. note:: When connected directly to a single server this method - always returns ``True``. - - .. versionadded:: 3.4 - """ - return self.has_readable_server(ReadPreference.PRIMARY) - - def __repr__(self) -> str: - # Sort the servers by address. - servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address) - return "<{} id: {}, topology_type: {}, servers: {!r}>".format( - self.__class__.__name__, - self._topology_settings._topology_id, - self.topology_type_name, - servers, - ) - - -# If topology type is Unknown and we receive a hello response, what should -# the new topology type be? -_SERVER_TYPE_TO_TOPOLOGY_TYPE = { - SERVER_TYPE.Mongos: TOPOLOGY_TYPE.Sharded, - SERVER_TYPE.RSPrimary: TOPOLOGY_TYPE.ReplicaSetWithPrimary, - SERVER_TYPE.RSSecondary: TOPOLOGY_TYPE.ReplicaSetNoPrimary, - SERVER_TYPE.RSArbiter: TOPOLOGY_TYPE.ReplicaSetNoPrimary, - SERVER_TYPE.RSOther: TOPOLOGY_TYPE.ReplicaSetNoPrimary, - # Note: SERVER_TYPE.LoadBalancer and Unknown are intentionally left out. -} - - -def updated_topology_description( - topology_description: TopologyDescription, server_description: ServerDescription -) -> TopologyDescription: - """Return an updated copy of a TopologyDescription. - - :param topology_description: the current TopologyDescription - :param server_description: a new ServerDescription that resulted from - a hello call - - Called after attempting (successfully or not) to call hello on the - server at server_description.address. Does not modify topology_description. - """ - address = server_description.address - - # These values will be updated, if necessary, to form the new - # TopologyDescription. - topology_type = topology_description.topology_type - set_name = topology_description.replica_set_name - max_set_version = topology_description.max_set_version - max_election_id = topology_description.max_election_id - server_type = server_description.server_type - - # Don't mutate the original dict of server descriptions; copy it. - sds = topology_description.server_descriptions() - - # Replace this server's description with the new one. - sds[address] = server_description - - if topology_type == TOPOLOGY_TYPE.Single: - # Set server type to Unknown if replica set name does not match. - if set_name is not None and set_name != server_description.replica_set_name: - error = ConfigurationError( - "client is configured to connect to a replica set named " - "'{}' but this node belongs to a set named '{}'".format( - set_name, server_description.replica_set_name - ) - ) - sds[address] = server_description.to_unknown(error=error) - # Single type never changes. - return TopologyDescription( - TOPOLOGY_TYPE.Single, - sds, - set_name, - max_set_version, - max_election_id, - topology_description._topology_settings, - ) - - if topology_type == TOPOLOGY_TYPE.Unknown: - if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.LoadBalancer): - if len(topology_description._topology_settings.seeds) == 1: - topology_type = TOPOLOGY_TYPE.Single - else: - # Remove standalone from Topology when given multiple seeds. - sds.pop(address) - elif server_type not in (SERVER_TYPE.Unknown, SERVER_TYPE.RSGhost): - topology_type = _SERVER_TYPE_TO_TOPOLOGY_TYPE[server_type] - - if topology_type == TOPOLOGY_TYPE.Sharded: - if server_type not in (SERVER_TYPE.Mongos, SERVER_TYPE.Unknown): - sds.pop(address) - - elif topology_type == TOPOLOGY_TYPE.ReplicaSetNoPrimary: - if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): - sds.pop(address) - - elif server_type == SERVER_TYPE.RSPrimary: - (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( - sds, set_name, server_description, max_set_version, max_election_id - ) - - elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): - topology_type, set_name = _update_rs_no_primary_from_member( - sds, set_name, server_description - ) - - elif topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: - if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): - sds.pop(address) - topology_type = _check_has_primary(sds) - - elif server_type == SERVER_TYPE.RSPrimary: - (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( - sds, set_name, server_description, max_set_version, max_election_id - ) - - elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): - topology_type = _update_rs_with_primary_from_member(sds, set_name, server_description) - - else: - # Server type is Unknown or RSGhost: did we just lose the primary? - topology_type = _check_has_primary(sds) - - # Return updated copy. - return TopologyDescription( - topology_type, - sds, - set_name, - max_set_version, - max_election_id, - topology_description._topology_settings, - ) - - -def _updated_topology_description_srv_polling( - topology_description: TopologyDescription, seedlist: list[tuple[str, Any]] -) -> TopologyDescription: - """Return an updated copy of a TopologyDescription. - - :param topology_description: the current TopologyDescription - :param seedlist: a list of new seeds new ServerDescription that resulted from - a hello call - """ - assert topology_description.topology_type in SRV_POLLING_TOPOLOGIES - # Create a copy of the server descriptions. - sds = topology_description.server_descriptions() - - # If seeds haven't changed, don't do anything. - if set(sds.keys()) == set(seedlist): - return topology_description - - # Remove SDs corresponding to servers no longer part of the SRV record. - for address in list(sds.keys()): - if address not in seedlist: - sds.pop(address) - - if topology_description.srv_max_hosts != 0: - new_hosts = set(seedlist) - set(sds.keys()) - n_to_add = topology_description.srv_max_hosts - len(sds) - if n_to_add > 0: - seedlist = sample(sorted(new_hosts), min(n_to_add, len(new_hosts))) - else: - seedlist = [] - # Add SDs corresponding to servers recently added to the SRV record. - for address in seedlist: - if address not in sds: - sds[address] = ServerDescription(address) - return TopologyDescription( - topology_description.topology_type, - sds, - topology_description.replica_set_name, - topology_description.max_set_version, - topology_description.max_election_id, - topology_description._topology_settings, - ) - - -def _update_rs_from_primary( - sds: MutableMapping[_Address, ServerDescription], - replica_set_name: Optional[str], - server_description: ServerDescription, - max_set_version: Optional[int], - max_election_id: Optional[ObjectId], -) -> tuple[int, Optional[str], Optional[int], Optional[ObjectId]]: - """Update topology description from a primary's hello response. - - Pass in a dict of ServerDescriptions, current replica set name, the - ServerDescription we are processing, and the TopologyDescription's - max_set_version and max_election_id if any. - - Returns (new topology type, new replica_set_name, new max_set_version, - new max_election_id). - """ - if replica_set_name is None: - replica_set_name = server_description.replica_set_name - - elif replica_set_name != server_description.replica_set_name: - # We found a primary but it doesn't have the replica_set_name - # provided by the user. - sds.pop(server_description.address) - return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id - - if server_description.max_wire_version is None or server_description.max_wire_version < 17: - new_election_tuple: tuple = (server_description.set_version, server_description.election_id) - max_election_tuple: tuple = (max_set_version, max_election_id) - if None not in new_election_tuple: - if None not in max_election_tuple and new_election_tuple < max_election_tuple: - # Stale primary, set to type Unknown. - sds[server_description.address] = server_description.to_unknown() - return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id - max_election_id = server_description.election_id - - if server_description.set_version is not None and ( - max_set_version is None or server_description.set_version > max_set_version - ): - max_set_version = server_description.set_version - else: - new_election_tuple = server_description.election_id, server_description.set_version - max_election_tuple = max_election_id, max_set_version - new_election_safe = tuple(MinKey() if i is None else i for i in new_election_tuple) - max_election_safe = tuple(MinKey() if i is None else i for i in max_election_tuple) - if new_election_safe < max_election_safe: - # Stale primary, set to type Unknown. - sds[server_description.address] = server_description.to_unknown() - return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id - else: - max_election_id = server_description.election_id - max_set_version = server_description.set_version - - # We've heard from the primary. Is it the same primary as before? - for server in sds.values(): - if ( - server.server_type is SERVER_TYPE.RSPrimary - and server.address != server_description.address - ): - # Reset old primary's type to Unknown. - sds[server.address] = server.to_unknown() - - # There can be only one prior primary. - break - - # Discover new hosts from this primary's response. - for new_address in server_description.all_hosts: - if new_address not in sds: - sds[new_address] = ServerDescription(new_address) - - # Remove hosts not in the response. - for addr in set(sds) - server_description.all_hosts: - sds.pop(addr) - - # If the host list differs from the seed list, we may not have a primary - # after all. - return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id) - - -def _update_rs_with_primary_from_member( - sds: MutableMapping[_Address, ServerDescription], - replica_set_name: Optional[str], - server_description: ServerDescription, -) -> int: - """RS with known primary. Process a response from a non-primary. - - Pass in a dict of ServerDescriptions, current replica set name, and the - ServerDescription we are processing. - - Returns new topology type. - """ - assert replica_set_name is not None - - if replica_set_name != server_description.replica_set_name: - sds.pop(server_description.address) - elif server_description.me and server_description.address != server_description.me: - sds.pop(server_description.address) - - # Had this member been the primary? - return _check_has_primary(sds) - - -def _update_rs_no_primary_from_member( - sds: MutableMapping[_Address, ServerDescription], - replica_set_name: Optional[str], - server_description: ServerDescription, -) -> tuple[int, Optional[str]]: - """RS without known primary. Update from a non-primary's response. - - Pass in a dict of ServerDescriptions, current replica set name, and the - ServerDescription we are processing. - - Returns (new topology type, new replica_set_name). - """ - topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary - if replica_set_name is None: - replica_set_name = server_description.replica_set_name - - elif replica_set_name != server_description.replica_set_name: - sds.pop(server_description.address) - return topology_type, replica_set_name - - # This isn't the primary's response, so don't remove any servers - # it doesn't report. Only add new servers. - for address in server_description.all_hosts: - if address not in sds: - sds[address] = ServerDescription(address) - - if server_description.me and server_description.address != server_description.me: - sds.pop(server_description.address) - - return topology_type, replica_set_name - - -def _check_has_primary(sds: Mapping[_Address, ServerDescription]) -> int: - """Current topology type is ReplicaSetWithPrimary. Is primary still known? - - Pass in a dict of ServerDescriptions. - - Returns new topology type. - """ - for s in sds.values(): - if s.server_type == SERVER_TYPE.RSPrimary: - return TOPOLOGY_TYPE.ReplicaSetWithPrimary - else: # noqa: PLW0120 - return TOPOLOGY_TYPE.ReplicaSetNoPrimary diff --git a/pymongo/asynchronous/typings.py b/pymongo/asynchronous/typings.py deleted file mode 100644 index 508c5b6de..000000000 --- a/pymongo/asynchronous/typings.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2022-Present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Type aliases used by PyMongo""" -from __future__ import annotations - -from typing import ( - TYPE_CHECKING, - Any, - Mapping, - Optional, - Sequence, - Tuple, - TypeVar, - Union, -) - -from bson.typings import _DocumentOut, _DocumentType, _DocumentTypeArg - -if TYPE_CHECKING: - from pymongo.asynchronous.collation import Collation - -_IS_SYNC = False - -# Common Shared Types. -_Address = Tuple[str, Optional[int]] -_CollationIn = Union[Mapping[str, Any], "Collation"] -_Pipeline = Sequence[Mapping[str, Any]] -ClusterTime = Mapping[str, Any] - -_T = TypeVar("_T") - - -def strip_optional(elem: Optional[_T]) -> _T: - """This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T - while inside a list comprehension. - """ - assert elem is not None - return elem - - -__all__ = [ - "_DocumentOut", - "_DocumentType", - "_DocumentTypeArg", - "_Address", - "_CollationIn", - "_Pipeline", - "strip_optional", -] diff --git a/pymongo/asynchronous/uri_parser.py b/pymongo/asynchronous/uri_parser.py deleted file mode 100644 index b5fde6c30..000000000 --- a/pymongo/asynchronous/uri_parser.py +++ /dev/null @@ -1,624 +0,0 @@ -# Copyright 2011-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - - -"""Tools to parse and validate a MongoDB URI.""" -from __future__ import annotations - -import re -import sys -import warnings -from typing import ( - TYPE_CHECKING, - Any, - Mapping, - MutableMapping, - Optional, - Sized, - Union, - cast, -) -from urllib.parse import unquote_plus - -from pymongo.asynchronous.client_options import _parse_ssl_options -from pymongo.asynchronous.common import ( - INTERNAL_URI_OPTION_NAME_MAP, - SRV_SERVICE_NAME, - URI_OPTIONS_DEPRECATION_MAP, - _CaseInsensitiveDictionary, - get_validated_options, -) -from pymongo.asynchronous.srv_resolver import _have_dnspython, _SrvResolver -from pymongo.asynchronous.typings import _Address -from pymongo.errors import ConfigurationError, InvalidURI - -if TYPE_CHECKING: - from pymongo.pyopenssl_context import SSLContext - -_IS_SYNC = False -SCHEME = "mongodb://" -SCHEME_LEN = len(SCHEME) -SRV_SCHEME = "mongodb+srv://" -SRV_SCHEME_LEN = len(SRV_SCHEME) -DEFAULT_PORT = 27017 - - -def _unquoted_percent(s: str) -> bool: - """Check for unescaped percent signs. - - :param s: A string. `s` can have things like '%25', '%2525', - and '%E2%85%A8' but cannot have unquoted percent like '%foo'. - """ - for i in range(len(s)): - if s[i] == "%": - sub = s[i : i + 3] - # If unquoting yields the same string this means there was an - # unquoted %. - if unquote_plus(sub) == sub: - return True - return False - - -def parse_userinfo(userinfo: str) -> tuple[str, str]: - """Validates the format of user information in a MongoDB URI. - Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", - "]", "@") as per RFC 3986 must be escaped. - - Returns a 2-tuple containing the unescaped username followed - by the unescaped password. - - :param userinfo: A string of the form : - """ - if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo): - raise InvalidURI( - "Username and password must be escaped according to " - "RFC 3986, use urllib.parse.quote_plus" - ) - - user, _, passwd = userinfo.partition(":") - # No password is expected with GSSAPI authentication. - if not user: - raise InvalidURI("The empty string is not valid username.") - - return unquote_plus(user), unquote_plus(passwd) - - -def parse_ipv6_literal_host( - entity: str, default_port: Optional[int] -) -> tuple[str, Optional[Union[str, int]]]: - """Validates an IPv6 literal host:port string. - - Returns a 2-tuple of IPv6 literal followed by port where - port is default_port if it wasn't specified in entity. - - :param entity: A string that represents an IPv6 literal enclosed - in braces (e.g. '[::1]' or '[::1]:27017'). - :param default_port: The port number to use when one wasn't - specified in entity. - """ - if entity.find("]") == -1: - raise ValueError( - "an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732." - ) - i = entity.find("]:") - if i == -1: - return entity[1:-1], default_port - return entity[1:i], entity[i + 2 :] - - -def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address: - """Validates a host string - - Returns a 2-tuple of host followed by port where port is default_port - if it wasn't specified in the string. - - :param entity: A host or host:port string where host could be a - hostname or IP address. - :param default_port: The port number to use when one wasn't - specified in entity. - """ - host = entity - port: Optional[Union[str, int]] = default_port - if entity[0] == "[": - host, port = parse_ipv6_literal_host(entity, default_port) - elif entity.endswith(".sock"): - return entity, default_port - elif entity.find(":") != -1: - if entity.count(":") > 1: - raise ValueError( - "Reserved characters such as ':' must be " - "escaped according RFC 2396. An IPv6 " - "address literal must be enclosed in '[' " - "and ']' according to RFC 2732." - ) - host, port = host.split(":", 1) - if isinstance(port, str): - if not port.isdigit() or int(port) > 65535 or int(port) <= 0: - raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}") - port = int(port) - - # Normalize hostname to lowercase, since DNS is case-insensitive: - # http://tools.ietf.org/html/rfc4343 - # This prevents useless rediscovery if "foo.com" is in the seed list but - # "FOO.com" is in the hello response. - return host.lower(), port - - -# Options whose values are implicitly determined by tlsInsecure. -_IMPLICIT_TLSINSECURE_OPTS = { - "tlsallowinvalidcertificates", - "tlsallowinvalidhostnames", - "tlsdisableocspendpointcheck", -} - - -def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary: - """Helper method for split_options which creates the options dict. - Also handles the creation of a list for the URI tag_sets/ - readpreferencetags portion, and the use of a unicode options string. - """ - options = _CaseInsensitiveDictionary() - for uriopt in opts.split(delim): - key, value = uriopt.split("=") - if key.lower() == "readpreferencetags": - options.setdefault(key, []).append(value) - else: - if key in options: - warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2) - if key.lower() == "authmechanismproperties": - val = value - else: - val = unquote_plus(value) - options[key] = val - - return options - - -def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Raise appropriate errors when conflicting TLS options are present in - the options dictionary. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - # Implicitly defined options must not be explicitly specified. - tlsinsecure = options.get("tlsinsecure") - if tlsinsecure is not None: - for opt in _IMPLICIT_TLSINSECURE_OPTS: - if opt in options: - err_msg = "URI options %s and %s cannot be specified simultaneously." - raise InvalidURI( - err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt)) - ) - - # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. - tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates") - if tlsallowinvalidcerts is not None: - if "tlsdisableocspendpointcheck" in options: - err_msg = "URI options %s and %s cannot be specified simultaneously." - raise InvalidURI( - err_msg - % ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck")) - ) - if tlsallowinvalidcerts is True: - options["tlsdisableocspendpointcheck"] = True - - # Handle co-occurence of CRL and OCSP-related options. - tlscrlfile = options.get("tlscrlfile") - if tlscrlfile is not None: - for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"): - if options.get(opt) is True: - err_msg = "URI option %s=True cannot be specified when CRL checking is enabled." - raise InvalidURI(err_msg % (opt,)) - - if "ssl" in options and "tls" in options: - - def truth_value(val: Any) -> Any: - if val in ("true", "false"): - return val == "true" - if isinstance(val, bool): - return val - return val - - if truth_value(options.get("ssl")) != truth_value(options.get("tls")): - err_msg = "Can not specify conflicting values for URI options %s and %s." - raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls"))) - - return options - - -def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Issue appropriate warnings when deprecated options are present in the - options dictionary. Removes deprecated option key, value pairs if the - options dictionary is found to also have the renamed option. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - for optname in list(options): - if optname in URI_OPTIONS_DEPRECATION_MAP: - mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] - if mode == "renamed": - newoptname = message - if newoptname in options: - warn_msg = "Deprecated option '%s' ignored in favor of '%s'." - warnings.warn( - warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), - DeprecationWarning, - stacklevel=2, - ) - options.pop(optname) - continue - warn_msg = "Option '%s' is deprecated, use '%s' instead." - warnings.warn( - warn_msg % (options.cased_key(optname), newoptname), - DeprecationWarning, - stacklevel=2, - ) - elif mode == "removed": - warn_msg = "Option '%s' is deprecated. %s." - warnings.warn( - warn_msg % (options.cased_key(optname), message), - DeprecationWarning, - stacklevel=2, - ) - - return options - - -def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Normalizes option names in the options dictionary by converting them to - their internally-used names. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - # Expand the tlsInsecure option. - tlsinsecure = options.get("tlsinsecure") - if tlsinsecure is not None: - for opt in _IMPLICIT_TLSINSECURE_OPTS: - # Implicit options are logically the same as tlsInsecure. - options[opt] = tlsinsecure - - for optname in list(options): - intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None) - if intname is not None: - options[intname] = options.pop(optname) - - return options - - -def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]: - """Validates and normalizes options passed in a MongoDB URI. - - Returns a new dictionary of validated and normalized options. If warn is - False then errors will be thrown for invalid options, otherwise they will - be ignored and a warning will be issued. - - :param opts: A dict of MongoDB URI options. - :param warn: If ``True`` then warnings will be logged and - invalid options will be ignored. Otherwise invalid options will - cause errors. - """ - return get_validated_options(opts, warn) - - -def split_options( - opts: str, validate: bool = True, warn: bool = False, normalize: bool = True -) -> MutableMapping[str, Any]: - """Takes the options portion of a MongoDB URI, validates each option - and returns the options in a dictionary. - - :param opt: A string representing MongoDB URI options. - :param validate: If ``True`` (the default), validate and normalize all - options. - :param warn: If ``False`` (the default), suppress all warnings raised - during validation of options. - :param normalize: If ``True`` (the default), renames all options to their - internally-used names. - """ - and_idx = opts.find("&") - semi_idx = opts.find(";") - try: - if and_idx >= 0 and semi_idx >= 0: - raise InvalidURI("Can not mix '&' and ';' for option separators.") - elif and_idx >= 0: - options = _parse_options(opts, "&") - elif semi_idx >= 0: - options = _parse_options(opts, ";") - elif opts.find("=") != -1: - options = _parse_options(opts, None) - else: - raise ValueError - except ValueError: - raise InvalidURI("MongoDB URI options are key=value pairs.") from None - - options = _handle_security_options(options) - - options = _handle_option_deprecations(options) - - if normalize: - options = _normalize_options(options) - - if validate: - options = cast(_CaseInsensitiveDictionary, validate_options(options, warn)) - if options.get("authsource") == "": - raise InvalidURI("the authSource database cannot be an empty string") - - return options - - -def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]: - """Takes a string of the form host1[:port],host2[:port]... and - splits it into (host, port) tuples. If [:port] isn't present the - default_port is used. - - Returns a set of 2-tuples containing the host name (or IP) followed by - port number. - - :param hosts: A string of the form host1[:port],host2[:port],... - :param default_port: The port number to use when one wasn't specified - for a host. - """ - nodes = [] - for entity in hosts.split(","): - if not entity: - raise ConfigurationError("Empty host (or extra comma in host list).") - port = default_port - # Unix socket entities don't have ports - if entity.endswith(".sock"): - port = None - nodes.append(parse_host(entity, port)) - return nodes - - -# Prohibited characters in database name. DB names also can't have ".", but for -# backward-compat we allow "db.collection" in URI. -_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]") - -_ALLOWED_TXT_OPTS = frozenset( - ["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"] -) - - -def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None: - # Ensure directConnection was not True if there are multiple seeds. - if len(nodes) > 1 and options.get("directconnection"): - raise ConfigurationError("Cannot specify multiple hosts with directConnection=true") - - if options.get("loadbalanced"): - if len(nodes) > 1: - raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true") - if options.get("directconnection"): - raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true") - if options.get("replicaset"): - raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") - - -def parse_uri( - uri: str, - default_port: Optional[int] = DEFAULT_PORT, - validate: bool = True, - warn: bool = False, - normalize: bool = True, - connect_timeout: Optional[float] = None, - srv_service_name: Optional[str] = None, - srv_max_hosts: Optional[int] = None, -) -> dict[str, Any]: - """Parse and validate a MongoDB URI. - - Returns a dict of the form:: - - { - 'nodelist': , - 'username': or None, - 'password': or None, - 'database': or None, - 'collection': or None, - 'options': , - 'fqdn': or None - } - - If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done - to build nodelist and options. - - :param uri: The MongoDB URI to parse. - :param default_port: The port number to use when one wasn't specified - for a host in the URI. - :param validate: If ``True`` (the default), validate and - normalize all options. Default: ``True``. - :param warn: When validating, if ``True`` then will warn - the user then ignore any invalid options or values. If ``False``, - validation will error when options are unsupported or values are - invalid. Default: ``False``. - :param normalize: If ``True``, convert names of URI options - to their internally-used names. Default: ``True``. - :param connect_timeout: The maximum time in milliseconds to - wait for a response from the DNS server. - :param srv_service_name: A custom SRV service name - - .. versionchanged:: 4.6 - The delimiting slash (``/``) between hosts and connection options is now optional. - For example, "mongodb://example.com?tls=true" is now a valid URI. - - .. versionchanged:: 4.0 - To better follow RFC 3986, unquoted percent signs ("%") are no longer - supported. - - .. versionchanged:: 3.9 - Added the ``normalize`` parameter. - - .. versionchanged:: 3.6 - Added support for mongodb+srv:// URIs. - - .. versionchanged:: 3.5 - Return the original value of the ``readPreference`` MongoDB URI option - instead of the validated read preference mode. - - .. versionchanged:: 3.1 - ``warn`` added so invalid options can be ignored. - """ - if uri.startswith(SCHEME): - is_srv = False - scheme_free = uri[SCHEME_LEN:] - elif uri.startswith(SRV_SCHEME): - if not _have_dnspython(): - python_path = sys.executable or "python" - raise ConfigurationError( - 'The "dnspython" module must be ' - "installed to use mongodb+srv:// URIs. " - "To fix this error install pymongo again:\n " - "%s -m pip install pymongo>=4.3" % (python_path) - ) - is_srv = True - scheme_free = uri[SRV_SCHEME_LEN:] - else: - raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") - - if not scheme_free: - raise InvalidURI("Must provide at least one hostname or IP.") - - user = None - passwd = None - dbase = None - collection = None - options = _CaseInsensitiveDictionary() - - host_plus_db_part, _, opts = scheme_free.partition("?") - if "/" in host_plus_db_part: - host_part, _, dbase = host_plus_db_part.partition("/") - else: - host_part = host_plus_db_part - - if dbase: - dbase = unquote_plus(dbase) - if "." in dbase: - dbase, collection = dbase.split(".", 1) - if _BAD_DB_CHARS.search(dbase): - raise InvalidURI('Bad database name "%s"' % dbase) - else: - dbase = None - - if opts: - options.update(split_options(opts, validate, warn, normalize)) - if srv_service_name is None: - srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) - if "@" in host_part: - userinfo, _, hosts = host_part.rpartition("@") - user, passwd = parse_userinfo(userinfo) - else: - hosts = host_part - - if "/" in hosts: - raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) - - hosts = unquote_plus(hosts) - fqdn = None - srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") - if is_srv: - if options.get("directConnection"): - raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") - nodes = split_hosts(hosts, default_port=None) - if len(nodes) != 1: - raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") - fqdn, port = nodes[0] - if port is not None: - raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") - - # Use the connection timeout. connectTimeoutMS passed as a keyword - # argument overrides the same option passed in the connection string. - connect_timeout = connect_timeout or options.get("connectTimeoutMS") - dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) - nodes = dns_resolver.get_hosts() - dns_options = dns_resolver.get_options() - if dns_options: - parsed_dns_options = split_options(dns_options, validate, warn, normalize) - if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: - raise ConfigurationError( - "Only authSource, replicaSet, and loadBalanced are supported from DNS" - ) - for opt, val in parsed_dns_options.items(): - if opt not in options: - options[opt] = val - if options.get("loadBalanced") and srv_max_hosts: - raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") - if options.get("replicaSet") and srv_max_hosts: - raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") - if "tls" not in options and "ssl" not in options: - options["tls"] = True if validate else "true" - elif not is_srv and options.get("srvServiceName") is not None: - raise ConfigurationError( - "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" - ) - elif not is_srv and srv_max_hosts: - raise ConfigurationError( - "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" - ) - else: - nodes = split_hosts(hosts, default_port=default_port) - - _check_options(nodes, options) - - return { - "nodelist": nodes, - "username": user, - "password": passwd, - "database": dbase, - "collection": collection, - "options": options, - "fqdn": fqdn, - } - - -def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]: - """Parse KMS TLS connection options.""" - if not kms_tls_options: - return {} - if not isinstance(kms_tls_options, dict): - raise TypeError("kms_tls_options must be a dict") - contexts = {} - for provider, options in kms_tls_options.items(): - if not isinstance(options, dict): - raise TypeError(f'kms_tls_options["{provider}"] must be a dict') - options.setdefault("tls", True) - opts = _CaseInsensitiveDictionary(options) - opts = _handle_security_options(opts) - opts = _normalize_options(opts) - opts = cast(_CaseInsensitiveDictionary, validate_options(opts)) - ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts) - if ssl_context is None: - raise ConfigurationError("TLS is required for KMS providers") - if allow_invalid_hostnames: - raise ConfigurationError("Insecure TLS options prohibited") - - for n in [ - "tlsInsecure", - "tlsAllowInvalidCertificates", - "tlsAllowInvalidHostnames", - "tlsDisableCertificateRevocationCheck", - ]: - if n in opts: - raise ConfigurationError(f"Insecure TLS options prohibited: {n}") - contexts[provider] = ssl_context - return contexts - - -if __name__ == "__main__": - import pprint - - try: - pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 - except InvalidURI as exc: - print(exc) # noqa: T201 - sys.exit(0) diff --git a/pymongo/auth.py b/pymongo/auth.py index 13302ae5d..a65113841 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -15,6 +15,7 @@ """Re-import of synchronous Auth API for compatibility.""" from __future__ import annotations +from pymongo.auth_shared import * # noqa: F403 from pymongo.synchronous.auth import * # noqa: F403 from pymongo.synchronous.auth import __doc__ as original_doc diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index fa7f7f297..4ac266de5 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -15,7 +15,9 @@ """Re-import of synchronous AuthOIDC API for compatibility.""" from __future__ import annotations +from pymongo.auth_oidc_shared import * # noqa: F403 from pymongo.synchronous.auth_oidc import * # noqa: F403 from pymongo.synchronous.auth_oidc import __doc__ as original_doc __doc__ = original_doc +__all__ = ["OIDCCallback", "OIDCCallbackContext", "OIDCCallbackResult", "OIDCIdPInfo"] # noqa: F405 diff --git a/pymongo/auth_oidc_shared.py b/pymongo/auth_oidc_shared.py new file mode 100644 index 000000000..5e3603fa3 --- /dev/null +++ b/pymongo/auth_oidc_shared.py @@ -0,0 +1,118 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Constants, types, and classes shared across OIDC auth implementations.""" +from __future__ import annotations + +import abc +import os +from dataclasses import dataclass, field +from typing import Optional +from urllib.parse import quote + +from pymongo._azure_helpers import _get_azure_response +from pymongo._gcp_helpers import _get_gcp_response + + +@dataclass +class OIDCIdPInfo: + issuer: str + clientId: Optional[str] = field(default=None) + requestScopes: Optional[list[str]] = field(default=None) + + +@dataclass +class OIDCCallbackContext: + timeout_seconds: float + username: str + version: int + refresh_token: Optional[str] = field(default=None) + idp_info: Optional[OIDCIdPInfo] = field(default=None) + + +@dataclass +class OIDCCallbackResult: + access_token: str + expires_in_seconds: Optional[float] = field(default=None) + refresh_token: Optional[str] = field(default=None) + + +class OIDCCallback(abc.ABC): + """A base class for defining OIDC callbacks.""" + + @abc.abstractmethod + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + """Convert the given BSON value into our own type.""" + + +@dataclass +class _OIDCProperties: + callback: Optional[OIDCCallback] = field(default=None) + human_callback: Optional[OIDCCallback] = field(default=None) + environment: Optional[str] = field(default=None) + allowed_hosts: list[str] = field(default_factory=list) + token_resource: Optional[str] = field(default=None) + username: str = "" + + +"""Mechanism properties for MONGODB-OIDC authentication.""" + +TOKEN_BUFFER_MINUTES = 5 +HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60 +CALLBACK_VERSION = 1 +MACHINE_CALLBACK_TIMEOUT_SECONDS = 60 +TIME_BETWEEN_CALLS_SECONDS = 0.1 + + +class _OIDCTestCallback(OIDCCallback): + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + token_file = os.environ.get("OIDC_TOKEN_FILE") + if not token_file: + raise RuntimeError( + 'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set' + ) + with open(token_file) as fid: + return OIDCCallbackResult(access_token=fid.read().strip()) + + +class _OIDCAWSCallback(OIDCCallback): + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") + if not token_file: + raise RuntimeError( + 'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set' + ) + with open(token_file) as fid: + return OIDCCallbackResult(access_token=fid.read().strip()) + + +class _OIDCAzureCallback(OIDCCallback): + def __init__(self, token_resource: str) -> None: + self.token_resource = quote(token_resource) + + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds) + return OIDCCallbackResult( + access_token=resp["access_token"], expires_in_seconds=resp["expires_in"] + ) + + +class _OIDCGCPCallback(OIDCCallback): + def __init__(self, token_resource: str) -> None: + self.token_resource = quote(token_resource) + + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + resp = _get_gcp_response(self.token_resource, context.timeout_seconds) + return OIDCCallbackResult(access_token=resp["access_token"]) diff --git a/pymongo/auth_shared.py b/pymongo/auth_shared.py new file mode 100644 index 000000000..7e3acd9df --- /dev/null +++ b/pymongo/auth_shared.py @@ -0,0 +1,236 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Constants and types shared across multiple auth types.""" +from __future__ import annotations + +import os +import typing +from base64 import standard_b64encode +from collections import namedtuple +from typing import Any, Dict, Mapping, Optional + +from bson import Binary +from pymongo.auth_oidc_shared import ( + _OIDCAzureCallback, + _OIDCGCPCallback, + _OIDCProperties, + _OIDCTestCallback, +) +from pymongo.errors import ConfigurationError + +MECHANISMS = frozenset( + [ + "GSSAPI", + "MONGODB-CR", + "MONGODB-OIDC", + "MONGODB-X509", + "MONGODB-AWS", + "PLAIN", + "SCRAM-SHA-1", + "SCRAM-SHA-256", + "DEFAULT", + ] +) +"""The authentication mechanisms supported by PyMongo.""" + + +class _Cache: + __slots__ = ("data",) + + _hash_val = hash("_Cache") + + def __init__(self) -> None: + self.data = None + + def __eq__(self, other: object) -> bool: + # Two instances must always compare equal. + if isinstance(other, _Cache): + return True + return NotImplemented + + def __ne__(self, other: object) -> bool: + if isinstance(other, _Cache): + return False + return NotImplemented + + def __hash__(self) -> int: + return self._hash_val + + +MongoCredential = namedtuple( + "MongoCredential", + ["mechanism", "source", "username", "password", "mechanism_properties", "cache"], +) +"""A hashable namedtuple of values used for authentication.""" + + +GSSAPIProperties = namedtuple( + "GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"] +) +"""Mechanism properties for GSSAPI authentication.""" + + +_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"]) +"""Mechanism properties for MONGODB-AWS authentication.""" + + +def _build_credentials_tuple( + mech: str, + source: Optional[str], + user: str, + passwd: str, + extra: Mapping[str, Any], + database: Optional[str], +) -> MongoCredential: + """Build and return a mechanism specific credentials tuple.""" + if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: + raise ConfigurationError(f"{mech} requires a username.") + if mech == "GSSAPI": + if source is not None and source != "$external": + raise ValueError("authentication source must be $external or None for GSSAPI") + properties = extra.get("authmechanismproperties", {}) + service_name = properties.get("SERVICE_NAME", "mongodb") + canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False)) + service_realm = properties.get("SERVICE_REALM") + props = GSSAPIProperties( + service_name=service_name, + canonicalize_host_name=canonicalize, + service_realm=service_realm, + ) + # Source is always $external. + return MongoCredential(mech, "$external", user, passwd, props, None) + elif mech == "MONGODB-X509": + if passwd is not None: + raise ConfigurationError("Passwords are not supported by MONGODB-X509") + if source is not None and source != "$external": + raise ValueError("authentication source must be $external or None for MONGODB-X509") + # Source is always $external, user can be None. + return MongoCredential(mech, "$external", user, None, None, None) + elif mech == "MONGODB-AWS": + if user is not None and passwd is None: + raise ConfigurationError("username without a password is not supported by MONGODB-AWS") + if source is not None and source != "$external": + raise ConfigurationError( + "authentication source must be $external or None for MONGODB-AWS" + ) + + properties = extra.get("authmechanismproperties", {}) + aws_session_token = properties.get("AWS_SESSION_TOKEN") + aws_props = _AWSProperties(aws_session_token=aws_session_token) + # user can be None for temporary link-local EC2 credentials. + return MongoCredential(mech, "$external", user, passwd, aws_props, None) + elif mech == "MONGODB-OIDC": + properties = extra.get("authmechanismproperties", {}) + callback = properties.get("OIDC_CALLBACK") + human_callback = properties.get("OIDC_HUMAN_CALLBACK") + environ = properties.get("ENVIRONMENT") + token_resource = properties.get("TOKEN_RESOURCE", "") + default_allowed = [ + "*.mongodb.net", + "*.mongodb-dev.net", + "*.mongodb-qa.net", + "*.mongodbgov.net", + "localhost", + "127.0.0.1", + "::1", + ] + allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed) + msg = ( + "authentication with MONGODB-OIDC requires providing either a callback or a environment" + ) + if passwd is not None: + msg = "password is not supported by MONGODB-OIDC" + raise ConfigurationError(msg) + if callback or human_callback: + if environ is not None: + raise ConfigurationError(msg) + if callback and human_callback: + msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK" + raise ConfigurationError(msg) + elif environ is not None: + if environ == "test": + if user is not None: + msg = "test environment for MONGODB-OIDC does not support username" + raise ConfigurationError(msg) + callback = _OIDCTestCallback() + elif environ == "azure": + passwd = None + if not token_resource: + raise ConfigurationError( + "Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" + ) + callback = _OIDCAzureCallback(token_resource) + elif environ == "gcp": + passwd = None + if not token_resource: + raise ConfigurationError( + "GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" + ) + callback = _OIDCGCPCallback(token_resource) + else: + raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}") + else: + raise ConfigurationError(msg) + + oidc_props = _OIDCProperties( + callback=callback, + human_callback=human_callback, + environment=environ, + allowed_hosts=allowed_hosts, + token_resource=token_resource, + username=user, + ) + return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache()) + + elif mech == "PLAIN": + source_database = source or database or "$external" + return MongoCredential(mech, source_database, user, passwd, None, None) + else: + source_database = source or database or "admin" + if passwd is None: + raise ConfigurationError("A password is required.") + return MongoCredential(mech, source_database, user, passwd, None, _Cache()) + + +def _xor(fir: bytes, sec: bytes) -> bytes: + """XOR two byte strings together.""" + return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)]) + + +def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]: + """Split a scram response into key, value pairs.""" + return dict( + typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1)) + for item in response.split(b",") + ) + + +def _authenticate_scram_start( + credentials: MongoCredential, mechanism: str +) -> tuple[bytes, bytes, typing.MutableMapping[str, Any]]: + username = credentials.username + user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C") + nonce = standard_b64encode(os.urandom(32)) + first_bare = b"n=" + user + b",r=" + nonce + + cmd = { + "saslStart": 1, + "mechanism": mechanism, + "payload": Binary(b"n,," + first_bare), + "autoAuthorize": 1, + "options": {"skipEmptyExchange": True}, + } + return nonce, first_bare, cmd diff --git a/pymongo/bulk_shared.py b/pymongo/bulk_shared.py new file mode 100644 index 000000000..7aa6340d5 --- /dev/null +++ b/pymongo/bulk_shared.py @@ -0,0 +1,131 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Constants, types, and classes shared across Bulk Write API implementations.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, NoReturn + +from pymongo.errors import BulkWriteError, OperationFailure +from pymongo.helpers_shared import _get_wce_doc +from pymongo.message import ( + _DELETE, + _INSERT, + _UPDATE, +) + +if TYPE_CHECKING: + from pymongo.typings import _DocumentOut + + +_DELETE_ALL: int = 0 +_DELETE_ONE: int = 1 + +# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err +_BAD_VALUE: int = 2 +_UNKNOWN_ERROR: int = 8 +_WRITE_CONCERN_ERROR: int = 64 + +_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete") + + +class _Run: + """Represents a batch of write operations.""" + + def __init__(self, op_type: int) -> None: + """Initialize a new Run object.""" + self.op_type: int = op_type + self.index_map: list[int] = [] + self.ops: list[Any] = [] + self.idx_offset: int = 0 + + def index(self, idx: int) -> int: + """Get the original index of an operation in this run. + + :param idx: The Run index that maps to the original index. + """ + return self.index_map[idx] + + def add(self, original_index: int, operation: Any) -> None: + """Add an operation to this Run instance. + + :param original_index: The original index of this operation + within a larger bulk operation. + :param operation: The operation document. + """ + self.index_map.append(original_index) + self.ops.append(operation) + + +def _merge_command( + run: _Run, + full_result: MutableMapping[str, Any], + offset: int, + result: Mapping[str, Any], +) -> None: + """Merge a write command result into the full bulk result.""" + affected = result.get("n", 0) + + if run.op_type == _INSERT: + full_result["nInserted"] += affected + + elif run.op_type == _DELETE: + full_result["nRemoved"] += affected + + elif run.op_type == _UPDATE: + upserted = result.get("upserted") + if upserted: + n_upserted = len(upserted) + for doc in upserted: + doc["index"] = run.index(doc["index"] + offset) + full_result["upserted"].extend(upserted) + full_result["nUpserted"] += n_upserted + full_result["nMatched"] += affected - n_upserted + else: + full_result["nMatched"] += affected + full_result["nModified"] += result["nModified"] + + write_errors = result.get("writeErrors") + if write_errors: + for doc in write_errors: + # Leave the server response intact for APM. + replacement = doc.copy() + idx = doc["index"] + offset + replacement["index"] = run.index(idx) + # Add the failed operation to the error document. + replacement["op"] = run.ops[idx] + full_result["writeErrors"].append(replacement) + + wce = _get_wce_doc(result) + if wce: + full_result["writeConcernErrors"].append(wce) + + +def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn: + """Raise a BulkWriteError from the full bulk api result.""" + # retryWrites on MMAPv1 should raise an actionable error. + if full_result["writeErrors"]: + full_result["writeErrors"].sort(key=lambda error: error["index"]) + err = full_result["writeErrors"][0] + code = err["code"] + msg = err["errmsg"] + if code == 20 and msg.startswith("Transaction numbers"): + errmsg = ( + "This MongoDB deployment does not support " + "retryable writes. Please add retryWrites=false " + "to your connection string." + ) + raise OperationFailure(errmsg, code, full_result) + raise BulkWriteError(full_result) diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index 5decc0991..b96a1750c 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -19,3 +19,4 @@ from pymongo.synchronous.change_stream import * # noqa: F403 from pymongo.synchronous.change_stream import __doc__ as original_doc __doc__ = original_doc +__all__ = ["ChangeStream", "ClusterChangeStream", "CollectionChangeStream", "DatabaseChangeStream"] # noqa: F405 diff --git a/pymongo/client_options.py b/pymongo/client_options.py index 7a4e04453..2fb7b30c7 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -1,21 +1,332 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2014-present MongoDB, Inc. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. -"""Re-import of synchronous ClientOptions API for compatibility.""" +"""Tools to parse mongo client options.""" from __future__ import annotations -from pymongo.synchronous.client_options import * # noqa: F403 -from pymongo.synchronous.client_options import __doc__ as original_doc +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast -__doc__ = original_doc +from bson.codec_options import _parse_codec_options +from pymongo import common +from pymongo.compression_support import CompressionSettings +from pymongo.errors import ConfigurationError +from pymongo.monitoring import _EventListener, _EventListeners +from pymongo.pool_options import PoolOptions +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ( + _ServerMode, + make_read_preference, + read_pref_mode_from_name, +) +from pymongo.server_selectors import any_server_selector +from pymongo.ssl_support import get_ssl_context +from pymongo.write_concern import WriteConcern, validate_boolean + +if TYPE_CHECKING: + from bson.codec_options import CodecOptions + from pymongo.auth_shared import MongoCredential + from pymongo.encryption_options import AutoEncryptionOpts + from pymongo.pyopenssl_context import SSLContext + from pymongo.topology_description import _ServerSelector + + +def _parse_credentials( + username: str, password: str, database: Optional[str], options: Mapping[str, Any] +) -> Optional[MongoCredential]: + """Parse authentication credentials.""" + mechanism = options.get("authmechanism", "DEFAULT" if username else None) + source = options.get("authsource") + if username or mechanism: + from pymongo.auth_shared import _build_credentials_tuple + + return _build_credentials_tuple(mechanism, source, username, password, options, database) + return None + + +def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode: + """Parse read preference options.""" + if "read_preference" in options: + return options["read_preference"] + + name = options.get("readpreference", "primary") + mode = read_pref_mode_from_name(name) + tags = options.get("readpreferencetags") + max_staleness = options.get("maxstalenessseconds", -1) + return make_read_preference(mode, tags, max_staleness) + + +def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern: + """Parse write concern options.""" + concern = options.get("w") + wtimeout = options.get("wtimeoutms") + j = options.get("journal") + fsync = options.get("fsync") + return WriteConcern(concern, wtimeout, j, fsync) + + +def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern: + """Parse read concern options.""" + concern = options.get("readconcernlevel") + return ReadConcern(concern) + + +def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]: + """Parse ssl options.""" + use_tls = options.get("tls") + if use_tls is not None: + validate_boolean("tls", use_tls) + + certfile = options.get("tlscertificatekeyfile") + passphrase = options.get("tlscertificatekeyfilepassword") + ca_certs = options.get("tlscafile") + crlfile = options.get("tlscrlfile") + allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False) + allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False) + disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False) + + enabled_tls_opts = [] + for opt in ( + "tlscertificatekeyfile", + "tlscertificatekeyfilepassword", + "tlscafile", + "tlscrlfile", + ): + # Any non-null value of these options implies tls=True. + if opt in options and options[opt]: + enabled_tls_opts.append(opt) + for opt in ( + "tlsallowinvalidcertificates", + "tlsallowinvalidhostnames", + "tlsdisableocspendpointcheck", + ): + # A value of False for these options implies tls=True. + if opt in options and not options[opt]: + enabled_tls_opts.append(opt) + + if enabled_tls_opts: + if use_tls is None: + # Implicitly enable TLS when one of the tls* options is set. + use_tls = True + elif not use_tls: + # Error since tls is explicitly disabled but a tls option is set. + raise ConfigurationError( + "TLS has not been enabled but the " + "following tls parameters have been set: " + "%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts) + ) + + if use_tls: + ctx = get_ssl_context( + certfile, + passphrase, + ca_certs, + crlfile, + allow_invalid_certificates, + allow_invalid_hostnames, + disable_ocsp_endpoint_check, + ) + return ctx, allow_invalid_hostnames + return None, allow_invalid_hostnames + + +def _parse_pool_options( + username: str, password: str, database: Optional[str], options: Mapping[str, Any] +) -> PoolOptions: + """Parse connection pool options.""" + credentials = _parse_credentials(username, password, database, options) + max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE) + min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE) + max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC) + if max_pool_size is not None and min_pool_size > max_pool_size: + raise ValueError("minPoolSize must be smaller or equal to maxPoolSize") + connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT) + socket_timeout = options.get("sockettimeoutms") + wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT) + event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners")) + appname = options.get("appname") + driver = options.get("driver") + server_api = options.get("server_api") + compression_settings = CompressionSettings( + options.get("compressors", []), options.get("zlibcompressionlevel", -1) + ) + ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options) + load_balanced = options.get("loadbalanced") + max_connecting = options.get("maxconnecting", common.MAX_CONNECTING) + return PoolOptions( + max_pool_size, + min_pool_size, + max_idle_time_seconds, + connect_timeout, + socket_timeout, + wait_queue_timeout, + ssl_context, + tls_allow_invalid_hostnames, + _EventListeners(event_listeners), + appname, + driver, + compression_settings, + max_connecting=max_connecting, + server_api=server_api, + load_balanced=load_balanced, + credentials=credentials, + ) + + +class ClientOptions: + """Read only configuration options for an AsyncMongoClient/MongoClient. + + Should not be instantiated directly by application developers. Access + a client's options via :attr:`pymongo.mongo_client.AsyncMongoClient.options` or :attr:`pymongo.mongo_client.MongoClient.options` + instead. + """ + + def __init__( + self, username: str, password: str, database: Optional[str], options: Mapping[str, Any] + ): + self.__options = options + self.__codec_options = _parse_codec_options(options) + self.__direct_connection = options.get("directconnection") + self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS) + # self.__server_selection_timeout is in seconds. Must use full name for + # common.SERVER_SELECTION_TIMEOUT because it is set directly by tests. + self.__server_selection_timeout = options.get( + "serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT + ) + self.__pool_options = _parse_pool_options(username, password, database, options) + self.__read_preference = _parse_read_preference(options) + self.__replica_set_name = options.get("replicaset") + self.__write_concern = _parse_write_concern(options) + self.__read_concern = _parse_read_concern(options) + self.__connect = options.get("connect") + self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY) + self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES) + self.__retry_reads = options.get("retryreads", common.RETRY_READS) + self.__server_selector = options.get("server_selector", any_server_selector) + self.__auto_encryption_opts = options.get("auto_encryption_opts") + self.__load_balanced = options.get("loadbalanced") + self.__timeout = options.get("timeoutms") + self.__server_monitoring_mode = options.get( + "servermonitoringmode", common.SERVER_MONITORING_MODE + ) + + @property + def _options(self) -> Mapping[str, Any]: + """The original options used to create this ClientOptions.""" + return self.__options + + @property + def connect(self) -> Optional[bool]: + """Whether to begin discovering a MongoDB topology automatically.""" + return self.__connect + + @property + def codec_options(self) -> CodecOptions: + """A :class:`~bson.codec_options.CodecOptions` instance.""" + return self.__codec_options + + @property + def direct_connection(self) -> Optional[bool]: + """Whether to connect to the deployment in 'Single' topology.""" + return self.__direct_connection + + @property + def local_threshold_ms(self) -> int: + """The local threshold for this instance.""" + return self.__local_threshold_ms + + @property + def server_selection_timeout(self) -> int: + """The server selection timeout for this instance in seconds.""" + return self.__server_selection_timeout + + @property + def server_selector(self) -> _ServerSelector: + return self.__server_selector + + @property + def heartbeat_frequency(self) -> int: + """The monitoring frequency in seconds.""" + return self.__heartbeat_frequency + + @property + def pool_options(self) -> PoolOptions: + """A :class:`~pymongo.pool.PoolOptions` instance.""" + return self.__pool_options + + @property + def read_preference(self) -> _ServerMode: + """A read preference instance.""" + return self.__read_preference + + @property + def replica_set_name(self) -> Optional[str]: + """Replica set name or None.""" + return self.__replica_set_name + + @property + def write_concern(self) -> WriteConcern: + """A :class:`~pymongo.write_concern.WriteConcern` instance.""" + return self.__write_concern + + @property + def read_concern(self) -> ReadConcern: + """A :class:`~pymongo.read_concern.ReadConcern` instance.""" + return self.__read_concern + + @property + def timeout(self) -> Optional[float]: + """The configured timeoutMS converted to seconds, or None. + + .. versionadded:: 4.2 + """ + return self.__timeout + + @property + def retry_writes(self) -> bool: + """If this instance should retry supported write operations.""" + return self.__retry_writes + + @property + def retry_reads(self) -> bool: + """If this instance should retry supported read operations.""" + return self.__retry_reads + + @property + def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]: + """A :class:`~pymongo.encryption.AutoEncryptionOpts` or None.""" + return self.__auto_encryption_opts + + @property + def load_balanced(self) -> Optional[bool]: + """True if the client was configured to connect to a load balancer.""" + return self.__load_balanced + + @property + def event_listeners(self) -> list[_EventListeners]: + """The event listeners registered for this client. + + See :mod:`~pymongo.monitoring` for details. + + .. versionadded:: 4.0 + """ + assert self.__pool_options._event_listeners is not None + return self.__pool_options._event_listeners.event_listeners() + + @property + def server_monitoring_mode(self) -> str: + """The configured serverMonitoringMode option. + + .. versionadded:: 4.5 + """ + return self.__server_monitoring_mode diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 0597e8986..1a3af44e1 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -19,3 +19,4 @@ from pymongo.synchronous.client_session import * # noqa: F403 from pymongo.synchronous.client_session import __doc__ as original_doc __doc__ = original_doc +__all__ = ["ClientSession", "SessionOptions", "TransactionOptions"] # noqa: F405 diff --git a/pymongo/collation.py b/pymongo/collation.py index b129a0451..995687296 100644 --- a/pymongo/collation.py +++ b/pymongo/collation.py @@ -1,4 +1,4 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2016-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,213 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-import of synchronous Collation API for compatibility.""" +"""Tools for working with `collations`_. + +.. _collations: https://www.mongodb.com/docs/manual/reference/collation/ +""" from __future__ import annotations -from pymongo.synchronous.collation import * # noqa: F403 -from pymongo.synchronous.collation import __doc__ as original_doc +from typing import Any, Mapping, Optional, Union -__doc__ = original_doc +from pymongo import common +from pymongo.write_concern import validate_boolean + + +class CollationStrength: + """ + An enum that defines values for `strength` on a + :class:`~pymongo.collation.Collation`. + """ + + PRIMARY = 1 + """Differentiate base (unadorned) characters.""" + + SECONDARY = 2 + """Differentiate character accents.""" + + TERTIARY = 3 + """Differentiate character case.""" + + QUATERNARY = 4 + """Differentiate words with and without punctuation.""" + + IDENTICAL = 5 + """Differentiate unicode code point (characters are exactly identical).""" + + +class CollationAlternate: + """ + An enum that defines values for `alternate` on a + :class:`~pymongo.collation.Collation`. + """ + + NON_IGNORABLE = "non-ignorable" + """Spaces and punctuation are treated as base characters.""" + + SHIFTED = "shifted" + """Spaces and punctuation are *not* considered base characters. + + Spaces and punctuation are distinguished regardless when the + :class:`~pymongo.collation.Collation` strength is at least + :data:`~pymongo.collation.CollationStrength.QUATERNARY`. + + """ + + +class CollationMaxVariable: + """ + An enum that defines values for `max_variable` on a + :class:`~pymongo.collation.Collation`. + """ + + PUNCT = "punct" + """Both punctuation and spaces are ignored.""" + + SPACE = "space" + """Spaces alone are ignored.""" + + +class CollationCaseFirst: + """ + An enum that defines values for `case_first` on a + :class:`~pymongo.collation.Collation`. + """ + + UPPER = "upper" + """Sort uppercase characters first.""" + + LOWER = "lower" + """Sort lowercase characters first.""" + + OFF = "off" + """Default for locale or collation strength.""" + + +class Collation: + """Collation + + :param locale: (string) The locale of the collation. This should be a string + that identifies an `ICU locale ID` exactly. For example, ``en_US`` is + valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB + documentation for a list of supported locales. + :param caseLevel: (optional) If ``True``, turn on case sensitivity if + `strength` is 1 or 2 (case sensitivity is implied if `strength` is + greater than 2). Defaults to ``False``. + :param caseFirst: (optional) Specify that either uppercase or lowercase + characters take precedence. Must be one of the following values: + + * :data:`~CollationCaseFirst.UPPER` + * :data:`~CollationCaseFirst.LOWER` + * :data:`~CollationCaseFirst.OFF` (the default) + + :param strength: Specify the comparison strength. This is also + known as the ICU comparison level. This must be one of the following + values: + + * :data:`~CollationStrength.PRIMARY` + * :data:`~CollationStrength.SECONDARY` + * :data:`~CollationStrength.TERTIARY` (the default) + * :data:`~CollationStrength.QUATERNARY` + * :data:`~CollationStrength.IDENTICAL` + + Each successive level builds upon the previous. For example, a + `strength` of :data:`~CollationStrength.SECONDARY` differentiates + characters based both on the unadorned base character and its accents. + + :param numericOrdering: If ``True``, order numbers numerically + instead of in collation order (defaults to ``False``). + :param alternate: Specify whether spaces and punctuation are + considered base characters. This must be one of the following values: + + * :data:`~CollationAlternate.NON_IGNORABLE` (the default) + * :data:`~CollationAlternate.SHIFTED` + + :param maxVariable: When `alternate` is + :data:`~CollationAlternate.SHIFTED`, this option specifies what + characters may be ignored. This must be one of the following values: + + * :data:`~CollationMaxVariable.PUNCT` (the default) + * :data:`~CollationMaxVariable.SPACE` + + :param normalization: If ``True``, normalizes text into Unicode + NFD. Defaults to ``False``. + :param backwards: If ``True``, accents on characters are + considered from the back of the word to the front, as it is done in some + French dictionary ordering traditions. Defaults to ``False``. + :param kwargs: Keyword arguments supplying any additional options + to be sent with this Collation object. + + .. versionadded: 3.4 + + """ + + __slots__ = ("__document",) + + def __init__( + self, + locale: str, + caseLevel: Optional[bool] = None, + caseFirst: Optional[str] = None, + strength: Optional[int] = None, + numericOrdering: Optional[bool] = None, + alternate: Optional[str] = None, + maxVariable: Optional[str] = None, + normalization: Optional[bool] = None, + backwards: Optional[bool] = None, + **kwargs: Any, + ) -> None: + locale = common.validate_string("locale", locale) + self.__document: dict[str, Any] = {"locale": locale} + if caseLevel is not None: + self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel) + if caseFirst is not None: + self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst) + if strength is not None: + self.__document["strength"] = common.validate_integer("strength", strength) + if numericOrdering is not None: + self.__document["numericOrdering"] = validate_boolean( + "numericOrdering", numericOrdering + ) + if alternate is not None: + self.__document["alternate"] = common.validate_string("alternate", alternate) + if maxVariable is not None: + self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable) + if normalization is not None: + self.__document["normalization"] = validate_boolean("normalization", normalization) + if backwards is not None: + self.__document["backwards"] = validate_boolean("backwards", backwards) + self.__document.update(kwargs) + + @property + def document(self) -> dict[str, Any]: + """The document representation of this collation. + + .. note:: + :class:`Collation` is immutable. Mutating the value of + :attr:`document` does not mutate this :class:`Collation`. + """ + return self.__document.copy() + + def __repr__(self) -> str: + document = self.document + return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document)) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Collation): + return self.document == other.document + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + +def validate_collation_or_none( + value: Optional[Union[Mapping[str, Any], Collation]] +) -> Optional[dict[str, Any]]: + if value is None: + return None + if isinstance(value, Collation): + return value.document + if isinstance(value, dict): + return value + raise TypeError("collation must be a dict, an instance of collation.Collation, or None.") diff --git a/pymongo/collection.py b/pymongo/collection.py index c7427f9b6..f726ed037 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -19,3 +19,7 @@ from pymongo.synchronous.collection import * # noqa: F403 from pymongo.synchronous.collection import __doc__ as original_doc __doc__ = original_doc +__all__ = [ # noqa: F405 + "Collection", + "ReturnDocument", +] diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index d9ca3ee40..941e3a0ed 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -19,3 +19,4 @@ from pymongo.synchronous.command_cursor import * # noqa: F403 from pymongo.synchronous.command_cursor import __doc__ as original_doc __doc__ = original_doc +__all__ = ["CommandCursor", "RawBatchCommandCursor"] # noqa: F405 diff --git a/pymongo/synchronous/common.py b/pymongo/common.py similarity index 98% rename from pymongo/synchronous/common.py rename to pymongo/common.py index 13e58aded..a073eba57 100644 --- a/pymongo/synchronous/common.py +++ b/pymongo/common.py @@ -40,22 +40,21 @@ from bson import SON from bson.binary import UuidRepresentation from bson.codec_options import CodecOptions, DatetimeConversion, TypeRegistry from bson.raw_bson import RawBSONDocument -from pymongo.driver_info import DriverInfo -from pymongo.errors import ConfigurationError -from pymongo.read_concern import ReadConcern -from pymongo.server_api import ServerApi -from pymongo.synchronous.compression_support import ( +from pymongo.compression_support import ( validate_compressors, validate_zlib_compression_level, ) -from pymongo.synchronous.monitoring import _validate_event_listeners -from pymongo.synchronous.read_preferences import _MONGOS_MODES, _ServerMode +from pymongo.driver_info import DriverInfo +from pymongo.errors import ConfigurationError +from pymongo.monitoring import _validate_event_listeners +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import _MONGOS_MODES, _ServerMode +from pymongo.server_api import ServerApi from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean if TYPE_CHECKING: - from pymongo.synchronous.client_session import ClientSession + from pymongo.typings import _AgnosticClientSession -_IS_SYNC = True ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict) @@ -69,7 +68,8 @@ MAX_WRITE_BATCH_SIZE = 1000 # What this version of PyMongo supports. MIN_SUPPORTED_SERVER_VERSION = "3.6" MIN_SUPPORTED_WIRE_VERSION = 6 -MAX_SUPPORTED_WIRE_VERSION = 21 +# MongoDB 8.0 +MAX_SUPPORTED_WIRE_VERSION = 25 # Frequency to call hello on servers, in seconds. HEARTBEAT_FREQUENCY = 10 @@ -380,7 +380,7 @@ def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode: def validate_auth_mechanism(option: str, value: Any) -> str: """Validate the authMechanism URI option.""" - from pymongo.synchronous.auth import MECHANISMS + from pymongo.auth_shared import MECHANISMS if value not in MECHANISMS: raise ValueError(f"{option} must be in {tuple(MECHANISMS)}") @@ -446,7 +446,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni elif key in ["ALLOWED_HOSTS"] and isinstance(value, list): props[key] = value elif key in ["OIDC_CALLBACK", "OIDC_HUMAN_CALLBACK"]: - from pymongo.synchronous.auth_oidc import OIDCCallback + from pymongo.auth_oidc_shared import OIDCCallback if not isinstance(value, OIDCCallback): raise ValueError("callback must be an OIDCCallback object") @@ -642,7 +642,7 @@ def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[A """Validate the driver keyword arg.""" if value is None: return value - from pymongo.synchronous.encryption_options import AutoEncryptionOpts + from pymongo.encryption_options import AutoEncryptionOpts if not isinstance(value, AutoEncryptionOpts): raise TypeError(f"{option} must be an instance of AutoEncryptionOpts") @@ -941,7 +941,7 @@ class BaseObject: """ return self._write_concern - def _write_concern_for(self, session: Optional[ClientSession]) -> WriteConcern: + def _write_concern_for(self, session: Optional[_AgnosticClientSession]) -> WriteConcern: """Read only access to the write concern of this instance or session.""" # Override this operation's write concern with the transaction's. if session and session.in_transaction: @@ -957,7 +957,7 @@ class BaseObject: """ return self._read_preference - def _read_preference_for(self, session: Optional[ClientSession]) -> _ServerMode: + def _read_preference_for(self, session: Optional[_AgnosticClientSession]) -> _ServerMode: """Read only access to the read preference of this instance or session.""" # Override this operation's read preference with the transaction's. if session: diff --git a/pymongo/synchronous/compression_support.py b/pymongo/compression_support.py similarity index 97% rename from pymongo/synchronous/compression_support.py rename to pymongo/compression_support.py index e5153f8c8..7123b90df 100644 --- a/pymongo/synchronous/compression_support.py +++ b/pymongo/compression_support.py @@ -16,11 +16,8 @@ from __future__ import annotations import warnings from typing import Any, Iterable, Optional, Union -from pymongo.helpers_constants import _SENSITIVE_COMMANDS -from pymongo.synchronous.hello_compat import HelloCompat - -_IS_SYNC = True - +from pymongo.hello import HelloCompat +from pymongo.helpers_shared import _SENSITIVE_COMMANDS _SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"} _NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD} diff --git a/pymongo/cursor.py b/pymongo/cursor.py index b3ac54c97..869adddc3 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -20,3 +20,4 @@ from pymongo.synchronous.cursor import * # noqa: F403 from pymongo.synchronous.cursor import __doc__ as original_doc __doc__ = original_doc +__all__ = ["Cursor", "CursorType", "RawBatchCursor"] # noqa: F405 diff --git a/pymongo/database.py b/pymongo/database.py index 6c81ac227..bbd05702d 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -19,3 +19,4 @@ from pymongo.synchronous.database import * # noqa: F403 from pymongo.synchronous.database import __doc__ as original_doc __doc__ = original_doc +__all__ = ["Database"] # noqa: F405 diff --git a/pymongo/encryption.py b/pymongo/encryption.py index 4887a3f90..5bc2a7590 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -19,3 +19,4 @@ from pymongo.synchronous.encryption import * # noqa: F403 from pymongo.synchronous.encryption import __doc__ as original_doc __doc__ = original_doc +__all__ = ["Algorithm", "ClientEncryption", "QueryType", "RewrapManyDataKeyResult"] # noqa: F405 diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index 350344a6d..b399cb0d4 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -1,4 +1,4 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,261 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-import of synchronous EncryptionOptions API for compatibility.""" +"""Support for automatic client-side field level encryption.""" from __future__ import annotations -from pymongo.synchronous.encryption_options import * # noqa: F403 -from pymongo.synchronous.encryption_options import __doc__ as original_doc +from typing import TYPE_CHECKING, Any, Mapping, Optional -__doc__ = original_doc +try: + import pymongocrypt # type:ignore[import] # noqa: F401 + + # Check for pymongocrypt>=1.10. + from pymongocrypt import synchronous as _ # noqa: F401 + + _HAVE_PYMONGOCRYPT = True +except ImportError: + _HAVE_PYMONGOCRYPT = False +from bson import int64 +from pymongo.common import validate_is_mapping +from pymongo.errors import ConfigurationError +from pymongo.uri_parser import _parse_kms_tls_options + +if TYPE_CHECKING: + from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg + + +class AutoEncryptionOpts: + """Options to configure automatic client-side field level encryption.""" + + def __init__( + self, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: Optional[_AgnosticMongoClient[_DocumentTypeArg]] = None, + schema_map: Optional[Mapping[str, Any]] = None, + bypass_auto_encryption: bool = False, + mongocryptd_uri: str = "mongodb://localhost:27020", + mongocryptd_bypass_spawn: bool = False, + mongocryptd_spawn_path: str = "mongocryptd", + mongocryptd_spawn_args: Optional[list[str]] = None, + kms_tls_options: Optional[Mapping[str, Any]] = None, + crypt_shared_lib_path: Optional[str] = None, + crypt_shared_lib_required: bool = False, + bypass_query_analysis: bool = False, + encrypted_fields_map: Optional[Mapping[str, Any]] = None, + ) -> None: + """Options to configure automatic client-side field level encryption. + + Automatic client-side field level encryption requires MongoDB >=4.2 + enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not + supported for operations on a database or view and will result in + error. + + Although automatic encryption requires MongoDB >=4.2 enterprise or a + MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all + users. To configure automatic *decryption* without automatic + *encryption* set ``bypass_auto_encryption=True``. Explicit + encryption and explicit decryption is also supported for all users + with the :class:`~pymongo.encryption.ClientEncryption` class. + + See :ref:`automatic-client-side-encryption` for an example. + + :param kms_providers: Map of KMS provider options. The `kms_providers` + map values differ by provider: + + - `aws`: Map with "accessKeyId" and "secretAccessKey" as strings. + These are the AWS access key ID and AWS secret access key used + to generate KMS messages. An optional "sessionToken" may be + included to support temporary AWS credentials. + - `azure`: Map with "tenantId", "clientId", and "clientSecret" as + strings. Additionally, "identityPlatformEndpoint" may also be + specified as a string (defaults to 'login.microsoftonline.com'). + These are the Azure Active Directory credentials used to + generate Azure Key Vault messages. + - `gcp`: Map with "email" as a string and "privateKey" + as `bytes` or a base64 encoded string. + Additionally, "endpoint" may also be specified as a string + (defaults to 'oauth2.googleapis.com'). These are the + credentials used to generate Google Cloud KMS messages. + - `kmip`: Map with "endpoint" as a host with required port. + For example: ``{"endpoint": "example.com:443"}``. + - `local`: Map with "key" as `bytes` (96 bytes in length) or + a base64 encoded string which decodes + to 96 bytes. "key" is the master key used to encrypt/decrypt + data keys. This key should be generated and stored as securely + as possible. + + KMS providers may be specified with an optional name suffix + separated by a colon, for example "kmip:name" or "aws:name". + Named KMS providers do not support :ref:`CSFLE on-demand credentials`. + Named KMS providers enables more than one of each KMS provider type to be configured. + For example, to configure multiple local KMS providers:: + + kms_providers = { + "local": {"key": local_kek1}, # Unnamed KMS provider. + "local:myname": {"key": local_kek2}, # Named KMS provider with name "myname". + } + + :param key_vault_namespace: The namespace for the key vault collection. + The key vault collection contains all data keys used for encryption + and decryption. Data keys are stored as documents in this MongoDB + collection. Data keys are protected with encryption by a KMS + provider. + :param key_vault_client: By default, the key vault collection + is assumed to reside in the same MongoDB cluster as the encrypted + AsyncMongoClient/MongoClient. Use this option to route data key queries to a + separate MongoDB cluster. + :param schema_map: Map of collection namespace ("db.coll") to + JSON Schema. By default, a collection's JSONSchema is periodically + polled with the listCollections command. But a JSONSchema may be + specified locally with the schemaMap option. + + **Supplying a `schema_map` provides more security than relying on + JSON Schemas obtained from the server. It protects against a + malicious server advertising a false JSON Schema, which could trick + the client into sending unencrypted data that should be + encrypted.** + + Schemas supplied in the schemaMap only apply to configuring + automatic encryption for client side encryption. Other validation + rules in the JSON schema will not be enforced by the driver and + will result in an error. + :param bypass_auto_encryption: If ``True``, automatic + encryption will be disabled but automatic decryption will still be + enabled. Defaults to ``False``. + :param mongocryptd_uri: The MongoDB URI used to connect + to the *local* mongocryptd process. Defaults to + ``'mongodb://localhost:27020'``. + :param mongocryptd_bypass_spawn: If ``True``, the encrypted + AsyncMongoClient/MongoClient will not attempt to spawn the mongocryptd process. + Defaults to ``False``. + :param mongocryptd_spawn_path: Used for spawning the + mongocryptd process. Defaults to ``'mongocryptd'`` and spawns + mongocryptd from the system path. + :param mongocryptd_spawn_args: A list of string arguments to + use when spawning the mongocryptd process. Defaults to + ``['--idleShutdownTimeoutSecs=60']``. If the list does not include + the ``idleShutdownTimeoutSecs`` option then + ``'--idleShutdownTimeoutSecs=60'`` will be added. + :param kms_tls_options: A map of KMS provider names to TLS + options to use when creating secure connections to KMS providers. + Accepts the same TLS options as + :class:`pymongo.mongo_client.AsyncMongoClient` and :class:`pymongo.mongo_client.MongoClient`. For example, to + override the system default CA file:: + + kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}} + + Or to supply a client certificate:: + + kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}} + :param crypt_shared_lib_path: Override the path to load the crypt_shared library. + :param crypt_shared_lib_required: If True, raise an error if libmongocrypt is + unable to load the crypt_shared library. + :param bypass_query_analysis: If ``True``, disable automatic analysis + of outgoing commands. Set `bypass_query_analysis` to use explicit + encryption on indexed fields without the MongoDB Enterprise Advanced + licensed crypt_shared library. + :param encrypted_fields_map: Map of collection namespace ("db.coll") to documents + that described the encrypted fields for Queryable Encryption. For example:: + + { + "db.encryptedCollection": { + "escCollection": "enxcol_.encryptedCollection.esc", + "ecocCollection": "enxcol_.encryptedCollection.ecoc", + "fields": [ + { + "path": "firstName", + "keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')), + "bsonType": "string", + "queries": {"queryType": "equality"} + }, + { + "path": "ssn", + "keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')), + "bsonType": "string" + } + ] + } + } + + .. versionchanged:: 4.2 + Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`, + and `bypass_query_analysis` parameters. + + .. versionchanged:: 4.0 + Added the `kms_tls_options` parameter and the "kmip" KMS provider. + + .. versionadded:: 3.9 + """ + if not _HAVE_PYMONGOCRYPT: + raise ConfigurationError( + "client side encryption requires the pymongocrypt library: " + "install a compatible version with: " + "python -m pip install 'pymongo[encryption]'" + ) + if encrypted_fields_map: + validate_is_mapping("encrypted_fields_map", encrypted_fields_map) + self._encrypted_fields_map = encrypted_fields_map + self._bypass_query_analysis = bypass_query_analysis + self._crypt_shared_lib_path = crypt_shared_lib_path + self._crypt_shared_lib_required = crypt_shared_lib_required + self._kms_providers = kms_providers + self._key_vault_namespace = key_vault_namespace + self._key_vault_client = key_vault_client + self._schema_map = schema_map + self._bypass_auto_encryption = bypass_auto_encryption + self._mongocryptd_uri = mongocryptd_uri + self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn + self._mongocryptd_spawn_path = mongocryptd_spawn_path + if mongocryptd_spawn_args is None: + mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"] + self._mongocryptd_spawn_args = mongocryptd_spawn_args + if not isinstance(self._mongocryptd_spawn_args, list): + raise TypeError("mongocryptd_spawn_args must be a list") + if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args): + self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60") + # Maps KMS provider name to a SSLContext. + self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options) + self._bypass_query_analysis = bypass_query_analysis + + +class RangeOpts: + """Options to configure encrypted queries using the range algorithm.""" + + def __init__( + self, + sparsity: int, + trim_factor: int, + min: Optional[Any] = None, + max: Optional[Any] = None, + precision: Optional[int] = None, + ) -> None: + """Options to configure encrypted queries using the range algorithm. + + :param sparsity: An integer. + :param trim_factor: An integer. + :param min: A BSON scalar value corresponding to the type being queried. + :param max: A BSON scalar value corresponding to the type being queried. + :param precision: An integer, may only be set for double or decimal128 types. + + .. versionadded:: 4.4 + """ + self.min = min + self.max = max + self.sparsity = sparsity + self.trim_factor = trim_factor + self.precision = precision + + @property + def document(self) -> dict[str, Any]: + doc = {} + for k, v in [ + ("sparsity", int64.Int64(self.sparsity)), + ("trimFactor", self.trim_factor), + ("precision", self.precision), + ("min", self.min), + ("max", self.max), + ]: + if v is not None: + doc[k] = v + return doc diff --git a/pymongo/errors.py b/pymongo/errors.py index 7efbc1ff3..a781e4a01 100644 --- a/pymongo/errors.py +++ b/pymongo/errors.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional, Sequence, Un from bson.errors import InvalidDocument if TYPE_CHECKING: - from pymongo.asynchronous.typings import _DocumentOut + from pymongo.typings import _DocumentOut class PyMongoError(Exception): diff --git a/pymongo/event_loggers.py b/pymongo/event_loggers.py index 756e90ba2..287db3fc4 100644 --- a/pymongo/event_loggers.py +++ b/pymongo/event_loggers.py @@ -1,4 +1,4 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,212 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-import of synchronous EventLoggers API for compatibility.""" + +"""Example event logger classes. + +.. versionadded:: 3.11 + +These loggers can be registered using :func:`register` or +:class:`~pymongo.mongo_client.MongoClient`. + +``monitoring.register(CommandLogger())`` + +or + +``MongoClient(event_listeners=[CommandLogger()])`` +""" from __future__ import annotations -from pymongo.synchronous.event_loggers import * # noqa: F403 -from pymongo.synchronous.event_loggers import __doc__ as original_doc +import logging -__doc__ = original_doc +from pymongo import monitoring + + +class CommandLogger(monitoring.CommandListener): + """A simple listener that logs command events. + + Listens for :class:`~pymongo.monitoring.CommandStartedEvent`, + :class:`~pymongo.monitoring.CommandSucceededEvent` and + :class:`~pymongo.monitoring.CommandFailedEvent` events and + logs them at the `INFO` severity level using :mod:`logging`. + .. versionadded:: 3.11 + """ + + def started(self, event: monitoring.CommandStartedEvent) -> None: + logging.info( + f"Command {event.command_name} with request id " + f"{event.request_id} started on server " + f"{event.connection_id}" + ) + + def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: + logging.info( + f"Command {event.command_name} with request id " + f"{event.request_id} on server {event.connection_id} " + f"succeeded in {event.duration_micros} " + "microseconds" + ) + + def failed(self, event: monitoring.CommandFailedEvent) -> None: + logging.info( + f"Command {event.command_name} with request id " + f"{event.request_id} on server {event.connection_id} " + f"failed in {event.duration_micros} " + "microseconds" + ) + + +class ServerLogger(monitoring.ServerListener): + """A simple listener that logs server discovery events. + + Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`, + :class:`~pymongo.monitoring.ServerDescriptionChangedEvent`, + and :class:`~pymongo.monitoring.ServerClosedEvent` + events and logs them at the `INFO` severity level using :mod:`logging`. + + .. versionadded:: 3.11 + """ + + def opened(self, event: monitoring.ServerOpeningEvent) -> None: + logging.info(f"Server {event.server_address} added to topology {event.topology_id}") + + def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None: + previous_server_type = event.previous_description.server_type + new_server_type = event.new_description.server_type + if new_server_type != previous_server_type: + # server_type_name was added in PyMongo 3.4 + logging.info( + f"Server {event.server_address} changed type from " + f"{event.previous_description.server_type_name} to " + f"{event.new_description.server_type_name}" + ) + + def closed(self, event: monitoring.ServerClosedEvent) -> None: + logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}") + + +class HeartbeatLogger(monitoring.ServerHeartbeatListener): + """A simple listener that logs server heartbeat events. + + Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`, + :class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`, + and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent` + events and logs them at the `INFO` severity level using :mod:`logging`. + + .. versionadded:: 3.11 + """ + + def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None: + logging.info(f"Heartbeat sent to server {event.connection_id}") + + def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None: + # The reply.document attribute was added in PyMongo 3.4. + logging.info( + f"Heartbeat to server {event.connection_id} " + "succeeded with reply " + f"{event.reply.document}" + ) + + def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None: + logging.warning( + f"Heartbeat to server {event.connection_id} failed with error {event.reply}" + ) + + +class TopologyLogger(monitoring.TopologyListener): + """A simple listener that logs server topology events. + + Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`, + :class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`, + and :class:`~pymongo.monitoring.TopologyClosedEvent` + events and logs them at the `INFO` severity level using :mod:`logging`. + + .. versionadded:: 3.11 + """ + + def opened(self, event: monitoring.TopologyOpenedEvent) -> None: + logging.info(f"Topology with id {event.topology_id} opened") + + def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None: + logging.info(f"Topology description updated for topology id {event.topology_id}") + previous_topology_type = event.previous_description.topology_type + new_topology_type = event.new_description.topology_type + if new_topology_type != previous_topology_type: + # topology_type_name was added in PyMongo 3.4 + logging.info( + f"Topology {event.topology_id} changed type from " + f"{event.previous_description.topology_type_name} to " + f"{event.new_description.topology_type_name}" + ) + # The has_writable_server and has_readable_server methods + # were added in PyMongo 3.4. + if not event.new_description.has_writable_server(): + logging.warning("No writable servers available.") + if not event.new_description.has_readable_server(): + logging.warning("No readable servers available.") + + def closed(self, event: monitoring.TopologyClosedEvent) -> None: + logging.info(f"Topology with id {event.topology_id} closed") + + +class ConnectionPoolLogger(monitoring.ConnectionPoolListener): + """A simple listener that logs server connection pool events. + + Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`, + :class:`~pymongo.monitoring.PoolClearedEvent`, + :class:`~pymongo.monitoring.PoolClosedEvent`, + :~pymongo.monitoring.class:`ConnectionCreatedEvent`, + :class:`~pymongo.monitoring.ConnectionReadyEvent`, + :class:`~pymongo.monitoring.ConnectionClosedEvent`, + :class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`, + :class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`, + :class:`~pymongo.monitoring.ConnectionCheckedOutEvent`, + and :class:`~pymongo.monitoring.ConnectionCheckedInEvent` + events and logs them at the `INFO` severity level using :mod:`logging`. + + .. versionadded:: 3.11 + """ + + def pool_created(self, event: monitoring.PoolCreatedEvent) -> None: + logging.info(f"[pool {event.address}] pool created") + + def pool_ready(self, event: monitoring.PoolReadyEvent) -> None: + logging.info(f"[pool {event.address}] pool ready") + + def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None: + logging.info(f"[pool {event.address}] pool cleared") + + def pool_closed(self, event: monitoring.PoolClosedEvent) -> None: + logging.info(f"[pool {event.address}] pool closed") + + def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None: + logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created") + + def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None: + logging.info( + f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded" + ) + + def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None: + logging.info( + f"[pool {event.address}][conn #{event.connection_id}] " + f'connection closed, reason: "{event.reason}"' + ) + + def connection_check_out_started( + self, event: monitoring.ConnectionCheckOutStartedEvent + ) -> None: + logging.info(f"[pool {event.address}] connection check out started") + + def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None: + logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}") + + def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None: + logging.info( + f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool" + ) + + def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None: + logging.info( + f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool" + ) diff --git a/pymongo/asynchronous/hello.py b/pymongo/hello.py similarity index 96% rename from pymongo/asynchronous/hello.py rename to pymongo/hello.py index 3826e8a27..62bb79980 100644 --- a/pymongo/asynchronous/hello.py +++ b/pymongo/hello.py @@ -21,12 +21,9 @@ import itertools from typing import Any, Generic, Mapping, Optional from bson.objectid import ObjectId -from pymongo.asynchronous import common -from pymongo.asynchronous.hello_compat import HelloCompat -from pymongo.asynchronous.typings import ClusterTime, _DocumentType +from pymongo import common from pymongo.server_type import SERVER_TYPE - -_IS_SYNC = False +from pymongo.typings import ClusterTime, _DocumentType def _get_server_type(doc: Mapping[str, Any]) -> int: @@ -57,6 +54,14 @@ def _get_server_type(doc: Mapping[str, Any]) -> int: return SERVER_TYPE.Standalone +class HelloCompat: + CMD = "hello" + LEGACY_CMD = "ismaster" + PRIMARY = "isWritablePrimary" + LEGACY_PRIMARY = "ismaster" + LEGACY_ERROR = "not master" + + class Hello(Generic[_DocumentType]): """Parse a hello response from the server. diff --git a/pymongo/helpers_constants.py b/pymongo/helpers_constants.py deleted file mode 100644 index 00b250270..000000000 --- a/pymongo/helpers_constants.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Constants used by the driver that don't really fit elsewhere.""" - -# From the SDAM spec, the "node is shutting down" codes. -from __future__ import annotations - -_SHUTDOWN_CODES: frozenset = frozenset( - [ - 11600, # InterruptedAtShutdown - 91, # ShutdownInProgress - ] -) -# From the SDAM spec, the "not primary" error codes are combined with the -# "node is recovering" error codes (of which the "node is shutting down" -# errors are a subset). -_NOT_PRIMARY_CODES: frozenset = ( - frozenset( - [ - 10058, # LegacyNotPrimary <=3.2 "not primary" error code - 10107, # NotWritablePrimary - 13435, # NotPrimaryNoSecondaryOk - 11602, # InterruptedDueToReplStateChange - 13436, # NotPrimaryOrSecondary - 189, # PrimarySteppedDown - ] - ) - | _SHUTDOWN_CODES -) -# From the retryable writes spec. -_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset( - [ - 7, # HostNotFound - 6, # HostUnreachable - 89, # NetworkTimeout - 9001, # SocketException - 262, # ExceededTimeLimit - 134, # ReadConcernMajorityNotAvailableYet - ] -) - -# Server code raised when re-authentication is required -_REAUTHENTICATION_REQUIRED_CODE: int = 391 - -# Server code raised when authentication fails. -_AUTHENTICATION_FAILURE_CODE: int = 18 - -# Note - to avoid bugs from forgetting which if these is all lowercase and -# which are camelCase, and at the same time avoid having to add a test for -# every command, use all lowercase here and test against command_name.lower(). -_SENSITIVE_COMMANDS: set = { - "authenticate", - "saslstart", - "saslcontinue", - "getnonce", - "createuser", - "updateuser", - "copydbgetnonce", - "copydbsaslstart", - "copydb", -} diff --git a/pymongo/helpers_shared.py b/pymongo/helpers_shared.py new file mode 100644 index 000000000..83ea2ddf7 --- /dev/null +++ b/pymongo/helpers_shared.py @@ -0,0 +1,327 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bits and pieces used by the driver that don't really fit elsewhere.""" +from __future__ import annotations + +import sys +import traceback +from collections import abc +from typing import ( + TYPE_CHECKING, + Any, + Container, + Iterable, + Mapping, + NoReturn, + Optional, + Sequence, + Union, +) + +from pymongo import ASCENDING +from pymongo.errors import ( + CursorNotFound, + DuplicateKeyError, + ExecutionTimeout, + NotPrimaryError, + OperationFailure, + WriteConcernError, + WriteError, + WTimeoutError, + _wtimeout_error, +) +from pymongo.hello import HelloCompat + +if TYPE_CHECKING: + from pymongo.cursor_shared import _Hint + from pymongo.operations import _IndexList + from pymongo.typings import _DocumentOut + + +# From the SDAM spec, the "node is shutting down" codes. + +_SHUTDOWN_CODES: frozenset = frozenset( + [ + 11600, # InterruptedAtShutdown + 91, # ShutdownInProgress + ] +) +# From the SDAM spec, the "not primary" error codes are combined with the +# "node is recovering" error codes (of which the "node is shutting down" +# errors are a subset). +_NOT_PRIMARY_CODES: frozenset = ( + frozenset( + [ + 10058, # LegacyNotPrimary <=3.2 "not primary" error code + 10107, # NotWritablePrimary + 13435, # NotPrimaryNoSecondaryOk + 11602, # InterruptedDueToReplStateChange + 13436, # NotPrimaryOrSecondary + 189, # PrimarySteppedDown + ] + ) + | _SHUTDOWN_CODES +) +# From the retryable writes spec. +_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset( + [ + 7, # HostNotFound + 6, # HostUnreachable + 89, # NetworkTimeout + 9001, # SocketException + 262, # ExceededTimeLimit + 134, # ReadConcernMajorityNotAvailableYet + ] +) + +# Server code raised when re-authentication is required +_REAUTHENTICATION_REQUIRED_CODE: int = 391 + +# Server code raised when authentication fails. +_AUTHENTICATION_FAILURE_CODE: int = 18 + +# Note - to avoid bugs from forgetting which if these is all lowercase and +# which are camelCase, and at the same time avoid having to add a test for +# every command, use all lowercase here and test against command_name.lower(). +_SENSITIVE_COMMANDS: set = { + "authenticate", + "saslstart", + "saslcontinue", + "getnonce", + "createuser", + "updateuser", + "copydbgetnonce", + "copydbsaslstart", + "copydb", +} + + +def _gen_index_name(keys: _IndexList) -> str: + """Generate an index name from the set of fields it is over.""" + return "_".join(["{}_{}".format(*item) for item in keys]) + + +def _index_list( + key_or_list: _Hint, direction: Optional[Union[int, str]] = None +) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]: + """Helper to generate a list of (key, direction) pairs. + + Takes such a list, or a single key, or a single key and direction. + """ + if direction is not None: + if not isinstance(key_or_list, str): + raise TypeError("Expected a string and a direction") + return [(key_or_list, direction)] + else: + if isinstance(key_or_list, str): + return [(key_or_list, ASCENDING)] + elif isinstance(key_or_list, abc.ItemsView): + return list(key_or_list) # type: ignore[arg-type] + elif isinstance(key_or_list, abc.Mapping): + return list(key_or_list.items()) + elif not isinstance(key_or_list, (list, tuple)): + raise TypeError("if no direction is specified, key_or_list must be an instance of list") + values: list[tuple[str, int]] = [] + for item in key_or_list: + if isinstance(item, str): + item = (item, ASCENDING) # noqa: PLW2901 + values.append(item) + return values + + +def _index_document(index_list: _IndexList) -> dict[str, Any]: + """Helper to generate an index specifying document. + + Takes a list of (key, direction) pairs. + """ + if not isinstance(index_list, (list, tuple, abc.Mapping)): + raise TypeError( + "must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list) + ) + if not len(index_list): + raise ValueError("key_or_list must not be empty") + + index: dict[str, Any] = {} + + if isinstance(index_list, abc.Mapping): + for key in index_list: + value = index_list[key] + _validate_index_key_pair(key, value) + index[key] = value + else: + for item in index_list: + if isinstance(item, str): + item = (item, ASCENDING) # noqa: PLW2901 + key, value = item + _validate_index_key_pair(key, value) + index[key] = value + return index + + +def _validate_index_key_pair(key: Any, value: Any) -> None: + if not isinstance(key, str): + raise TypeError("first item in each key pair must be an instance of str") + if not isinstance(value, (str, int, abc.Mapping)): + raise TypeError( + "second item in each key pair must be 1, -1, " + "'2d', or another valid MongoDB index specifier." + ) + + +def _check_command_response( + response: _DocumentOut, + max_wire_version: Optional[int], + allowable_errors: Optional[Container[Union[int, str]]] = None, + parse_write_concern_error: bool = False, +) -> None: + """Check the response to a command for errors.""" + if "ok" not in response: + # Server didn't recognize our message as a command. + raise OperationFailure( + response.get("$err"), # type: ignore[arg-type] + response.get("code"), + response, + max_wire_version, + ) + + if parse_write_concern_error and "writeConcernError" in response: + _error = response["writeConcernError"] + _labels = response.get("errorLabels") + if _labels: + _error.update({"errorLabels": _labels}) + _raise_write_concern_error(_error) + + if response["ok"]: + return + + details = response + # Mongos returns the error details in a 'raw' object + # for some errors. + if "raw" in response: + for shard in response["raw"].values(): + # Grab the first non-empty raw error from a shard. + if shard.get("errmsg") and not shard.get("ok"): + details = shard + break + + errmsg = details["errmsg"] + code = details.get("code") + + # For allowable errors, only check for error messages when the code is not + # included. + if allowable_errors: + if code is not None: + if code in allowable_errors: + return + elif errmsg in allowable_errors: + return + + # Server is "not primary" or "recovering" + if code is not None: + if code in _NOT_PRIMARY_CODES: + raise NotPrimaryError(errmsg, response) + elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg: + raise NotPrimaryError(errmsg, response) + + # Other errors + # findAndModify with upsert can raise duplicate key error + if code in (11000, 11001, 12582): + raise DuplicateKeyError(errmsg, code, response, max_wire_version) + elif code == 50: + raise ExecutionTimeout(errmsg, code, response, max_wire_version) + elif code == 43: + raise CursorNotFound(errmsg, code, response, max_wire_version) + + raise OperationFailure(errmsg, code, response, max_wire_version) + + +def _raise_last_write_error(write_errors: list[Any]) -> NoReturn: + # If the last batch had multiple errors only report + # the last error to emulate continue_on_error. + error = write_errors[-1] + if error.get("code") == 11000: + raise DuplicateKeyError(error.get("errmsg"), 11000, error) + raise WriteError(error.get("errmsg"), error.get("code"), error) + + +def _raise_write_concern_error(error: Any) -> NoReturn: + if _wtimeout_error(error): + # Make sure we raise WTimeoutError + raise WTimeoutError(error.get("errmsg"), error.get("code"), error) + raise WriteConcernError(error.get("errmsg"), error.get("code"), error) + + +def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]: + """Return the writeConcernError or None.""" + wce = result.get("writeConcernError") + if wce: + # The server reports errorLabels at the top level but it's more + # convenient to attach it to the writeConcernError doc itself. + error_labels = result.get("errorLabels") + if error_labels: + # Copy to avoid changing the original document. + wce = wce.copy() + wce["errorLabels"] = error_labels + return wce + + +def _check_write_command_response(result: Mapping[str, Any]) -> None: + """Backward compatibility helper for write command error handling.""" + # Prefer write errors over write concern errors + write_errors = result.get("writeErrors") + if write_errors: + _raise_last_write_error(write_errors) + + wce = _get_wce_doc(result) + if wce: + _raise_write_concern_error(wce) + + +def _fields_list_to_dict( + fields: Union[Mapping[str, Any], Iterable[str]], option_name: str +) -> Mapping[str, Any]: + """Takes a sequence of field names and returns a matching dictionary. + + ["a", "b"] becomes {"a": 1, "b": 1} + + and + + ["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1} + """ + if isinstance(fields, abc.Mapping): + return fields + + if isinstance(fields, (abc.Sequence, abc.Set)): + if not all(isinstance(field, str) for field in fields): + raise TypeError(f"{option_name} must be a list of key names, each an instance of str") + return dict.fromkeys(fields, 1) + + raise TypeError(f"{option_name} must be a mapping or list of key names") + + +def _handle_exception() -> None: + """Print exceptions raised by subscribers to stderr.""" + # Heavily influenced by logging.Handler.handleError. + + # See note here: + # https://docs.python.org/3.4/library/sys.html#sys.__stderr__ + if sys.stderr: + einfo = sys.exc_info() + try: + traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr) + except OSError: + pass + finally: + del einfo diff --git a/pymongo/synchronous/logger.py b/pymongo/logger.py similarity index 98% rename from pymongo/synchronous/logger.py rename to pymongo/logger.py index d0f539ee6..2caafa778 100644 --- a/pymongo/synchronous/logger.py +++ b/pymongo/logger.py @@ -21,9 +21,7 @@ from typing import Any from bson import UuidRepresentation, json_util from bson.json_util import JSONOptions, _truncate_documents -from pymongo.synchronous.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason - -_IS_SYNC = True +from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason class _CommandStatusMessage(str, enum.Enum): diff --git a/pymongo/synchronous/max_staleness_selectors.py b/pymongo/max_staleness_selectors.py similarity index 98% rename from pymongo/synchronous/max_staleness_selectors.py rename to pymongo/max_staleness_selectors.py index cde43890d..89bfa6528 100644 --- a/pymongo/synchronous/max_staleness_selectors.py +++ b/pymongo/max_staleness_selectors.py @@ -34,9 +34,8 @@ from pymongo.errors import ConfigurationError from pymongo.server_type import SERVER_TYPE if TYPE_CHECKING: - from pymongo.synchronous.server_selectors import Selection + from pymongo.server_selectors import Selection -_IS_SYNC = True # Constant defined in Max Staleness Spec: An idle primary writes a no-op every # 10 seconds to refresh secondaries' lastWriteDate values. diff --git a/pymongo/synchronous/message.py b/pymongo/message.py similarity index 78% rename from pymongo/synchronous/message.py rename to pymongo/message.py index 0eca1e8f1..bcb4ce10e 100644 --- a/pymongo/synchronous/message.py +++ b/pymongo/message.py @@ -22,7 +22,6 @@ MongoDB. from __future__ import annotations import datetime -import logging import random import struct from io import BytesIO as _BytesIO @@ -39,7 +38,7 @@ from typing import ( ) import bson -from bson import CodecOptions, _decode_selective, _dict_to_bson, _make_c_string, encode +from bson import CodecOptions, _dict_to_bson, _make_c_string from bson.int64 import Int64 from bson.raw_bson import ( _RAW_ARRAY_BSON_OPTIONS, @@ -47,6 +46,8 @@ from bson.raw_bson import ( RawBSONDocument, _inflate_bson, ) +from pymongo.hello import HelloCompat +from pymongo.monitoring import _EventListeners try: from pymongo import _cmessage # type: ignore[attr-defined] @@ -64,31 +65,20 @@ from pymongo.errors import ( OperationFailure, ProtocolError, ) -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.helpers import _handle_reauth -from pymongo.synchronous.logger import ( - _COMMAND_LOGGER, - _CommandStatusMessage, - _debug_log, -) -from pymongo.synchronous.read_preferences import ReadPreference -from pymongo.write_concern import WriteConcern +from pymongo.read_preferences import ReadPreference, _ServerMode if TYPE_CHECKING: - from datetime import timedelta - + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.read_concern import ReadConcern - from pymongo.synchronous.client_session import ClientSession - from pymongo.synchronous.compression_support import SnappyContext, ZlibContext, ZstdContext - from pymongo.synchronous.mongo_client import MongoClient - from pymongo.synchronous.monitoring import _EventListeners - from pymongo.synchronous.pool import Connection - from pymongo.synchronous.read_preferences import _ServerMode - from pymongo.synchronous.typings import _Address, _DocumentOut + from pymongo.typings import ( + _Address, + _AgnosticClientSession, + _AgnosticConnection, + _AgnosticMongoClient, + _DocumentOut, + ) -_IS_SYNC = True - MAX_INT32 = 2147483647 MIN_INT32 = -2147483648 @@ -217,7 +207,7 @@ def _gen_find_command( options: Optional[int], read_concern: ReadConcern, collation: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, + session: Optional[_AgnosticClientSession] = None, allow_disk_use: Optional[bool] = None, ) -> dict[str, Any]: """Generate a find command document.""" @@ -264,7 +254,7 @@ def _gen_get_more_command( batch_size: Optional[int], max_await_time_ms: Optional[int], comment: Optional[Any], - conn: Connection, + conn: _AgnosticConnection, ) -> dict[str, Any]: """Generate a getMore command document.""" cmd: dict[str, Any] = {"getMore": cursor_id, "collection": coll} @@ -277,362 +267,6 @@ def _gen_get_more_command( return cmd -class _Query: - """A query operation.""" - - __slots__ = ( - "flags", - "db", - "coll", - "ntoskip", - "spec", - "fields", - "codec_options", - "read_preference", - "limit", - "batch_size", - "name", - "read_concern", - "collation", - "session", - "client", - "allow_disk_use", - "_as_command", - "exhaust", - ) - - # For compatibility with the _GetMore class. - conn_mgr = None - cursor_id = None - - def __init__( - self, - flags: int, - db: str, - coll: str, - ntoskip: int, - spec: Mapping[str, Any], - fields: Optional[Mapping[str, Any]], - codec_options: CodecOptions, - read_preference: _ServerMode, - limit: int, - batch_size: int, - read_concern: ReadConcern, - collation: Optional[Mapping[str, Any]], - session: Optional[ClientSession], - client: MongoClient, - allow_disk_use: Optional[bool], - exhaust: bool, - ): - self.flags = flags - self.db = db - self.coll = coll - self.ntoskip = ntoskip - self.spec = spec - self.fields = fields - self.codec_options = codec_options - self.read_preference = read_preference - self.read_concern = read_concern - self.limit = limit - self.batch_size = batch_size - self.collation = collation - self.session = session - self.client = client - self.allow_disk_use = allow_disk_use - self.name = "find" - self._as_command: Optional[tuple[dict[str, Any], str]] = None - self.exhaust = exhaust - - def reset(self) -> None: - self._as_command = None - - def namespace(self) -> str: - return f"{self.db}.{self.coll}" - - def use_command(self, conn: Connection) -> bool: - use_find_cmd = False - if not self.exhaust: - use_find_cmd = True - elif conn.max_wire_version >= 8: - # OP_MSG supports exhaust on MongoDB 4.2+ - use_find_cmd = True - elif not self.read_concern.ok_for_legacy: - raise ConfigurationError( - "read concern level of %s is not valid " - "with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version) - ) - - conn.validate_session(self.client, self.session) - return use_find_cmd - - def as_command( - self, conn: Connection, apply_timeout: bool = False - ) -> tuple[dict[str, Any], str]: - """Return a find command document for this query.""" - # We use the command twice: on the wire and for command monitoring. - # Generate it once, for speed and to avoid repeating side-effects. - if self._as_command is not None: - return self._as_command - - explain = "$explain" in self.spec - cmd: dict[str, Any] = _gen_find_command( - self.coll, - self.spec, - self.fields, - self.ntoskip, - self.limit, - self.batch_size, - self.flags, - self.read_concern, - self.collation, - self.session, - self.allow_disk_use, - ) - if explain: - self.name = "explain" - cmd = {"explain": cmd} - session = self.session - conn.add_server_api(cmd) - if session: - session._apply_to(cmd, False, self.read_preference, conn) - # Explain does not support readConcern. - if not explain and not session.in_transaction: - session._update_read_concern(cmd, conn) - conn.send_cluster_time(cmd, session, self.client) - # Support auto encryption - client = self.client - if client._encrypter and not client._encrypter._bypass_auto_encryption: - cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) - # Support CSOT - if apply_timeout: - conn.apply_timeout(client, cmd) - self._as_command = cmd, self.db - return self._as_command - - def get_message( - self, read_preference: _ServerMode, conn: Connection, use_cmd: bool = False - ) -> tuple[int, bytes, int]: - """Get a query message, possibly setting the secondaryOk bit.""" - # Use the read_preference decided by _socket_from_server. - self.read_preference = read_preference - if read_preference.mode: - # Set the secondaryOk bit. - flags = self.flags | 4 - else: - flags = self.flags - - ns = self.namespace() - spec = self.spec - - if use_cmd: - spec = (self.as_command(conn, apply_timeout=True))[0] - request_id, msg, size, _ = _op_msg( - 0, - spec, - self.db, - read_preference, - self.codec_options, - ctx=conn.compression_context, - ) - return request_id, msg, size - - # OP_QUERY treats ntoreturn of -1 and 1 the same, return - # one document and close the cursor. We have to use 2 for - # batch size if 1 is specified. - ntoreturn = self.batch_size == 1 and 2 or self.batch_size - if self.limit: - if ntoreturn: - ntoreturn = min(self.limit, ntoreturn) - else: - ntoreturn = self.limit - - if conn.is_mongos: - assert isinstance(spec, MutableMapping) - spec = _maybe_add_read_preference(spec, read_preference) - - return _query( - flags, - ns, - self.ntoskip, - ntoreturn, - spec, - None if use_cmd else self.fields, - self.codec_options, - ctx=conn.compression_context, - ) - - -class _GetMore: - """A getmore operation.""" - - __slots__ = ( - "db", - "coll", - "ntoreturn", - "cursor_id", - "max_await_time_ms", - "codec_options", - "read_preference", - "session", - "client", - "conn_mgr", - "_as_command", - "exhaust", - "comment", - ) - - name = "getMore" - - def __init__( - self, - db: str, - coll: str, - ntoreturn: int, - cursor_id: int, - codec_options: CodecOptions, - read_preference: _ServerMode, - session: Optional[ClientSession], - client: MongoClient, - max_await_time_ms: Optional[int], - conn_mgr: Any, - exhaust: bool, - comment: Any, - ): - self.db = db - self.coll = coll - self.ntoreturn = ntoreturn - self.cursor_id = cursor_id - self.codec_options = codec_options - self.read_preference = read_preference - self.session = session - self.client = client - self.max_await_time_ms = max_await_time_ms - self.conn_mgr = conn_mgr - self._as_command: Optional[tuple[dict[str, Any], str]] = None - self.exhaust = exhaust - self.comment = comment - - def reset(self) -> None: - self._as_command = None - - def namespace(self) -> str: - return f"{self.db}.{self.coll}" - - def use_command(self, conn: Connection) -> bool: - use_cmd = False - if not self.exhaust: - use_cmd = True - elif conn.max_wire_version >= 8: - # OP_MSG supports exhaust on MongoDB 4.2+ - use_cmd = True - - conn.validate_session(self.client, self.session) - return use_cmd - - def as_command( - self, conn: Connection, apply_timeout: bool = False - ) -> tuple[dict[str, Any], str]: - """Return a getMore command document for this query.""" - # See _Query.as_command for an explanation of this caching. - if self._as_command is not None: - return self._as_command - - cmd: dict[str, Any] = _gen_get_more_command( - self.cursor_id, - self.coll, - self.ntoreturn, - self.max_await_time_ms, - self.comment, - conn, - ) - if self.session: - self.session._apply_to(cmd, False, self.read_preference, conn) - conn.add_server_api(cmd) - conn.send_cluster_time(cmd, self.session, self.client) - # Support auto encryption - client = self.client - if client._encrypter and not client._encrypter._bypass_auto_encryption: - cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) - # Support CSOT - if apply_timeout: - conn.apply_timeout(client, cmd=None) - self._as_command = cmd, self.db - return self._as_command - - def get_message( - self, dummy0: Any, conn: Connection, use_cmd: bool = False - ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]: - """Get a getmore message.""" - ns = self.namespace() - ctx = conn.compression_context - - if use_cmd: - spec = (self.as_command(conn, apply_timeout=True))[0] - if self.conn_mgr and self.exhaust: - flags = _OpMsg.EXHAUST_ALLOWED - else: - flags = 0 - request_id, msg, size, _ = _op_msg( - flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context - ) - return request_id, msg, size - - return _get_more(ns, self.ntoreturn, self.cursor_id, ctx) - - -class _RawBatchQuery(_Query): - def use_command(self, conn: Connection) -> bool: - # Compatibility checks. - super().use_command(conn) - if conn.max_wire_version >= 8: - # MongoDB 4.2+ supports exhaust over OP_MSG - return True - elif not self.exhaust: - return True - return False - - -class _RawBatchGetMore(_GetMore): - def use_command(self, conn: Connection) -> bool: - # Compatibility checks. - super().use_command(conn) - if conn.max_wire_version >= 8: - # MongoDB 4.2+ supports exhaust over OP_MSG - return True - elif not self.exhaust: - return True - return False - - -class _CursorAddress(tuple): - """The server address (host, port) of a cursor, with namespace property.""" - - __namespace: Any - - def __new__(cls, address: _Address, namespace: str) -> _CursorAddress: - self = tuple.__new__(cls, address) - self.__namespace = namespace - return self - - @property - def namespace(self) -> str: - """The namespace this cursor.""" - return self.__namespace - - def __hash__(self) -> int: - # Two _CursorAddress instances with different namespaces - # must not hash the same. - return ((*self, self.__namespace)).__hash__() - - def __eq__(self, other: object) -> bool: - if isinstance(other, _CursorAddress): - return tuple(self) == tuple(other) and self.namespace == other.namespace - return NotImplemented - - def __ne__(self, other: object) -> bool: - return not self == other - - _pack_compression_header = struct.Struct(" return b"".join( [ _ZERO_32, - _make_c_string(collection_name), + bson._make_c_string(collection_name), _pack_int(num_to_return), _pack_long_long(cursor_id), ] @@ -907,8 +541,18 @@ def _get_more( return _get_more_uncompressed(collection_name, num_to_return, cursor_id) +# OP_MSG ------------------------------------------------------------- + + +_OP_MSG_MAP = { + _INSERT: b"documents\x00", + _UPDATE: b"updates\x00", + _DELETE: b"deletes\x00", +} + + class _BulkWriteContext: - """A wrapper around Connection for use with write splitting functions.""" + """A wrapper around AsyncConnection for use with write splitting functions.""" __slots__ = ( "db_name", @@ -929,10 +573,10 @@ class _BulkWriteContext: self, database_name: str, cmd_name: str, - conn: Connection, + conn: _AgnosticConnection, operation_id: int, listeners: _EventListeners, - session: ClientSession, + session: _AgnosticClientSession, op_type: int, codec: CodecOptions, ): @@ -949,9 +593,9 @@ class _BulkWriteContext: self.op_type = op_type self.codec = codec - def __batch_command( + def batch_command( self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] - ) -> tuple[int, bytes, list[Mapping[str, Any]]]: + ) -> tuple[int, Union[bytes, dict[str, Any]], list[Mapping[str, Any]]]: namespace = self.db_name + ".$cmd" request_id, msg, to_send = _do_batched_op_msg( namespace, self.op_type, cmd, docs, self.codec, self @@ -960,26 +604,6 @@ class _BulkWriteContext: raise InvalidOperation("cannot do an empty bulk write") return request_id, msg, to_send - def execute( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient - ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: - request_id, msg, to_send = self.__batch_command(cmd, docs) - result = self.write_command(cmd, request_id, msg, to_send, client) - client._process_response(result, self.session) - return result, to_send - - def execute_unack( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient - ) -> list[Mapping[str, Any]]: - request_id, msg, to_send = self.__batch_command(cmd, docs) - # Though this isn't strictly a "legacy" write, the helper - # handles publishing commands and sending our message - # without receiving a result. Send 0 for max_doc_size - # to disable size checking. Size checking is handled while - # the documents are encoded to BSON. - self.unack_write(cmd, request_id, msg, 0, to_send, client) - return to_send - @property def max_bson_size(self) -> int: """A proxy for SockInfo.max_bson_size.""" @@ -1003,178 +627,6 @@ class _BulkWriteContext: """The maximum size of a BSON command before batch splitting.""" return self.max_bson_size - def unack_write( - self, - cmd: MutableMapping[str, Any], - request_id: int, - msg: bytes, - max_doc_size: int, - docs: list[Mapping[str, Any]], - client: MongoClient, - ) -> Optional[Mapping[str, Any]]: - """A proxy for Connection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.STARTED, - command=cmd, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - cmd = self._start(cmd, request_id, docs) - try: - result = self.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value] - duration = datetime.datetime.now() - self.start_time - if result is not None: - reply = _convert_write_result(self.name, cmd, result) - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.SUCCEEDED, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - self._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - self.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(self.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.FAILED, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if self.publish: - assert self.start_time is not None - self._fail(request_id, failure, duration) - raise - finally: - self.start_time = datetime.datetime.now() - return result - - @_handle_reauth - def write_command( - self, - cmd: MutableMapping[str, Any], - request_id: int, - msg: bytes, - docs: list[Mapping[str, Any]], - client: MongoClient, - ) -> dict[str, Any]: - """A proxy for SocketInfo.write_command that handles event publishing.""" - cmd[self.field] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.STARTED, - command=cmd, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - self._start(cmd, request_id, docs) - try: - reply = self.conn.write_command(request_id, msg, self.codec) - duration = datetime.datetime.now() - self.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.SUCCEEDED, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - self._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - self.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.FAILED, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if self.publish: - self._fail(request_id, failure, duration) - raise - finally: - self.start_time = datetime.datetime.now() - return reply - def _start( self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]] ) -> MutableMapping[str, Any]: @@ -1191,7 +643,7 @@ class _BulkWriteContext: ) return cmd - def _succeed(self, request_id: int, reply: _DocumentOut, duration: timedelta) -> None: + def _succeed(self, request_id: int, reply: _DocumentOut, duration: datetime.timedelta) -> None: """Publish a CommandSucceededEvent.""" self.listeners.publish_command_success( duration, @@ -1205,7 +657,7 @@ class _BulkWriteContext: database_name=self.db_name, ) - def _fail(self, request_id: int, failure: _DocumentOut, duration: timedelta) -> None: + def _fail(self, request_id: int, failure: _DocumentOut, duration: datetime.timedelta) -> None: """Publish a CommandFailedEvent.""" self.listeners.publish_command_failure( duration, @@ -1220,19 +672,12 @@ class _BulkWriteContext: ) -# From the Client Side Encryption spec: -# Because automatic encryption increases the size of commands, the driver -# MUST split bulk writes at a reduced size limit before undergoing automatic -# encryption. The write payload MUST be split at 2MiB (2097152). -_MAX_SPLIT_SIZE_ENC = 2097152 - - class _EncryptedBulkWriteContext(_BulkWriteContext): __slots__ = () - def __batch_command( + def batch_command( self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] - ) -> tuple[dict[str, Any], list[Mapping[str, Any]]]: + ) -> tuple[int, dict[str, Any], list[Mapping[str, Any]]]: namespace = self.db_name + ".$cmd" msg, to_send = _encode_batched_write_command( namespace, self.op_type, cmd, docs, self.codec, self @@ -1243,29 +688,7 @@ class _EncryptedBulkWriteContext(_BulkWriteContext): # Chop off the OP_QUERY header to get a properly batched write command. cmd_start = msg.index(b"\x00", 4) + 9 outgoing = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS) - return outgoing, to_send - - def execute( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient - ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: - batched_cmd, to_send = self.__batch_command(cmd, docs) - result: Mapping[str, Any] = self.conn.command( - self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client - ) - return result, to_send - - def execute_unack( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient - ) -> list[Mapping[str, Any]]: - batched_cmd, to_send = self.__batch_command(cmd, docs) - self.conn.command( - self.db_name, - batched_cmd, - write_concern=WriteConcern(w=0), - session=self.session, - client=client, - ) - return to_send + return -1, outgoing, to_send @property def max_split_size(self) -> int: @@ -1288,14 +711,11 @@ def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> N raise DocumentTooLarge(f"{operation!r} command document too large") -# OP_MSG ------------------------------------------------------------- - - -_OP_MSG_MAP = { - _INSERT: b"documents\x00", - _UPDATE: b"updates\x00", - _DELETE: b"deletes\x00", -} +# From the Client Side Encryption spec: +# Because automatic encryption increases the size of commands, the driver +# MUST split bulk writes at a reduced size limit before undergoing automatic +# encryption. The write payload MUST be split at 2MiB (2097152). +_MAX_SPLIT_SIZE_ENC = 2097152 def _batched_op_msg_impl( @@ -1507,7 +927,7 @@ def _batched_write_command_impl( # Where to write command document length command_start = buf.tell() - buf.write(encode(command)) + buf.write(bson.encode(command)) # Start of payload buf.seek(-1, 2) @@ -1696,7 +1116,7 @@ class _OpMsg: cursor_id is ignored user_fields is used to determine which fields must not be decoded """ - inflated_response = _decode_selective( + inflated_response = bson._decode_selective( RawBSONDocument(self.payload_document), user_fields, _RAW_ARRAY_BSON_OPTIONS ) return [inflated_response] @@ -1758,3 +1178,356 @@ _UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = { _OpReply.OP_CODE: _OpReply.unpack, _OpMsg.OP_CODE: _OpMsg.unpack, } + + +class _Query: + """A query operation.""" + + __slots__ = ( + "flags", + "db", + "coll", + "ntoskip", + "spec", + "fields", + "codec_options", + "read_preference", + "limit", + "batch_size", + "name", + "read_concern", + "collation", + "session", + "client", + "allow_disk_use", + "_as_command", + "exhaust", + ) + + # For compatibility with the _GetMore class. + conn_mgr = None + cursor_id = None + + def __init__( + self, + flags: int, + db: str, + coll: str, + ntoskip: int, + spec: Mapping[str, Any], + fields: Optional[Mapping[str, Any]], + codec_options: CodecOptions, + read_preference: _ServerMode, + limit: int, + batch_size: int, + read_concern: ReadConcern, + collation: Optional[Mapping[str, Any]], + session: Optional[_AgnosticClientSession], + client: _AgnosticMongoClient, + allow_disk_use: Optional[bool], + exhaust: bool, + ): + self.flags = flags + self.db = db + self.coll = coll + self.ntoskip = ntoskip + self.spec = spec + self.fields = fields + self.codec_options = codec_options + self.read_preference = read_preference + self.read_concern = read_concern + self.limit = limit + self.batch_size = batch_size + self.collation = collation + self.session = session + self.client = client + self.allow_disk_use = allow_disk_use + self.name = "find" + self._as_command: Optional[tuple[dict[str, Any], str]] = None + self.exhaust = exhaust + + def reset(self) -> None: + self._as_command = None + + def namespace(self) -> str: + return f"{self.db}.{self.coll}" + + def use_command(self, conn: _AgnosticConnection) -> bool: + use_find_cmd = False + if not self.exhaust: + use_find_cmd = True + elif conn.max_wire_version >= 8: + # OP_MSG supports exhaust on MongoDB 4.2+ + use_find_cmd = True + elif not self.read_concern.ok_for_legacy: + raise ConfigurationError( + "read concern level of %s is not valid " + "with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version) + ) + + conn.validate_session(self.client, self.session) # type: ignore[arg-type] + return use_find_cmd + + def update_command(self, cmd: dict[str, Any]) -> None: + self._as_command = cmd, self.db + + def as_command( + self, conn: _AgnosticConnection, apply_timeout: bool = False + ) -> tuple[dict[str, Any], str]: + """Return a find command document for this query.""" + # We use the command twice: on the wire and for command monitoring. + # Generate it once, for speed and to avoid repeating side-effects. + if self._as_command is not None: + return self._as_command + + explain = "$explain" in self.spec + cmd: dict[str, Any] = _gen_find_command( + self.coll, + self.spec, + self.fields, + self.ntoskip, + self.limit, + self.batch_size, + self.flags, + self.read_concern, + self.collation, + self.session, + self.allow_disk_use, + ) + if explain: + self.name = "explain" + cmd = {"explain": cmd} + conn.add_server_api(cmd) + if self.session: + self.session._apply_to(cmd, False, self.read_preference, conn) # type: ignore[arg-type] + # Explain does not support readConcern. + if not explain and not self.session.in_transaction: + self.session._update_read_concern(cmd, conn) # type: ignore[arg-type] + conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type] + # Support CSOT + if apply_timeout: + conn.apply_timeout(self.client, cmd=cmd) # type: ignore[arg-type] + self._as_command = cmd, self.db + return self._as_command + + def get_message( + self, read_preference: _ServerMode, conn: _AgnosticConnection, use_cmd: bool = False + ) -> tuple[int, bytes, int]: + """Get a query message, possibly setting the secondaryOk bit.""" + # Use the read_preference decided by _socket_from_server. + self.read_preference = read_preference + if read_preference.mode: + # Set the secondaryOk bit. + flags = self.flags | 4 + else: + flags = self.flags + + ns = self.namespace() + spec = self.spec + + if use_cmd: + spec = self.as_command(conn)[0] + request_id, msg, size, _ = _op_msg( + 0, + spec, + self.db, + read_preference, + self.codec_options, + ctx=conn.compression_context, + ) + return request_id, msg, size + + # OP_QUERY treats ntoreturn of -1 and 1 the same, return + # one document and close the cursor. We have to use 2 for + # batch size if 1 is specified. + ntoreturn = self.batch_size == 1 and 2 or self.batch_size + if self.limit: + if ntoreturn: + ntoreturn = min(self.limit, ntoreturn) + else: + ntoreturn = self.limit + + if conn.is_mongos: + assert isinstance(spec, MutableMapping) + spec = _maybe_add_read_preference(spec, read_preference) + + return _query( + flags, + ns, + self.ntoskip, + ntoreturn, + spec, + None if use_cmd else self.fields, + self.codec_options, + ctx=conn.compression_context, + ) + + +class _GetMore: + """A getmore operation.""" + + __slots__ = ( + "db", + "coll", + "ntoreturn", + "cursor_id", + "max_await_time_ms", + "codec_options", + "read_preference", + "session", + "client", + "conn_mgr", + "_as_command", + "exhaust", + "comment", + ) + + name = "getMore" + + def __init__( + self, + db: str, + coll: str, + ntoreturn: int, + cursor_id: int, + codec_options: CodecOptions, + read_preference: _ServerMode, + session: Optional[_AgnosticClientSession], + client: _AgnosticMongoClient, + max_await_time_ms: Optional[int], + conn_mgr: Any, + exhaust: bool, + comment: Any, + ): + self.db = db + self.coll = coll + self.ntoreturn = ntoreturn + self.cursor_id = cursor_id + self.codec_options = codec_options + self.read_preference = read_preference + self.session = session + self.client = client + self.max_await_time_ms = max_await_time_ms + self.conn_mgr = conn_mgr + self._as_command: Optional[tuple[dict[str, Any], str]] = None + self.exhaust = exhaust + self.comment = comment + + def reset(self) -> None: + self._as_command = None + + def namespace(self) -> str: + return f"{self.db}.{self.coll}" + + def use_command(self, conn: _AgnosticConnection) -> bool: + use_cmd = False + if not self.exhaust: + use_cmd = True + elif conn.max_wire_version >= 8: + # OP_MSG supports exhaust on MongoDB 4.2+ + use_cmd = True + + conn.validate_session(self.client, self.session) # type: ignore[arg-type] + return use_cmd + + def update_command(self, cmd: dict[str, Any]) -> None: + self._as_command = cmd, self.db + + def as_command( + self, conn: _AgnosticConnection, apply_timeout: bool = False + ) -> tuple[dict[str, Any], str]: + """Return a getMore command document for this query.""" + # See _Query.as_command for an explanation of this caching. + if self._as_command is not None: + return self._as_command + + cmd: dict[str, Any] = _gen_get_more_command( + self.cursor_id, + self.coll, + self.ntoreturn, + self.max_await_time_ms, + self.comment, + conn, + ) + if self.session: + self.session._apply_to(cmd, False, self.read_preference, conn) # type: ignore[arg-type] + conn.add_server_api(cmd) + conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type] + # Support CSOT + if apply_timeout: + conn.apply_timeout(self.client, cmd=None) # type: ignore[arg-type] + self._as_command = cmd, self.db + return self._as_command + + def get_message( + self, dummy0: Any, conn: _AgnosticConnection, use_cmd: bool = False + ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]: + """Get a getmore message.""" + ns = self.namespace() + ctx = conn.compression_context + + if use_cmd: + spec = self.as_command(conn)[0] + if self.conn_mgr and self.exhaust: + flags = _OpMsg.EXHAUST_ALLOWED + else: + flags = 0 + request_id, msg, size, _ = _op_msg( + flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context + ) + return request_id, msg, size + + return _get_more(ns, self.ntoreturn, self.cursor_id, ctx) + + +class _RawBatchQuery(_Query): + def use_command(self, conn: _AgnosticConnection) -> bool: + # Compatibility checks. + super().use_command(conn) + if conn.max_wire_version >= 8: + # MongoDB 4.2+ supports exhaust over OP_MSG + return True + elif not self.exhaust: + return True + return False + + +class _RawBatchGetMore(_GetMore): + def use_command(self, conn: _AgnosticConnection) -> bool: + # Compatibility checks. + super().use_command(conn) + if conn.max_wire_version >= 8: + # MongoDB 4.2+ supports exhaust over OP_MSG + return True + elif not self.exhaust: + return True + return False + + +class _CursorAddress(tuple): + """The server address (host, port) of a cursor, with namespace property.""" + + __namespace: Any + + def __new__(cls, address: _Address, namespace: str) -> _CursorAddress: + self = tuple.__new__(cls, address) + self.__namespace = namespace + return self + + @property + def namespace(self) -> str: + """The namespace this cursor.""" + return self.__namespace + + def __hash__(self) -> int: + # Two _CursorAddress instances with different namespaces + # must not hash the same. + return ((*self, self.__namespace)).__hash__() + + def __eq__(self, other: object) -> bool: + if isinstance(other, _CursorAddress): + return tuple(self) == tuple(other) and self.namespace == other.namespace + return NotImplemented + + def __ne__(self, other: object) -> bool: + return not self == other diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 68c2bbc4b..a815cbc8a 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -19,3 +19,4 @@ from pymongo.synchronous.mongo_client import * # noqa: F403 from pymongo.synchronous.mongo_client import __doc__ as original_doc __doc__ = original_doc +__all__ = ["MongoClient"] # noqa: F405 diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index b9825b4ca..260213e18 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -1,21 +1,1900 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2015-present MongoDB, Inc. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Tools to monitor driver events. + +.. versionadded:: 3.1 + +.. attention:: Starting in PyMongo 3.11, the monitoring classes outlined below + are included in the PyMongo distribution under the + :mod:`~pymongo.event_loggers` submodule. + +Use :func:`register` to register global listeners for specific events. +Listeners must inherit from one of the abstract classes below and implement +the correct functions for that class. + +For example, a simple command logger might be implemented like this:: + + import logging + + from pymongo import monitoring + + class CommandLogger(monitoring.CommandListener): + + def started(self, event): + logging.info("Command {0.command_name} with request id " + "{0.request_id} started on server " + "{0.connection_id}".format(event)) + + def succeeded(self, event): + logging.info("Command {0.command_name} with request id " + "{0.request_id} on server {0.connection_id} " + "succeeded in {0.duration_micros} " + "microseconds".format(event)) + + def failed(self, event): + logging.info("Command {0.command_name} with request id " + "{0.request_id} on server {0.connection_id} " + "failed in {0.duration_micros} " + "microseconds".format(event)) + + monitoring.register(CommandLogger()) + +Server discovery and monitoring events are also available. For example:: + + class ServerLogger(monitoring.ServerListener): + + def opened(self, event): + logging.info("Server {0.server_address} added to topology " + "{0.topology_id}".format(event)) + + def description_changed(self, event): + previous_server_type = event.previous_description.server_type + new_server_type = event.new_description.server_type + if new_server_type != previous_server_type: + # server_type_name was added in PyMongo 3.4 + logging.info( + "Server {0.server_address} changed type from " + "{0.previous_description.server_type_name} to " + "{0.new_description.server_type_name}".format(event)) + + def closed(self, event): + logging.warning("Server {0.server_address} removed from topology " + "{0.topology_id}".format(event)) + + + class HeartbeatLogger(monitoring.ServerHeartbeatListener): + + def started(self, event): + logging.info("Heartbeat sent to server " + "{0.connection_id}".format(event)) + + def succeeded(self, event): + # The reply.document attribute was added in PyMongo 3.4. + logging.info("Heartbeat to server {0.connection_id} " + "succeeded with reply " + "{0.reply.document}".format(event)) + + def failed(self, event): + logging.warning("Heartbeat to server {0.connection_id} " + "failed with error {0.reply}".format(event)) + + class TopologyLogger(monitoring.TopologyListener): + + def opened(self, event): + logging.info("Topology with id {0.topology_id} " + "opened".format(event)) + + def description_changed(self, event): + logging.info("Topology description updated for " + "topology id {0.topology_id}".format(event)) + previous_topology_type = event.previous_description.topology_type + new_topology_type = event.new_description.topology_type + if new_topology_type != previous_topology_type: + # topology_type_name was added in PyMongo 3.4 + logging.info( + "Topology {0.topology_id} changed type from " + "{0.previous_description.topology_type_name} to " + "{0.new_description.topology_type_name}".format(event)) + # The has_writable_server and has_readable_server methods + # were added in PyMongo 3.4. + if not event.new_description.has_writable_server(): + logging.warning("No writable servers available.") + if not event.new_description.has_readable_server(): + logging.warning("No readable servers available.") + + def closed(self, event): + logging.info("Topology with id {0.topology_id} " + "closed".format(event)) + +Connection monitoring and pooling events are also available. For example:: + + class ConnectionPoolLogger(ConnectionPoolListener): + + def pool_created(self, event): + logging.info("[pool {0.address}] pool created".format(event)) + + def pool_ready(self, event): + logging.info("[pool {0.address}] pool is ready".format(event)) + + def pool_cleared(self, event): + logging.info("[pool {0.address}] pool cleared".format(event)) + + def pool_closed(self, event): + logging.info("[pool {0.address}] pool closed".format(event)) + + def connection_created(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection created".format(event)) + + def connection_ready(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection setup succeeded".format(event)) + + def connection_closed(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection closed, reason: " + "{0.reason}".format(event)) + + def connection_check_out_started(self, event): + logging.info("[pool {0.address}] connection check out " + "started".format(event)) + + def connection_check_out_failed(self, event): + logging.info("[pool {0.address}] connection check out " + "failed, reason: {0.reason}".format(event)) + + def connection_checked_out(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection checked out of pool".format(event)) + + def connection_checked_in(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection checked into pool".format(event)) + + +Event listeners can also be registered per instance of +:class:`~pymongo.mongo_client.MongoClient`:: + + client = MongoClient(event_listeners=[CommandLogger()]) + +Note that previously registered global listeners are automatically included +when configuring per client event listeners. Registering a new global listener +will not add that listener to existing client instances. + +.. note:: Events are delivered **synchronously**. Application threads block + waiting for event handlers (e.g. :meth:`~CommandListener.started`) to + return. Care must be taken to ensure that your event handlers are efficient + enough to not adversely affect overall application performance. + +.. warning:: The command documents published through this API are *not* copies. + If you intend to modify them in any way you must copy them in your event + handler first. +""" -"""Re-import of synchronous Monitoring API for compatibility.""" from __future__ import annotations -from pymongo.synchronous.monitoring import * # noqa: F403 -from pymongo.synchronous.monitoring import __doc__ as original_doc +import datetime +from collections import abc, namedtuple +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence -__doc__ = original_doc +from bson.objectid import ObjectId +from pymongo.hello import Hello, HelloCompat +from pymongo.helpers_shared import _SENSITIVE_COMMANDS, _handle_exception +from pymongo.typings import _Address, _DocumentOut + +if TYPE_CHECKING: + from datetime import timedelta + + from pymongo.server_description import ServerDescription + from pymongo.topology_description import TopologyDescription + + +_Listeners = namedtuple( + "_Listeners", + ( + "command_listeners", + "server_listeners", + "server_heartbeat_listeners", + "topology_listeners", + "cmap_listeners", + ), +) + +_LISTENERS = _Listeners([], [], [], [], []) + + +class _EventListener: + """Abstract base class for all event listeners.""" + + +class CommandListener(_EventListener): + """Abstract base class for command listeners. + + Handles `CommandStartedEvent`, `CommandSucceededEvent`, + and `CommandFailedEvent`. + """ + + def started(self, event: CommandStartedEvent) -> None: + """Abstract method to handle a `CommandStartedEvent`. + + :param event: An instance of :class:`CommandStartedEvent`. + """ + raise NotImplementedError + + def succeeded(self, event: CommandSucceededEvent) -> None: + """Abstract method to handle a `CommandSucceededEvent`. + + :param event: An instance of :class:`CommandSucceededEvent`. + """ + raise NotImplementedError + + def failed(self, event: CommandFailedEvent) -> None: + """Abstract method to handle a `CommandFailedEvent`. + + :param event: An instance of :class:`CommandFailedEvent`. + """ + raise NotImplementedError + + +class ConnectionPoolListener(_EventListener): + """Abstract base class for connection pool listeners. + + Handles all of the connection pool events defined in the Connection + Monitoring and Pooling Specification: + :class:`PoolCreatedEvent`, :class:`PoolClearedEvent`, + :class:`PoolClosedEvent`, :class:`ConnectionCreatedEvent`, + :class:`ConnectionReadyEvent`, :class:`ConnectionClosedEvent`, + :class:`ConnectionCheckOutStartedEvent`, + :class:`ConnectionCheckOutFailedEvent`, + :class:`ConnectionCheckedOutEvent`, + and :class:`ConnectionCheckedInEvent`. + + .. versionadded:: 3.9 + """ + + def pool_created(self, event: PoolCreatedEvent) -> None: + """Abstract method to handle a :class:`PoolCreatedEvent`. + + Emitted when a connection Pool is created. + + :param event: An instance of :class:`PoolCreatedEvent`. + """ + raise NotImplementedError + + def pool_ready(self, event: PoolReadyEvent) -> None: + """Abstract method to handle a :class:`PoolReadyEvent`. + + Emitted when a connection Pool is marked ready. + + :param event: An instance of :class:`PoolReadyEvent`. + + .. versionadded:: 4.0 + """ + raise NotImplementedError + + def pool_cleared(self, event: PoolClearedEvent) -> None: + """Abstract method to handle a `PoolClearedEvent`. + + Emitted when a connection Pool is cleared. + + :param event: An instance of :class:`PoolClearedEvent`. + """ + raise NotImplementedError + + def pool_closed(self, event: PoolClosedEvent) -> None: + """Abstract method to handle a `PoolClosedEvent`. + + Emitted when a connection Pool is closed. + + :param event: An instance of :class:`PoolClosedEvent`. + """ + raise NotImplementedError + + def connection_created(self, event: ConnectionCreatedEvent) -> None: + """Abstract method to handle a :class:`ConnectionCreatedEvent`. + + Emitted when a connection Pool creates a Connection object. + + :param event: An instance of :class:`ConnectionCreatedEvent`. + """ + raise NotImplementedError + + def connection_ready(self, event: ConnectionReadyEvent) -> None: + """Abstract method to handle a :class:`ConnectionReadyEvent`. + + Emitted when a connection has finished its setup, and is now ready to + use. + + :param event: An instance of :class:`ConnectionReadyEvent`. + """ + raise NotImplementedError + + def connection_closed(self, event: ConnectionClosedEvent) -> None: + """Abstract method to handle a :class:`ConnectionClosedEvent`. + + Emitted when a connection Pool closes a connection. + + :param event: An instance of :class:`ConnectionClosedEvent`. + """ + raise NotImplementedError + + def connection_check_out_started(self, event: ConnectionCheckOutStartedEvent) -> None: + """Abstract method to handle a :class:`ConnectionCheckOutStartedEvent`. + + Emitted when the driver starts attempting to check out a connection. + + :param event: An instance of :class:`ConnectionCheckOutStartedEvent`. + """ + raise NotImplementedError + + def connection_check_out_failed(self, event: ConnectionCheckOutFailedEvent) -> None: + """Abstract method to handle a :class:`ConnectionCheckOutFailedEvent`. + + Emitted when the driver's attempt to check out a connection fails. + + :param event: An instance of :class:`ConnectionCheckOutFailedEvent`. + """ + raise NotImplementedError + + def connection_checked_out(self, event: ConnectionCheckedOutEvent) -> None: + """Abstract method to handle a :class:`ConnectionCheckedOutEvent`. + + Emitted when the driver successfully checks out a connection. + + :param event: An instance of :class:`ConnectionCheckedOutEvent`. + """ + raise NotImplementedError + + def connection_checked_in(self, event: ConnectionCheckedInEvent) -> None: + """Abstract method to handle a :class:`ConnectionCheckedInEvent`. + + Emitted when the driver checks in a connection back to the connection + Pool. + + :param event: An instance of :class:`ConnectionCheckedInEvent`. + """ + raise NotImplementedError + + +class ServerHeartbeatListener(_EventListener): + """Abstract base class for server heartbeat listeners. + + Handles `ServerHeartbeatStartedEvent`, `ServerHeartbeatSucceededEvent`, + and `ServerHeartbeatFailedEvent`. + + .. versionadded:: 3.3 + """ + + def started(self, event: ServerHeartbeatStartedEvent) -> None: + """Abstract method to handle a `ServerHeartbeatStartedEvent`. + + :param event: An instance of :class:`ServerHeartbeatStartedEvent`. + """ + raise NotImplementedError + + def succeeded(self, event: ServerHeartbeatSucceededEvent) -> None: + """Abstract method to handle a `ServerHeartbeatSucceededEvent`. + + :param event: An instance of :class:`ServerHeartbeatSucceededEvent`. + """ + raise NotImplementedError + + def failed(self, event: ServerHeartbeatFailedEvent) -> None: + """Abstract method to handle a `ServerHeartbeatFailedEvent`. + + :param event: An instance of :class:`ServerHeartbeatFailedEvent`. + """ + raise NotImplementedError + + +class TopologyListener(_EventListener): + """Abstract base class for topology monitoring listeners. + Handles `TopologyOpenedEvent`, `TopologyDescriptionChangedEvent`, and + `TopologyClosedEvent`. + + .. versionadded:: 3.3 + """ + + def opened(self, event: TopologyOpenedEvent) -> None: + """Abstract method to handle a `TopologyOpenedEvent`. + + :param event: An instance of :class:`TopologyOpenedEvent`. + """ + raise NotImplementedError + + def description_changed(self, event: TopologyDescriptionChangedEvent) -> None: + """Abstract method to handle a `TopologyDescriptionChangedEvent`. + + :param event: An instance of :class:`TopologyDescriptionChangedEvent`. + """ + raise NotImplementedError + + def closed(self, event: TopologyClosedEvent) -> None: + """Abstract method to handle a `TopologyClosedEvent`. + + :param event: An instance of :class:`TopologyClosedEvent`. + """ + raise NotImplementedError + + +class ServerListener(_EventListener): + """Abstract base class for server listeners. + Handles `ServerOpeningEvent`, `ServerDescriptionChangedEvent`, and + `ServerClosedEvent`. + + .. versionadded:: 3.3 + """ + + def opened(self, event: ServerOpeningEvent) -> None: + """Abstract method to handle a `ServerOpeningEvent`. + + :param event: An instance of :class:`ServerOpeningEvent`. + """ + raise NotImplementedError + + def description_changed(self, event: ServerDescriptionChangedEvent) -> None: + """Abstract method to handle a `ServerDescriptionChangedEvent`. + + :param event: An instance of :class:`ServerDescriptionChangedEvent`. + """ + raise NotImplementedError + + def closed(self, event: ServerClosedEvent) -> None: + """Abstract method to handle a `ServerClosedEvent`. + + :param event: An instance of :class:`ServerClosedEvent`. + """ + raise NotImplementedError + + +def _to_micros(dur: timedelta) -> int: + """Convert duration 'dur' to microseconds.""" + return int(dur.total_seconds() * 10e5) + + +def _validate_event_listeners( + option: str, listeners: Sequence[_EventListeners] +) -> Sequence[_EventListeners]: + """Validate event listeners""" + if not isinstance(listeners, abc.Sequence): + raise TypeError(f"{option} must be a list or tuple") + for listener in listeners: + if not isinstance(listener, _EventListener): + raise TypeError( + f"Listeners for {option} must be either a " + "CommandListener, ServerHeartbeatListener, " + "ServerListener, TopologyListener, or " + "ConnectionPoolListener." + ) + return listeners + + +def register(listener: _EventListener) -> None: + """Register a global event listener. + + :param listener: A subclasses of :class:`CommandListener`, + :class:`ServerHeartbeatListener`, :class:`ServerListener`, + :class:`TopologyListener`, or :class:`ConnectionPoolListener`. + """ + if not isinstance(listener, _EventListener): + raise TypeError( + f"Listeners for {listener} must be either a " + "CommandListener, ServerHeartbeatListener, " + "ServerListener, TopologyListener, or " + "ConnectionPoolListener." + ) + if isinstance(listener, CommandListener): + _LISTENERS.command_listeners.append(listener) + if isinstance(listener, ServerHeartbeatListener): + _LISTENERS.server_heartbeat_listeners.append(listener) + if isinstance(listener, ServerListener): + _LISTENERS.server_listeners.append(listener) + if isinstance(listener, TopologyListener): + _LISTENERS.topology_listeners.append(listener) + if isinstance(listener, ConnectionPoolListener): + _LISTENERS.cmap_listeners.append(listener) + + +# The "hello" command is also deemed sensitive when attempting speculative +# authentication. +def _is_speculative_authenticate(command_name: str, doc: Mapping[str, Any]) -> bool: + if ( + command_name.lower() in ("hello", HelloCompat.LEGACY_CMD) + and "speculativeAuthenticate" in doc + ): + return True + return False + + +class _CommandEvent: + """Base class for command events.""" + + __slots__ = ( + "__cmd_name", + "__rqst_id", + "__conn_id", + "__op_id", + "__service_id", + "__db", + "__server_conn_id", + ) + + def __init__( + self, + command_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + database_name: str = "", + server_connection_id: Optional[int] = None, + ) -> None: + self.__cmd_name = command_name + self.__rqst_id = request_id + self.__conn_id = connection_id + self.__op_id = operation_id + self.__service_id = service_id + self.__db = database_name + self.__server_conn_id = server_connection_id + + @property + def command_name(self) -> str: + """The command name.""" + return self.__cmd_name + + @property + def request_id(self) -> int: + """The request id for this operation.""" + return self.__rqst_id + + @property + def connection_id(self) -> _Address: + """The address (host, port) of the server this command was sent to.""" + return self.__conn_id + + @property + def service_id(self) -> Optional[ObjectId]: + """The service_id this command was sent to, or ``None``. + + .. versionadded:: 3.12 + """ + return self.__service_id + + @property + def operation_id(self) -> Optional[int]: + """An id for this series of events or None.""" + return self.__op_id + + @property + def database_name(self) -> str: + """The database_name this command was sent to, or ``""``. + + .. versionadded:: 4.6 + """ + return self.__db + + @property + def server_connection_id(self) -> Optional[int]: + """The server-side connection id for the connection this command was sent on, or ``None``. + + .. versionadded:: 4.7 + """ + return self.__server_conn_id + + +class CommandStartedEvent(_CommandEvent): + """Event published when a command starts. + + :param command: The command document. + :param database_name: The name of the database this command was run against. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this command + was sent to. + :param operation_id: An optional identifier for a series of related events. + :param service_id: The service_id this command was sent to, or ``None``. + """ + + __slots__ = ("__cmd",) + + def __init__( + self, + command: _DocumentOut, + database_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + server_connection_id: Optional[int] = None, + ) -> None: + if not command: + raise ValueError(f"{command!r} is not a valid command") + # Command name must be first key. + command_name = next(iter(command)) + super().__init__( + command_name, + request_id, + connection_id, + operation_id, + service_id=service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + cmd_name = command_name.lower() + if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, command): + self.__cmd: _DocumentOut = {} + else: + self.__cmd = command + + @property + def command(self) -> _DocumentOut: + """The command document.""" + return self.__cmd + + @property + def database_name(self) -> str: + """The name of the database this command was run against.""" + return super().database_name + + def __repr__(self) -> str: + return ( + "<{} {} db: {!r}, command: {!r}, operation_id: {}, service_id: {}, server_connection_id: {}>" + ).format( + self.__class__.__name__, + self.connection_id, + self.database_name, + self.command_name, + self.operation_id, + self.service_id, + self.server_connection_id, + ) + + +class CommandSucceededEvent(_CommandEvent): + """Event published when a command succeeds. + + :param duration: The command duration as a datetime.timedelta. + :param reply: The server reply document. + :param command_name: The command name. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this command + was sent to. + :param operation_id: An optional identifier for a series of related events. + :param service_id: The service_id this command was sent to, or ``None``. + :param database_name: The database this command was sent to, or ``""``. + """ + + __slots__ = ("__duration_micros", "__reply") + + def __init__( + self, + duration: datetime.timedelta, + reply: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + database_name: str = "", + server_connection_id: Optional[int] = None, + ) -> None: + super().__init__( + command_name, + request_id, + connection_id, + operation_id, + service_id=service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + self.__duration_micros = _to_micros(duration) + cmd_name = command_name.lower() + if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, reply): + self.__reply: _DocumentOut = {} + else: + self.__reply = reply + + @property + def duration_micros(self) -> int: + """The duration of this operation in microseconds.""" + return self.__duration_micros + + @property + def reply(self) -> _DocumentOut: + """The server failure document for this operation.""" + return self.__reply + + def __repr__(self) -> str: + return ( + "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, service_id: {}, server_connection_id: {}>" + ).format( + self.__class__.__name__, + self.connection_id, + self.database_name, + self.command_name, + self.operation_id, + self.duration_micros, + self.service_id, + self.server_connection_id, + ) + + +class CommandFailedEvent(_CommandEvent): + """Event published when a command fails. + + :param duration: The command duration as a datetime.timedelta. + :param failure: The server reply document. + :param command_name: The command name. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this command + was sent to. + :param operation_id: An optional identifier for a series of related events. + :param service_id: The service_id this command was sent to, or ``None``. + :param database_name: The database this command was sent to, or ``""``. + """ + + __slots__ = ("__duration_micros", "__failure") + + def __init__( + self, + duration: datetime.timedelta, + failure: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + database_name: str = "", + server_connection_id: Optional[int] = None, + ) -> None: + super().__init__( + command_name, + request_id, + connection_id, + operation_id, + service_id=service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + self.__duration_micros = _to_micros(duration) + self.__failure = failure + + @property + def duration_micros(self) -> int: + """The duration of this operation in microseconds.""" + return self.__duration_micros + + @property + def failure(self) -> _DocumentOut: + """The server failure document for this operation.""" + return self.__failure + + def __repr__(self) -> str: + return ( + "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, " + "failure: {!r}, service_id: {}, server_connection_id: {}>" + ).format( + self.__class__.__name__, + self.connection_id, + self.database_name, + self.command_name, + self.operation_id, + self.duration_micros, + self.failure, + self.service_id, + self.server_connection_id, + ) + + +class _PoolEvent: + """Base class for pool events.""" + + __slots__ = ("__address",) + + def __init__(self, address: _Address) -> None: + self.__address = address + + @property + def address(self) -> _Address: + """The address (host, port) pair of the server the pool is attempting + to connect to. + """ + return self.__address + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__address!r})" + + +class PoolCreatedEvent(_PoolEvent): + """Published when a Connection Pool is created. + + :param address: The address (host, port) pair of the server this Pool is + attempting to connect to. + + .. versionadded:: 3.9 + """ + + __slots__ = ("__options",) + + def __init__(self, address: _Address, options: dict[str, Any]) -> None: + super().__init__(address) + self.__options = options + + @property + def options(self) -> dict[str, Any]: + """Any non-default pool options that were set on this Connection Pool.""" + return self.__options + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.__options!r})" + + +class PoolReadyEvent(_PoolEvent): + """Published when a Connection Pool is marked ready. + + :param address: The address (host, port) pair of the server this Pool is + attempting to connect to. + + .. versionadded:: 4.0 + """ + + __slots__ = () + + +class PoolClearedEvent(_PoolEvent): + """Published when a Connection Pool is cleared. + + :param address: The address (host, port) pair of the server this Pool is + attempting to connect to. + :param service_id: The service_id this command was sent to, or ``None``. + :param interrupt_connections: True if all active connections were interrupted by the Pool during clearing. + + .. versionadded:: 3.9 + """ + + __slots__ = ("__service_id", "__interrupt_connections") + + def __init__( + self, + address: _Address, + service_id: Optional[ObjectId] = None, + interrupt_connections: bool = False, + ) -> None: + super().__init__(address) + self.__service_id = service_id + self.__interrupt_connections = interrupt_connections + + @property + def service_id(self) -> Optional[ObjectId]: + """Connections with this service_id are cleared. + + When service_id is ``None``, all connections in the pool are cleared. + + .. versionadded:: 3.12 + """ + return self.__service_id + + @property + def interrupt_connections(self) -> bool: + """If True, active connections are interrupted during clearing. + + .. versionadded:: 4.7 + """ + return self.__interrupt_connections + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.__service_id!r}, {self.__interrupt_connections!r})" + + +class PoolClosedEvent(_PoolEvent): + """Published when a Connection Pool is closed. + + :param address: The address (host, port) pair of the server this Pool is + attempting to connect to. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionClosedReason: + """An enum that defines values for `reason` on a + :class:`ConnectionClosedEvent`. + + .. versionadded:: 3.9 + """ + + STALE = "stale" + """The pool was cleared, making the connection no longer valid.""" + + IDLE = "idle" + """The connection became stale by being idle for too long (maxIdleTimeMS). + """ + + ERROR = "error" + """The connection experienced an error, making it no longer valid.""" + + POOL_CLOSED = "poolClosed" + """The pool was closed, making the connection no longer valid.""" + + +class ConnectionCheckOutFailedReason: + """An enum that defines values for `reason` on a + :class:`ConnectionCheckOutFailedEvent`. + + .. versionadded:: 3.9 + """ + + TIMEOUT = "timeout" + """The connection check out attempt exceeded the specified timeout.""" + + POOL_CLOSED = "poolClosed" + """The pool was previously closed, and cannot provide new connections.""" + + CONN_ERROR = "connectionError" + """The connection check out attempt experienced an error while setting up + a new connection. + """ + + +class _ConnectionEvent: + """Private base class for connection events.""" + + __slots__ = ("__address",) + + def __init__(self, address: _Address) -> None: + self.__address = address + + @property + def address(self) -> _Address: + """The address (host, port) pair of the server this connection is + attempting to connect to. + """ + return self.__address + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__address!r})" + + +class _ConnectionIdEvent(_ConnectionEvent): + """Private base class for connection events with an id.""" + + __slots__ = ("__connection_id",) + + def __init__(self, address: _Address, connection_id: int) -> None: + super().__init__(address) + self.__connection_id = connection_id + + @property + def connection_id(self) -> int: + """The ID of the connection.""" + return self.__connection_id + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.__connection_id!r})" + + +class _ConnectionDurationEvent(_ConnectionIdEvent): + """Private base class for connection events with a duration.""" + + __slots__ = ("__duration",) + + def __init__(self, address: _Address, connection_id: int, duration: Optional[float]) -> None: + super().__init__(address, connection_id) + self.__duration = duration + + @property + def duration(self) -> Optional[float]: + """The duration of the connection event. + + .. versionadded:: 4.7 + """ + return self.__duration + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.connection_id!r}, {self.__duration!r})" + + +class ConnectionCreatedEvent(_ConnectionIdEvent): + """Published when a Connection Pool creates a Connection object. + + NOTE: This connection is not ready for use until the + :class:`ConnectionReadyEvent` is published. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionReadyEvent(_ConnectionDurationEvent): + """Published when a Connection has finished its setup, and is ready to use. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionClosedEvent(_ConnectionIdEvent): + """Published when a Connection is closed. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + :param reason: A reason explaining why this connection was closed. + + .. versionadded:: 3.9 + """ + + __slots__ = ("__reason",) + + def __init__(self, address: _Address, connection_id: int, reason: str): + super().__init__(address, connection_id) + self.__reason = reason + + @property + def reason(self) -> str: + """A reason explaining why this connection was closed. + + The reason must be one of the strings from the + :class:`ConnectionClosedReason` enum. + """ + return self.__reason + + def __repr__(self) -> str: + return "{}({!r}, {!r}, {!r})".format( + self.__class__.__name__, + self.address, + self.connection_id, + self.__reason, + ) + + +class ConnectionCheckOutStartedEvent(_ConnectionEvent): + """Published when the driver starts attempting to check out a connection. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionCheckOutFailedEvent(_ConnectionDurationEvent): + """Published when the driver's attempt to check out a connection fails. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param reason: A reason explaining why connection check out failed. + + .. versionadded:: 3.9 + """ + + __slots__ = ("__reason",) + + def __init__(self, address: _Address, reason: str, duration: Optional[float]) -> None: + super().__init__(address=address, connection_id=0, duration=duration) + self.__reason = reason + + @property + def reason(self) -> str: + """A reason explaining why connection check out failed. + + The reason must be one of the strings from the + :class:`ConnectionCheckOutFailedReason` enum. + """ + return self.__reason + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.__reason!r}, {self.duration!r})" + + +class ConnectionCheckedOutEvent(_ConnectionDurationEvent): + """Published when the driver successfully checks out a connection. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionCheckedInEvent(_ConnectionIdEvent): + """Published when the driver checks in a Connection into the Pool. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class _ServerEvent: + """Base class for server events.""" + + __slots__ = ("__server_address", "__topology_id") + + def __init__(self, server_address: _Address, topology_id: ObjectId) -> None: + self.__server_address = server_address + self.__topology_id = topology_id + + @property + def server_address(self) -> _Address: + """The address (host, port) pair of the server""" + return self.__server_address + + @property + def topology_id(self) -> ObjectId: + """A unique identifier for the topology this server is a part of.""" + return self.__topology_id + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.server_address} topology_id: {self.topology_id}>" + + +class ServerDescriptionChangedEvent(_ServerEvent): + """Published when server description changes. + + .. versionadded:: 3.3 + """ + + __slots__ = ("__previous_description", "__new_description") + + def __init__( + self, + previous_description: ServerDescription, + new_description: ServerDescription, + *args: Any, + ) -> None: + super().__init__(*args) + self.__previous_description = previous_description + self.__new_description = new_description + + @property + def previous_description(self) -> ServerDescription: + """The previous + :class:`~pymongo.server_description.ServerDescription`. + """ + return self.__previous_description + + @property + def new_description(self) -> ServerDescription: + """The new + :class:`~pymongo.server_description.ServerDescription`. + """ + return self.__new_description + + def __repr__(self) -> str: + return "<{} {} changed from: {}, to: {}>".format( + self.__class__.__name__, + self.server_address, + self.previous_description, + self.new_description, + ) + + +class ServerOpeningEvent(_ServerEvent): + """Published when server is initialized. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class ServerClosedEvent(_ServerEvent): + """Published when server is closed. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class TopologyEvent: + """Base class for topology description events.""" + + __slots__ = ("__topology_id",) + + def __init__(self, topology_id: ObjectId) -> None: + self.__topology_id = topology_id + + @property + def topology_id(self) -> ObjectId: + """A unique identifier for the topology this server is a part of.""" + return self.__topology_id + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} topology_id: {self.topology_id}>" + + +class TopologyDescriptionChangedEvent(TopologyEvent): + """Published when the topology description changes. + + .. versionadded:: 3.3 + """ + + __slots__ = ("__previous_description", "__new_description") + + def __init__( + self, + previous_description: TopologyDescription, + new_description: TopologyDescription, + *args: Any, + ) -> None: + super().__init__(*args) + self.__previous_description = previous_description + self.__new_description = new_description + + @property + def previous_description(self) -> TopologyDescription: + """The previous + :class:`~pymongo.topology_description.TopologyDescription`. + """ + return self.__previous_description + + @property + def new_description(self) -> TopologyDescription: + """The new + :class:`~pymongo.topology_description.TopologyDescription`. + """ + return self.__new_description + + def __repr__(self) -> str: + return "<{} topology_id: {} changed from: {}, to: {}>".format( + self.__class__.__name__, + self.topology_id, + self.previous_description, + self.new_description, + ) + + +class TopologyOpenedEvent(TopologyEvent): + """Published when the topology is initialized. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class TopologyClosedEvent(TopologyEvent): + """Published when the topology is closed. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class _ServerHeartbeatEvent: + """Base class for server heartbeat events.""" + + __slots__ = ("__connection_id", "__awaited") + + def __init__(self, connection_id: _Address, awaited: bool = False) -> None: + self.__connection_id = connection_id + self.__awaited = awaited + + @property + def connection_id(self) -> _Address: + """The address (host, port) of the server this heartbeat was sent + to. + """ + return self.__connection_id + + @property + def awaited(self) -> bool: + """Whether the heartbeat was issued as an awaitable hello command. + + .. versionadded:: 4.6 + """ + return self.__awaited + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.connection_id} awaited: {self.awaited}>" + + +class ServerHeartbeatStartedEvent(_ServerHeartbeatEvent): + """Published when a heartbeat is started. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent): + """Fired when the server heartbeat succeeds. + + .. versionadded:: 3.3 + """ + + __slots__ = ("__duration", "__reply") + + def __init__( + self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False + ) -> None: + super().__init__(connection_id, awaited) + self.__duration = duration + self.__reply = reply + + @property + def duration(self) -> float: + """The duration of this heartbeat in microseconds.""" + return self.__duration + + @property + def reply(self) -> Hello: + """An instance of :class:`~pymongo.hello.Hello`.""" + return self.__reply + + @property + def awaited(self) -> bool: + """Whether the heartbeat was awaited. + + If true, then :meth:`duration` reflects the sum of the round trip time + to the server and the time that the server waited before sending a + response. + + .. versionadded:: 3.11 + """ + return super().awaited + + def __repr__(self) -> str: + return "<{} {} duration: {}, awaited: {}, reply: {}>".format( + self.__class__.__name__, + self.connection_id, + self.duration, + self.awaited, + self.reply, + ) + + +class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent): + """Fired when the server heartbeat fails, either with an "ok: 0" + or a socket exception. + + .. versionadded:: 3.3 + """ + + __slots__ = ("__duration", "__reply") + + def __init__( + self, duration: float, reply: Exception, connection_id: _Address, awaited: bool = False + ) -> None: + super().__init__(connection_id, awaited) + self.__duration = duration + self.__reply = reply + + @property + def duration(self) -> float: + """The duration of this heartbeat in microseconds.""" + return self.__duration + + @property + def reply(self) -> Exception: + """A subclass of :exc:`Exception`.""" + return self.__reply + + @property + def awaited(self) -> bool: + """Whether the heartbeat was awaited. + + If true, then :meth:`duration` reflects the sum of the round trip time + to the server and the time that the server waited before sending a + response. + + .. versionadded:: 3.11 + """ + return super().awaited + + def __repr__(self) -> str: + return "<{} {} duration: {}, awaited: {}, reply: {!r}>".format( + self.__class__.__name__, + self.connection_id, + self.duration, + self.awaited, + self.reply, + ) + + +class _EventListeners: + """Configure event listeners for a client instance. + + Any event listeners registered globally are included by default. + + :param listeners: A list of event listeners. + """ + + def __init__(self, listeners: Optional[Sequence[_EventListener]]): + self.__command_listeners = _LISTENERS.command_listeners[:] + self.__server_listeners = _LISTENERS.server_listeners[:] + lst = _LISTENERS.server_heartbeat_listeners + self.__server_heartbeat_listeners = lst[:] + self.__topology_listeners = _LISTENERS.topology_listeners[:] + self.__cmap_listeners = _LISTENERS.cmap_listeners[:] + if listeners is not None: + for lst in listeners: + if isinstance(lst, CommandListener): + self.__command_listeners.append(lst) + if isinstance(lst, ServerListener): + self.__server_listeners.append(lst) + if isinstance(lst, ServerHeartbeatListener): + self.__server_heartbeat_listeners.append(lst) + if isinstance(lst, TopologyListener): + self.__topology_listeners.append(lst) + if isinstance(lst, ConnectionPoolListener): + self.__cmap_listeners.append(lst) + self.__enabled_for_commands = bool(self.__command_listeners) + self.__enabled_for_server = bool(self.__server_listeners) + self.__enabled_for_server_heartbeat = bool(self.__server_heartbeat_listeners) + self.__enabled_for_topology = bool(self.__topology_listeners) + self.__enabled_for_cmap = bool(self.__cmap_listeners) + + @property + def enabled_for_commands(self) -> bool: + """Are any CommandListener instances registered?""" + return self.__enabled_for_commands + + @property + def enabled_for_server(self) -> bool: + """Are any ServerListener instances registered?""" + return self.__enabled_for_server + + @property + def enabled_for_server_heartbeat(self) -> bool: + """Are any ServerHeartbeatListener instances registered?""" + return self.__enabled_for_server_heartbeat + + @property + def enabled_for_topology(self) -> bool: + """Are any TopologyListener instances registered?""" + return self.__enabled_for_topology + + @property + def enabled_for_cmap(self) -> bool: + """Are any ConnectionPoolListener instances registered?""" + return self.__enabled_for_cmap + + def event_listeners(self) -> list[_EventListeners]: + """List of registered event listeners.""" + return ( + self.__command_listeners + + self.__server_heartbeat_listeners + + self.__server_listeners + + self.__topology_listeners + + self.__cmap_listeners + ) + + def publish_command_start( + self, + command: _DocumentOut, + database_name: str, + request_id: int, + connection_id: _Address, + server_connection_id: Optional[int], + op_id: Optional[int] = None, + service_id: Optional[ObjectId] = None, + ) -> None: + """Publish a CommandStartedEvent to all command listeners. + + :param command: The command document. + :param database_name: The name of the database this command was run + against. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this + command was sent to. + :param op_id: The (optional) operation id for this operation. + :param service_id: The service_id this command was sent to, or ``None``. + """ + if op_id is None: + op_id = request_id + event = CommandStartedEvent( + command, + database_name, + request_id, + connection_id, + op_id, + service_id=service_id, + server_connection_id=server_connection_id, + ) + for subscriber in self.__command_listeners: + try: + subscriber.started(event) + except Exception: + _handle_exception() + + def publish_command_success( + self, + duration: timedelta, + reply: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + server_connection_id: Optional[int], + op_id: Optional[int] = None, + service_id: Optional[ObjectId] = None, + speculative_hello: bool = False, + database_name: str = "", + ) -> None: + """Publish a CommandSucceededEvent to all command listeners. + + :param duration: The command duration as a datetime.timedelta. + :param reply: The server reply document. + :param command_name: The command name. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this + command was sent to. + :param op_id: The (optional) operation id for this operation. + :param service_id: The service_id this command was sent to, or ``None``. + :param speculative_hello: Was the command sent with speculative auth? + :param database_name: The database this command was sent to, or ``""``. + """ + if op_id is None: + op_id = request_id + if speculative_hello: + # Redact entire response when the command started contained + # speculativeAuthenticate. + reply = {} + event = CommandSucceededEvent( + duration, + reply, + command_name, + request_id, + connection_id, + op_id, + service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + for subscriber in self.__command_listeners: + try: + subscriber.succeeded(event) + except Exception: + _handle_exception() + + def publish_command_failure( + self, + duration: timedelta, + failure: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + server_connection_id: Optional[int], + op_id: Optional[int] = None, + service_id: Optional[ObjectId] = None, + database_name: str = "", + ) -> None: + """Publish a CommandFailedEvent to all command listeners. + + :param duration: The command duration as a datetime.timedelta. + :param failure: The server reply document or failure description + document. + :param command_name: The command name. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this + command was sent to. + :param op_id: The (optional) operation id for this operation. + :param service_id: The service_id this command was sent to, or ``None``. + :param database_name: The database this command was sent to, or ``""``. + """ + if op_id is None: + op_id = request_id + event = CommandFailedEvent( + duration, + failure, + command_name, + request_id, + connection_id, + op_id, + service_id=service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + for subscriber in self.__command_listeners: + try: + subscriber.failed(event) + except Exception: + _handle_exception() + + def publish_server_heartbeat_started(self, connection_id: _Address, awaited: bool) -> None: + """Publish a ServerHeartbeatStartedEvent to all server heartbeat + listeners. + + :param connection_id: The address (host, port) pair of the connection. + :param awaited: True if this heartbeat is part of an awaitable hello command. + """ + event = ServerHeartbeatStartedEvent(connection_id, awaited) + for subscriber in self.__server_heartbeat_listeners: + try: + subscriber.started(event) + except Exception: + _handle_exception() + + def publish_server_heartbeat_succeeded( + self, connection_id: _Address, duration: float, reply: Hello, awaited: bool + ) -> None: + """Publish a ServerHeartbeatSucceededEvent to all server heartbeat + listeners. + + :param connection_id: The address (host, port) pair of the connection. + :param duration: The execution time of the event in the highest possible + resolution for the platform. + :param reply: The command reply. + :param awaited: True if the response was awaited. + """ + event = ServerHeartbeatSucceededEvent(duration, reply, connection_id, awaited) + for subscriber in self.__server_heartbeat_listeners: + try: + subscriber.succeeded(event) + except Exception: + _handle_exception() + + def publish_server_heartbeat_failed( + self, connection_id: _Address, duration: float, reply: Exception, awaited: bool + ) -> None: + """Publish a ServerHeartbeatFailedEvent to all server heartbeat + listeners. + + :param connection_id: The address (host, port) pair of the connection. + :param duration: The execution time of the event in the highest possible + resolution for the platform. + :param reply: The command reply. + :param awaited: True if the response was awaited. + """ + event = ServerHeartbeatFailedEvent(duration, reply, connection_id, awaited) + for subscriber in self.__server_heartbeat_listeners: + try: + subscriber.failed(event) + except Exception: + _handle_exception() + + def publish_server_opened(self, server_address: _Address, topology_id: ObjectId) -> None: + """Publish a ServerOpeningEvent to all server listeners. + + :param server_address: The address (host, port) pair of the server. + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = ServerOpeningEvent(server_address, topology_id) + for subscriber in self.__server_listeners: + try: + subscriber.opened(event) + except Exception: + _handle_exception() + + def publish_server_closed(self, server_address: _Address, topology_id: ObjectId) -> None: + """Publish a ServerClosedEvent to all server listeners. + + :param server_address: The address (host, port) pair of the server. + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = ServerClosedEvent(server_address, topology_id) + for subscriber in self.__server_listeners: + try: + subscriber.closed(event) + except Exception: + _handle_exception() + + def publish_server_description_changed( + self, + previous_description: ServerDescription, + new_description: ServerDescription, + server_address: _Address, + topology_id: ObjectId, + ) -> None: + """Publish a ServerDescriptionChangedEvent to all server listeners. + + :param previous_description: The previous server description. + :param server_address: The address (host, port) pair of the server. + :param new_description: The new server description. + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = ServerDescriptionChangedEvent( + previous_description, new_description, server_address, topology_id + ) + for subscriber in self.__server_listeners: + try: + subscriber.description_changed(event) + except Exception: + _handle_exception() + + def publish_topology_opened(self, topology_id: ObjectId) -> None: + """Publish a TopologyOpenedEvent to all topology listeners. + + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = TopologyOpenedEvent(topology_id) + for subscriber in self.__topology_listeners: + try: + subscriber.opened(event) + except Exception: + _handle_exception() + + def publish_topology_closed(self, topology_id: ObjectId) -> None: + """Publish a TopologyClosedEvent to all topology listeners. + + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = TopologyClosedEvent(topology_id) + for subscriber in self.__topology_listeners: + try: + subscriber.closed(event) + except Exception: + _handle_exception() + + def publish_topology_description_changed( + self, + previous_description: TopologyDescription, + new_description: TopologyDescription, + topology_id: ObjectId, + ) -> None: + """Publish a TopologyDescriptionChangedEvent to all topology listeners. + + :param previous_description: The previous topology description. + :param new_description: The new topology description. + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = TopologyDescriptionChangedEvent(previous_description, new_description, topology_id) + for subscriber in self.__topology_listeners: + try: + subscriber.description_changed(event) + except Exception: + _handle_exception() + + def publish_pool_created(self, address: _Address, options: dict[str, Any]) -> None: + """Publish a :class:`PoolCreatedEvent` to all pool listeners.""" + event = PoolCreatedEvent(address, options) + for subscriber in self.__cmap_listeners: + try: + subscriber.pool_created(event) + except Exception: + _handle_exception() + + def publish_pool_ready(self, address: _Address) -> None: + """Publish a :class:`PoolReadyEvent` to all pool listeners.""" + event = PoolReadyEvent(address) + for subscriber in self.__cmap_listeners: + try: + subscriber.pool_ready(event) + except Exception: + _handle_exception() + + def publish_pool_cleared( + self, + address: _Address, + service_id: Optional[ObjectId], + interrupt_connections: bool = False, + ) -> None: + """Publish a :class:`PoolClearedEvent` to all pool listeners.""" + event = PoolClearedEvent(address, service_id, interrupt_connections) + for subscriber in self.__cmap_listeners: + try: + subscriber.pool_cleared(event) + except Exception: + _handle_exception() + + def publish_pool_closed(self, address: _Address) -> None: + """Publish a :class:`PoolClosedEvent` to all pool listeners.""" + event = PoolClosedEvent(address) + for subscriber in self.__cmap_listeners: + try: + subscriber.pool_closed(event) + except Exception: + _handle_exception() + + def publish_connection_created(self, address: _Address, connection_id: int) -> None: + """Publish a :class:`ConnectionCreatedEvent` to all connection + listeners. + """ + event = ConnectionCreatedEvent(address, connection_id) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_created(event) + except Exception: + _handle_exception() + + def publish_connection_ready( + self, address: _Address, connection_id: int, duration: float + ) -> None: + """Publish a :class:`ConnectionReadyEvent` to all connection listeners.""" + event = ConnectionReadyEvent(address, connection_id, duration) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_ready(event) + except Exception: + _handle_exception() + + def publish_connection_closed(self, address: _Address, connection_id: int, reason: str) -> None: + """Publish a :class:`ConnectionClosedEvent` to all connection + listeners. + """ + event = ConnectionClosedEvent(address, connection_id, reason) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_closed(event) + except Exception: + _handle_exception() + + def publish_connection_check_out_started(self, address: _Address) -> None: + """Publish a :class:`ConnectionCheckOutStartedEvent` to all connection + listeners. + """ + event = ConnectionCheckOutStartedEvent(address) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_check_out_started(event) + except Exception: + _handle_exception() + + def publish_connection_check_out_failed( + self, address: _Address, reason: str, duration: float + ) -> None: + """Publish a :class:`ConnectionCheckOutFailedEvent` to all connection + listeners. + """ + event = ConnectionCheckOutFailedEvent(address, reason, duration) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_check_out_failed(event) + except Exception: + _handle_exception() + + def publish_connection_checked_out( + self, address: _Address, connection_id: int, duration: float + ) -> None: + """Publish a :class:`ConnectionCheckedOutEvent` to all connection + listeners. + """ + event = ConnectionCheckedOutEvent(address, connection_id, duration) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_checked_out(event) + except Exception: + _handle_exception() + + def publish_connection_checked_in(self, address: _Address, connection_id: int) -> None: + """Publish a :class:`ConnectionCheckedInEvent` to all connection + listeners. + """ + event = ConnectionCheckedInEvent(address, connection_id) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_checked_in(event) + except Exception: + _handle_exception() diff --git a/pymongo/operations.py b/pymongo/operations.py index dbfc048a6..2967a2944 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -1,4 +1,4 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,613 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-import of synchronous Operations API for compatibility.""" +"""Operation class definitions.""" from __future__ import annotations -from pymongo.synchronous.operations import * # noqa: F403 -from pymongo.synchronous.operations import __doc__ as original_doc +import enum +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) -__doc__ = original_doc +from bson.raw_bson import RawBSONDocument +from pymongo import helpers_shared +from pymongo.collation import validate_collation_or_none +from pymongo.common import validate_is_mapping, validate_list +from pymongo.helpers_shared import _gen_index_name, _index_document, _index_list +from pymongo.typings import _CollationIn, _DocumentType, _Pipeline +from pymongo.write_concern import validate_boolean + +if TYPE_CHECKING: + from pymongo.typings import _AgnosticBulk + + +# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary +_IndexList = Union[ + Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any] +] +_IndexKeyHint = Union[str, _IndexList] + + +class _Op(str, enum.Enum): + ABORT = "abortTransaction" + AGGREGATE = "aggregate" + COMMIT = "commitTransaction" + COUNT = "count" + CREATE = "create" + CREATE_INDEXES = "createIndexes" + CREATE_SEARCH_INDEXES = "createSearchIndexes" + DELETE = "delete" + DISTINCT = "distinct" + DROP = "drop" + DROP_DATABASE = "dropDatabase" + DROP_INDEXES = "dropIndexes" + DROP_SEARCH_INDEXES = "dropSearchIndexes" + END_SESSIONS = "endSessions" + FIND_AND_MODIFY = "findAndModify" + FIND = "find" + INSERT = "insert" + LIST_COLLECTIONS = "listCollections" + LIST_INDEXES = "listIndexes" + LIST_SEARCH_INDEX = "listSearchIndexes" + LIST_DATABASES = "listDatabases" + UPDATE = "update" + UPDATE_INDEX = "updateIndex" + UPDATE_SEARCH_INDEX = "updateSearchIndex" + RENAME = "rename" + GETMORE = "getMore" + KILL_CURSORS = "killCursors" + TEST = "testOperation" + + +class InsertOne(Generic[_DocumentType]): + """Represents an insert_one operation.""" + + __slots__ = ("_doc",) + + def __init__(self, document: _DocumentType) -> None: + """Create an InsertOne instance. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`. + + :param document: The document to insert. If the document is missing an + _id field one will be added. + """ + self._doc = document + + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" + bulkobj.add_insert(self._doc) # type: ignore[arg-type] + + def __repr__(self) -> str: + return f"InsertOne({self._doc!r})" + + def __eq__(self, other: Any) -> bool: + if type(other) == type(self): + return other._doc == self._doc + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + +class DeleteOne: + """Represents a delete_one operation.""" + + __slots__ = ("_filter", "_collation", "_hint") + + def __init__( + self, + filter: Mapping[str, Any], + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Create a DeleteOne instance. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the document to delete. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.4 and above. + + .. versionchanged:: 3.11 + Added the ``hint`` option. + .. versionchanged:: 3.5 + Added the `collation` option. + """ + if filter is not None: + validate_is_mapping("filter", filter) + if hint is not None and not isinstance(hint, str): + self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint) + else: + self._hint = hint + self._filter = filter + self._collation = collation + + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" + bulkobj.add_delete( + self._filter, + 1, + collation=validate_collation_or_none(self._collation), + hint=self._hint, + ) + + def __repr__(self) -> str: + return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})" + + def __eq__(self, other: Any) -> bool: + if type(other) == type(self): + return (other._filter, other._collation, other._hint) == ( + self._filter, + self._collation, + self._hint, + ) + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + +class DeleteMany: + """Represents a delete_many operation.""" + + __slots__ = ("_filter", "_collation", "_hint") + + def __init__( + self, + filter: Mapping[str, Any], + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Create a DeleteMany instance. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the documents to delete. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.4 and above. + + .. versionchanged:: 3.11 + Added the ``hint`` option. + .. versionchanged:: 3.5 + Added the `collation` option. + """ + if filter is not None: + validate_is_mapping("filter", filter) + if hint is not None and not isinstance(hint, str): + self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint) + else: + self._hint = hint + self._filter = filter + self._collation = collation + + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" + bulkobj.add_delete( + self._filter, + 0, + collation=validate_collation_or_none(self._collation), + hint=self._hint, + ) + + def __repr__(self) -> str: + return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})" + + def __eq__(self, other: Any) -> bool: + if type(other) == type(self): + return (other._filter, other._collation, other._hint) == ( + self._filter, + self._collation, + self._hint, + ) + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + +class ReplaceOne(Generic[_DocumentType]): + """Represents a replace_one operation.""" + + __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint") + + def __init__( + self, + filter: Mapping[str, Any], + replacement: Union[_DocumentType, RawBSONDocument], + upsert: bool = False, + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Create a ReplaceOne instance. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the document to replace. + :param replacement: The new document. + :param upsert: If ``True``, perform an insert if no documents + match the filter. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.2 and above. + + .. versionchanged:: 3.11 + Added the ``hint`` option. + .. versionchanged:: 3.5 + Added the ``collation`` option. + """ + if filter is not None: + validate_is_mapping("filter", filter) + if upsert is not None: + validate_boolean("upsert", upsert) + if hint is not None and not isinstance(hint, str): + self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint) + else: + self._hint = hint + self._filter = filter + self._doc = replacement + self._upsert = upsert + self._collation = collation + + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" + bulkobj.add_replace( + self._filter, + self._doc, + self._upsert, + collation=validate_collation_or_none(self._collation), + hint=self._hint, + ) + + def __eq__(self, other: Any) -> bool: + if type(other) == type(self): + return ( + other._filter, + other._doc, + other._upsert, + other._collation, + other._hint, + ) == ( + self._filter, + self._doc, + self._upsert, + self._collation, + other._hint, + ) + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __repr__(self) -> str: + return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format( + self.__class__.__name__, + self._filter, + self._doc, + self._upsert, + self._collation, + self._hint, + ) + + +class _UpdateOp: + """Private base class for update operations.""" + + __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint") + + def __init__( + self, + filter: Mapping[str, Any], + doc: Union[Mapping[str, Any], _Pipeline], + upsert: bool, + collation: Optional[_CollationIn], + array_filters: Optional[list[Mapping[str, Any]]], + hint: Optional[_IndexKeyHint], + ): + if filter is not None: + validate_is_mapping("filter", filter) + if upsert is not None: + validate_boolean("upsert", upsert) + if array_filters is not None: + validate_list("array_filters", array_filters) + if hint is not None and not isinstance(hint, str): + self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint) + else: + self._hint = hint + + self._filter = filter + self._doc = doc + self._upsert = upsert + self._collation = collation + self._array_filters = array_filters + + def __eq__(self, other: object) -> bool: + if isinstance(other, type(self)): + return ( + other._filter, + other._doc, + other._upsert, + other._collation, + other._array_filters, + other._hint, + ) == ( + self._filter, + self._doc, + self._upsert, + self._collation, + self._array_filters, + self._hint, + ) + return NotImplemented + + def __repr__(self) -> str: + return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( + self.__class__.__name__, + self._filter, + self._doc, + self._upsert, + self._collation, + self._array_filters, + self._hint, + ) + + +class UpdateOne(_UpdateOp): + """Represents an update_one operation.""" + + __slots__ = () + + def __init__( + self, + filter: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + upsert: bool = False, + collation: Optional[_CollationIn] = None, + array_filters: Optional[list[Mapping[str, Any]]] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Represents an update_one operation. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the document to update. + :param update: The modifications to apply. + :param upsert: If ``True``, perform an insert if no documents + match the filter. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param array_filters: A list of filters specifying which + array elements an update should apply. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.2 and above. + + .. versionchanged:: 3.11 + Added the `hint` option. + .. versionchanged:: 3.9 + Added the ability to accept a pipeline as the `update`. + .. versionchanged:: 3.6 + Added the `array_filters` option. + .. versionchanged:: 3.5 + Added the `collation` option. + """ + super().__init__(filter, update, upsert, collation, array_filters, hint) + + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" + bulkobj.add_update( + self._filter, + self._doc, + False, + self._upsert, + collation=validate_collation_or_none(self._collation), + array_filters=self._array_filters, + hint=self._hint, + ) + + +class UpdateMany(_UpdateOp): + """Represents an update_many operation.""" + + __slots__ = () + + def __init__( + self, + filter: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + upsert: bool = False, + collation: Optional[_CollationIn] = None, + array_filters: Optional[list[Mapping[str, Any]]] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Create an UpdateMany instance. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the documents to update. + :param update: The modifications to apply. + :param upsert: If ``True``, perform an insert if no documents + match the filter. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param array_filters: A list of filters specifying which + array elements an update should apply. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.2 and above. + + .. versionchanged:: 3.11 + Added the `hint` option. + .. versionchanged:: 3.9 + Added the ability to accept a pipeline as the `update`. + .. versionchanged:: 3.6 + Added the `array_filters` option. + .. versionchanged:: 3.5 + Added the `collation` option. + """ + super().__init__(filter, update, upsert, collation, array_filters, hint) + + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" + bulkobj.add_update( + self._filter, + self._doc, + True, + self._upsert, + collation=validate_collation_or_none(self._collation), + array_filters=self._array_filters, + hint=self._hint, + ) + + +class IndexModel: + """Represents an index to create.""" + + __slots__ = ("__document",) + + def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None: + """Create an Index instance. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_indexes` and :meth:`~pymongo.collection.Collection.create_indexes`. + + Takes either a single key or a list containing (key, direction) pairs + or keys. If no direction is given, :data:`~pymongo.ASCENDING` will + be assumed. + The key(s) must be an instance of :class:`str`, and the direction(s) must + be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, + :data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`, + :data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`). + + Valid options include, but are not limited to: + + - `name`: custom name to use for this index - if none is + given, a name will be generated. + - `unique`: if ``True``, creates a uniqueness constraint on the index. + - `background`: if ``True``, this index should be created in the + background. + - `sparse`: if ``True``, omit from the index any documents that lack + the indexed field. + - `bucketSize`: for use with geoHaystack indexes. + Number of documents to group together within a certain proximity + to a given longitude and latitude. + - `min`: minimum value for keys in a :data:`~pymongo.GEO2D` + index. + - `max`: maximum value for keys in a :data:`~pymongo.GEO2D` + index. + - `expireAfterSeconds`: Used to create an expiring (TTL) + collection. MongoDB will automatically delete documents from + this collection after seconds. The indexed field must + be a UTC datetime or the data will not expire. + - `partialFilterExpression`: A document that specifies a filter for + a partial index. + - `collation`: An instance of :class:`~pymongo.collation.Collation` + that specifies the collation to use. + - `wildcardProjection`: Allows users to include or exclude specific + field paths from a `wildcard index`_ using the { "$**" : 1} key + pattern. Requires MongoDB >= 4.2. + - `hidden`: if ``True``, this index will be hidden from the query + planner and will not be evaluated as part of query plan + selection. Requires MongoDB >= 4.4. + + See the MongoDB documentation for a full list of supported options by + server version. + + :param keys: a single key or a list containing (key, direction) pairs + or keys specifying the index to create. + :param kwargs: any additional index creation + options (see the above list) should be passed as keyword + arguments. + + .. versionchanged:: 3.11 + Added the ``hidden`` option. + .. versionchanged:: 3.2 + Added the ``partialFilterExpression`` option to support partial + indexes. + + .. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/ + """ + keys = _index_list(keys) + if kwargs.get("name") is None: + kwargs["name"] = _gen_index_name(keys) + kwargs["key"] = _index_document(keys) + collation = validate_collation_or_none(kwargs.pop("collation", None)) + self.__document = kwargs + if collation is not None: + self.__document["collation"] = collation + + @property + def document(self) -> dict[str, Any]: + """An index document suitable for passing to the createIndexes + command. + """ + return self.__document + + +class SearchIndexModel: + """Represents a search index to create.""" + + __slots__ = ("__document",) + + def __init__( + self, + definition: Mapping[str, Any], + name: Optional[str] = None, + type: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Create a Search Index instance. + + For use with :meth:`~pymongo.collection.AsyncCollection.create_search_index` and :meth:`~pymongo.collection.AsyncCollection.create_search_indexes`. + + :param definition: The definition for this index. + :param name: The name for this index, if present. + :param type: The type for this index which defaults to "search". Alternative values include "vectorSearch". + :param kwargs: Keyword arguments supplying any additional options. + + .. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster. + .. versionadded:: 4.5 + .. versionchanged:: 4.7 + Added the type and kwargs arguments. + """ + self.__document: dict[str, Any] = {} + if name is not None: + self.__document["name"] = name + self.__document["definition"] = definition + if type is not None: + self.__document["type"] = type + self.__document.update(kwargs) + + @property + def document(self) -> Mapping[str, Any]: + """The document for this index.""" + return self.__document diff --git a/pymongo/pool.py b/pymongo/pool.py index 0045f227b..fbbb70fc6 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -19,3 +19,4 @@ from pymongo.synchronous.pool import * # noqa: F403 from pymongo.synchronous.pool import __doc__ as original_doc __doc__ = original_doc +__all__ = ["PoolOptions"] # noqa: F405 diff --git a/pymongo/pool_options.py b/pymongo/pool_options.py new file mode 100644 index 000000000..668b82635 --- /dev/null +++ b/pymongo/pool_options.py @@ -0,0 +1,507 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""AsyncConnection pool options for AsyncMongoClient/MongoClient.""" +from __future__ import annotations + +import copy +import os +import platform +import sys +from pathlib import Path +from typing import TYPE_CHECKING, Any, MutableMapping, Optional + +import bson +from pymongo import __version__ +from pymongo.common import ( + MAX_CONNECTING, + MAX_IDLE_TIME_SEC, + MAX_POOL_SIZE, + MIN_POOL_SIZE, + WAIT_QUEUE_TIMEOUT, +) + +if TYPE_CHECKING: + from pymongo.auth_shared import MongoCredential + from pymongo.compression_support import CompressionSettings + from pymongo.driver_info import DriverInfo + from pymongo.monitoring import _EventListeners + from pymongo.pyopenssl_context import SSLContext + from pymongo.server_api import ServerApi + + +_METADATA: dict[str, Any] = {"driver": {"name": "PyMongo", "version": __version__}} + +if sys.platform.startswith("linux"): + # platform.linux_distribution was deprecated in Python 3.5 + # and removed in Python 3.8. Starting in Python 3.5 it + # raises DeprecationWarning + # DeprecationWarning: dist() and linux_distribution() functions are deprecated in Python 3.5 + _name = platform.system() + _METADATA["os"] = { + "type": _name, + "name": _name, + "architecture": platform.machine(), + # Kernel version (e.g. 4.4.0-17-generic). + "version": platform.release(), + } +elif sys.platform == "darwin": + _METADATA["os"] = { + "type": platform.system(), + "name": platform.system(), + "architecture": platform.machine(), + # (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin + # kernel version. + "version": platform.mac_ver()[0], + } +elif sys.platform == "win32": + _METADATA["os"] = { + "type": platform.system(), + # "Windows XP", "Windows 7", "Windows 10", etc. + "name": " ".join((platform.system(), platform.release())), + "architecture": platform.machine(), + # Windows patch level (e.g. 5.1.2600-SP3) + "version": "-".join(platform.win32_ver()[1:3]), + } +elif sys.platform.startswith("java"): + _name, _ver, _arch = platform.java_ver()[-1] + _METADATA["os"] = { + # Linux, Windows 7, Mac OS X, etc. + "type": _name, + "name": _name, + # x86, x86_64, AMD64, etc. + "architecture": _arch, + # Linux kernel version, OSX version, etc. + "version": _ver, + } +else: + # Get potential alias (e.g. SunOS 5.11 becomes Solaris 2.11) + _aliased = platform.system_alias(platform.system(), platform.release(), platform.version()) + _METADATA["os"] = { + "type": platform.system(), + "name": " ".join([part for part in _aliased[:2] if part]), + "architecture": platform.machine(), + "version": _aliased[2], + } + +if platform.python_implementation().startswith("PyPy"): + _METADATA["platform"] = " ".join( + ( + platform.python_implementation(), + ".".join(map(str, sys.pypy_version_info)), # type: ignore + "(Python %s)" % ".".join(map(str, sys.version_info)), + ) + ) +elif sys.platform.startswith("java"): + _METADATA["platform"] = " ".join( + ( + platform.python_implementation(), + ".".join(map(str, sys.version_info)), + "(%s)" % " ".join((platform.system(), platform.release())), + ) + ) +else: + _METADATA["platform"] = " ".join( + (platform.python_implementation(), ".".join(map(str, sys.version_info))) + ) + +DOCKER_ENV_PATH = "/.dockerenv" +ENV_VAR_K8S = "KUBERNETES_SERVICE_HOST" + +RUNTIME_NAME_DOCKER = "docker" +ORCHESTRATOR_NAME_K8S = "kubernetes" + + +def get_container_env_info() -> dict[str, str]: + """Returns the runtime and orchestrator of a container. + If neither value is present, the metadata client.env.container field will be omitted.""" + container = {} + + if Path(DOCKER_ENV_PATH).exists(): + container["runtime"] = RUNTIME_NAME_DOCKER + if os.getenv(ENV_VAR_K8S): + container["orchestrator"] = ORCHESTRATOR_NAME_K8S + + return container + + +def _is_lambda() -> bool: + if os.getenv("AWS_LAMBDA_RUNTIME_API"): + return True + env = os.getenv("AWS_EXECUTION_ENV") + if env: + return env.startswith("AWS_Lambda_") + return False + + +def _is_azure_func() -> bool: + return bool(os.getenv("FUNCTIONS_WORKER_RUNTIME")) + + +def _is_gcp_func() -> bool: + return bool(os.getenv("K_SERVICE") or os.getenv("FUNCTION_NAME")) + + +def _is_vercel() -> bool: + return bool(os.getenv("VERCEL")) + + +def _is_faas() -> bool: + return _is_lambda() or _is_azure_func() or _is_gcp_func() or _is_vercel() + + +def _getenv_int(key: str) -> Optional[int]: + """Like os.getenv but returns an int, or None if the value is missing/malformed.""" + val = os.getenv(key) + if not val: + return None + try: + return int(val) + except ValueError: + return None + + +def _metadata_env() -> dict[str, Any]: + env: dict[str, Any] = {} + container = get_container_env_info() + if container: + env["container"] = container + # Skip if multiple (or no) envs are matched. + if (_is_lambda(), _is_azure_func(), _is_gcp_func(), _is_vercel()).count(True) != 1: + return env + if _is_lambda(): + env["name"] = "aws.lambda" + region = os.getenv("AWS_REGION") + if region: + env["region"] = region + memory_mb = _getenv_int("AWS_LAMBDA_FUNCTION_MEMORY_SIZE") + if memory_mb is not None: + env["memory_mb"] = memory_mb + elif _is_azure_func(): + env["name"] = "azure.func" + elif _is_gcp_func(): + env["name"] = "gcp.func" + region = os.getenv("FUNCTION_REGION") + if region: + env["region"] = region + memory_mb = _getenv_int("FUNCTION_MEMORY_MB") + if memory_mb is not None: + env["memory_mb"] = memory_mb + timeout_sec = _getenv_int("FUNCTION_TIMEOUT_SEC") + if timeout_sec is not None: + env["timeout_sec"] = timeout_sec + elif _is_vercel(): + env["name"] = "vercel" + region = os.getenv("VERCEL_REGION") + if region: + env["region"] = region + return env + + +_MAX_METADATA_SIZE = 512 + + +# See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations +def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None: + """Perform metadata truncation.""" + if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: + return + # 1. Omit fields from env except env.name. + env_name = metadata.get("env", {}).get("name") + if env_name: + metadata["env"] = {"name": env_name} + if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: + return + # 2. Omit fields from os except os.type. + os_type = metadata.get("os", {}).get("type") + if os_type: + metadata["os"] = {"type": os_type} + if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: + return + # 3. Omit the env document entirely. + metadata.pop("env", None) + encoded_size = len(bson.encode(metadata)) + if encoded_size <= _MAX_METADATA_SIZE: + return + # 4. Truncate platform. + overflow = encoded_size - _MAX_METADATA_SIZE + plat = metadata.get("platform", "") + if plat: + plat = plat[:-overflow] + if plat: + metadata["platform"] = plat + else: + metadata.pop("platform", None) + encoded_size = len(bson.encode(metadata)) + if encoded_size <= _MAX_METADATA_SIZE: + return + # 5. Truncate driver info. + overflow = encoded_size - _MAX_METADATA_SIZE + driver = metadata.get("driver", {}) + if driver: + # Truncate driver version. + driver_version = driver.get("version")[:-overflow] + if len(driver_version) >= len(_METADATA["driver"]["version"]): + metadata["driver"]["version"] = driver_version + else: + metadata["driver"]["version"] = _METADATA["driver"]["version"] + encoded_size = len(bson.encode(metadata)) + if encoded_size <= _MAX_METADATA_SIZE: + return + # Truncate driver name. + overflow = encoded_size - _MAX_METADATA_SIZE + driver_name = driver.get("name")[:-overflow] + if len(driver_name) >= len(_METADATA["driver"]["name"]): + metadata["driver"]["name"] = driver_name + else: + metadata["driver"]["name"] = _METADATA["driver"]["name"] + + +# If the first getaddrinfo call of this interpreter's life is on a thread, +# while the main thread holds the import lock, getaddrinfo deadlocks trying +# to import the IDNA codec. Import it here, where presumably we're on the +# main thread, to avoid the deadlock. See PYTHON-607. +"foo".encode("idna") + + +class PoolOptions: + """Read only connection pool options for an AsyncMongoClient/MongoClient. + + Should not be instantiated directly by application developers. Access + a client's pool options via + :attr:`~pymongo.client_options.ClientOptions.pool_options` instead:: + + pool_opts = client.options.pool_options + pool_opts.max_pool_size + pool_opts.min_pool_size + + """ + + __slots__ = ( + "__max_pool_size", + "__min_pool_size", + "__max_idle_time_seconds", + "__connect_timeout", + "__socket_timeout", + "__wait_queue_timeout", + "__ssl_context", + "__tls_allow_invalid_hostnames", + "__event_listeners", + "__appname", + "__driver", + "__metadata", + "__compression_settings", + "__max_connecting", + "__pause_enabled", + "__server_api", + "__load_balanced", + "__credentials", + ) + + def __init__( + self, + max_pool_size: int = MAX_POOL_SIZE, + min_pool_size: int = MIN_POOL_SIZE, + max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC, + connect_timeout: Optional[float] = None, + socket_timeout: Optional[float] = None, + wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT, + ssl_context: Optional[SSLContext] = None, + tls_allow_invalid_hostnames: bool = False, + event_listeners: Optional[_EventListeners] = None, + appname: Optional[str] = None, + driver: Optional[DriverInfo] = None, + compression_settings: Optional[CompressionSettings] = None, + max_connecting: int = MAX_CONNECTING, + pause_enabled: bool = True, + server_api: Optional[ServerApi] = None, + load_balanced: Optional[bool] = None, + credentials: Optional[MongoCredential] = None, + ): + self.__max_pool_size = max_pool_size + self.__min_pool_size = min_pool_size + self.__max_idle_time_seconds = max_idle_time_seconds + self.__connect_timeout = connect_timeout + self.__socket_timeout = socket_timeout + self.__wait_queue_timeout = wait_queue_timeout + self.__ssl_context = ssl_context + self.__tls_allow_invalid_hostnames = tls_allow_invalid_hostnames + self.__event_listeners = event_listeners + self.__appname = appname + self.__driver = driver + self.__compression_settings = compression_settings + self.__max_connecting = max_connecting + self.__pause_enabled = pause_enabled + self.__server_api = server_api + self.__load_balanced = load_balanced + self.__credentials = credentials + self.__metadata = copy.deepcopy(_METADATA) + if appname: + self.__metadata["application"] = {"name": appname} + + # Combine the "driver" AsyncMongoClient option with PyMongo's info, like: + # { + # 'driver': { + # 'name': 'PyMongo|MyDriver', + # 'version': '4.2.0|1.2.3', + # }, + # 'platform': 'CPython 3.8.0|MyPlatform' + # } + if driver: + if driver.name: + self.__metadata["driver"]["name"] = "{}|{}".format( + _METADATA["driver"]["name"], + driver.name, + ) + if driver.version: + self.__metadata["driver"]["version"] = "{}|{}".format( + _METADATA["driver"]["version"], + driver.version, + ) + if driver.platform: + self.__metadata["platform"] = "{}|{}".format(_METADATA["platform"], driver.platform) + + env = _metadata_env() + if env: + self.__metadata["env"] = env + + _truncate_metadata(self.__metadata) + + @property + def _credentials(self) -> Optional[MongoCredential]: + """A :class:`~pymongo.auth.MongoCredentials` instance or None.""" + return self.__credentials + + @property + def non_default_options(self) -> dict[str, Any]: + """The non-default options this pool was created with. + + Added for CMAP's :class:`PoolCreatedEvent`. + """ + opts = {} + if self.__max_pool_size != MAX_POOL_SIZE: + opts["maxPoolSize"] = self.__max_pool_size + if self.__min_pool_size != MIN_POOL_SIZE: + opts["minPoolSize"] = self.__min_pool_size + if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC: + assert self.__max_idle_time_seconds is not None + opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000 + if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT: + assert self.__wait_queue_timeout is not None + opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000 + if self.__max_connecting != MAX_CONNECTING: + opts["maxConnecting"] = self.__max_connecting + return opts + + @property + def max_pool_size(self) -> float: + """The maximum allowable number of concurrent connections to each + connected server. Requests to a server will block if there are + `maxPoolSize` outstanding connections to the requested server. + Defaults to 100. Cannot be 0. + + When a server's pool has reached `max_pool_size`, operations for that + server block waiting for a socket to be returned to the pool. If + ``waitQueueTimeoutMS`` is set, a blocked operation will raise + :exc:`~pymongo.errors.ConnectionFailure` after a timeout. + By default ``waitQueueTimeoutMS`` is not set. + """ + return self.__max_pool_size + + @property + def min_pool_size(self) -> int: + """The minimum required number of concurrent connections that the pool + will maintain to each connected server. Default is 0. + """ + return self.__min_pool_size + + @property + def max_connecting(self) -> int: + """The maximum number of concurrent connection creation attempts per + pool. Defaults to 2. + """ + return self.__max_connecting + + @property + def pause_enabled(self) -> bool: + return self.__pause_enabled + + @property + def max_idle_time_seconds(self) -> Optional[int]: + """The maximum number of seconds that a connection can remain + idle in the pool before being removed and replaced. Defaults to + `None` (no limit). + """ + return self.__max_idle_time_seconds + + @property + def connect_timeout(self) -> Optional[float]: + """How long a connection can take to be opened before timing out.""" + return self.__connect_timeout + + @property + def socket_timeout(self) -> Optional[float]: + """How long a send or receive on a socket can take before timing out.""" + return self.__socket_timeout + + @property + def wait_queue_timeout(self) -> Optional[int]: + """How long a thread will wait for a socket from the pool if the pool + has no free sockets. + """ + return self.__wait_queue_timeout + + @property + def _ssl_context(self) -> Optional[SSLContext]: + """An SSLContext instance or None.""" + return self.__ssl_context + + @property + def tls_allow_invalid_hostnames(self) -> bool: + """If True skip ssl.match_hostname.""" + return self.__tls_allow_invalid_hostnames + + @property + def _event_listeners(self) -> Optional[_EventListeners]: + """An instance of pymongo.monitoring._EventListeners.""" + return self.__event_listeners + + @property + def appname(self) -> Optional[str]: + """The application name, for sending with hello in server handshake.""" + return self.__appname + + @property + def driver(self) -> Optional[DriverInfo]: + """Driver name and version, for sending with hello in handshake.""" + return self.__driver + + @property + def _compression_settings(self) -> Optional[CompressionSettings]: + return self.__compression_settings + + @property + def metadata(self) -> dict[str, Any]: + """A dict of metadata about the application, driver, os, and platform.""" + return self.__metadata.copy() + + @property + def server_api(self) -> Optional[ServerApi]: + """A pymongo.server_api.ServerApi or None.""" + return self.__server_api + + @property + def load_balanced(self) -> Optional[bool]: + """True if this Pool is configured in load balanced mode.""" + return self.__load_balanced diff --git a/pymongo/read_preferences.py b/pymongo/read_preferences.py index de15cbfca..a7e138cd9 100644 --- a/pymongo/read_preferences.py +++ b/pymongo/read_preferences.py @@ -1,6 +1,6 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2012-present MongoDB, Inc. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the "License", # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -12,10 +12,612 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-import of synchronous ReadPreferences API for compatibility.""" +"""Utilities for choosing which member of a replica set to read from.""" + from __future__ import annotations -from pymongo.synchronous.read_preferences import * # noqa: F403 -from pymongo.synchronous.read_preferences import __doc__ as original_doc +from collections import abc +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence -__doc__ = original_doc +from pymongo import max_staleness_selectors +from pymongo.errors import ConfigurationError +from pymongo.server_selectors import ( + member_with_tags_server_selector, + secondary_with_tags_server_selector, +) + +if TYPE_CHECKING: + from pymongo.server_selectors import Selection + from pymongo.topology_description import TopologyDescription + + +_PRIMARY = 0 +_PRIMARY_PREFERRED = 1 +_SECONDARY = 2 +_SECONDARY_PREFERRED = 3 +_NEAREST = 4 + + +_MONGOS_MODES = ( + "primary", + "primaryPreferred", + "secondary", + "secondaryPreferred", + "nearest", +) + +_Hedge = Mapping[str, Any] +_TagSets = Sequence[Mapping[str, Any]] + + +def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]: + """Validate tag sets for a MongoClient.""" + if tag_sets is None: + return tag_sets + + if not isinstance(tag_sets, (list, tuple)): + raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence") + if len(tag_sets) == 0: + raise ValueError( + f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags" + ) + + for tags in tag_sets: + if not isinstance(tags, abc.Mapping): + raise TypeError( + f"Tag set {tags!r} invalid, must be an instance of dict, " + "bson.son.SON or other type that inherits from " + "collection.Mapping" + ) + + return list(tag_sets) + + +def _invalid_max_staleness_msg(max_staleness: Any) -> str: + return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness + + +# Some duplication with common.py to avoid import cycle. +def _validate_max_staleness(max_staleness: Any) -> int: + """Validate max_staleness.""" + if max_staleness == -1: + return -1 + + if not isinstance(max_staleness, int): + raise TypeError(_invalid_max_staleness_msg(max_staleness)) + + if max_staleness <= 0: + raise ValueError(_invalid_max_staleness_msg(max_staleness)) + + return max_staleness + + +def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]: + """Validate hedge.""" + if hedge is None: + return None + + if not isinstance(hedge, dict): + raise TypeError(f"hedge must be a dictionary, not {hedge!r}") + + return hedge + + +class _ServerMode: + """Base class for all read preferences.""" + + __slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge") + + def __init__( + self, + mode: int, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + self.__mongos_mode = _MONGOS_MODES[mode] + self.__mode = mode + self.__tag_sets = _validate_tag_sets(tag_sets) + self.__max_staleness = _validate_max_staleness(max_staleness) + self.__hedge = _validate_hedge(hedge) + + @property + def name(self) -> str: + """The name of this read preference.""" + return self.__class__.__name__ + + @property + def mongos_mode(self) -> str: + """The mongos mode of this read preference.""" + return self.__mongos_mode + + @property + def document(self) -> dict[str, Any]: + """Read preference as a document.""" + doc: dict[str, Any] = {"mode": self.__mongos_mode} + if self.__tag_sets not in (None, [{}]): + doc["tags"] = self.__tag_sets + if self.__max_staleness != -1: + doc["maxStalenessSeconds"] = self.__max_staleness + if self.__hedge not in (None, {}): + doc["hedge"] = self.__hedge + return doc + + @property + def mode(self) -> int: + """The mode of this read preference instance.""" + return self.__mode + + @property + def tag_sets(self) -> _TagSets: + """Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to + read only from members whose ``dc`` tag has the value ``"ny"``. + To specify a priority-order for tag sets, provide a list of + tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag + set, ``{}``, means "read from any member that matches the mode, + ignoring tags." MongoClient tries each set of tags in turn + until it finds a set of tags with at least one matching member. + For example, to only send a query to an analytic node:: + + Nearest(tag_sets=[{"node":"analytics"}]) + + Or using :class:`SecondaryPreferred`:: + + SecondaryPreferred(tag_sets=[{"node":"analytics"}]) + + .. seealso:: `Data-Center Awareness + `_ + """ + return list(self.__tag_sets) if self.__tag_sets else [{}] + + @property + def max_staleness(self) -> int: + """The maximum estimated length of time (in seconds) a replica set + secondary can fall behind the primary in replication before it will + no longer be selected for operations, or -1 for no maximum. + """ + return self.__max_staleness + + @property + def hedge(self) -> Optional[_Hedge]: + """The read preference ``hedge`` parameter. + + A dictionary that configures how the server will perform hedged reads. + It consists of the following keys: + + - ``enabled``: Enables or disables hedged reads in sharded clusters. + + Hedged reads are automatically enabled in MongoDB 4.4+ when using a + ``nearest`` read preference. To explicitly enable hedged reads, set + the ``enabled`` key to ``true``:: + + >>> Nearest(hedge={'enabled': True}) + + To explicitly disable hedged reads, set the ``enabled`` key to + ``False``:: + + >>> Nearest(hedge={'enabled': False}) + + .. versionadded:: 3.11 + """ + return self.__hedge + + @property + def min_wire_version(self) -> int: + """The wire protocol version the server must support. + + Some read preferences impose version requirements on all servers (e.g. + maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5). + + All servers' maxWireVersion must be at least this read preference's + `min_wire_version`, or the driver raises + :exc:`~pymongo.errors.ConfigurationError`. + """ + return 0 if self.__max_staleness == -1 else 5 + + def __repr__(self) -> str: + return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format( + self.name, + self.__tag_sets, + self.__max_staleness, + self.__hedge, + ) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, _ServerMode): + return ( + self.mode == other.mode + and self.tag_sets == other.tag_sets + and self.max_staleness == other.max_staleness + and self.hedge == other.hedge + ) + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __getstate__(self) -> dict[str, Any]: + """Return value of object for pickling. + + Needed explicitly because __slots__() defined. + """ + return { + "mode": self.__mode, + "tag_sets": self.__tag_sets, + "max_staleness": self.__max_staleness, + "hedge": self.__hedge, + } + + def __setstate__(self, value: Mapping[str, Any]) -> None: + """Restore from pickling.""" + self.__mode = value["mode"] + self.__mongos_mode = _MONGOS_MODES[self.__mode] + self.__tag_sets = _validate_tag_sets(value["tag_sets"]) + self.__max_staleness = _validate_max_staleness(value["max_staleness"]) + self.__hedge = _validate_hedge(value["hedge"]) + + def __call__(self, selection: Selection) -> Selection: + return selection + + +class Primary(_ServerMode): + """Primary read preference. + + * When directly connected to one mongod queries are allowed if the server + is standalone or a replica set primary. + * When connected to a mongos queries are sent to the primary of a shard. + * When connected to a replica set queries are sent to the primary of + the replica set. + """ + + __slots__ = () + + def __init__(self) -> None: + super().__init__(_PRIMARY) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to a Selection.""" + return selection.primary_selection + + def __repr__(self) -> str: + return "Primary()" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, _ServerMode): + return other.mode == _PRIMARY + return NotImplemented + + +class PrimaryPreferred(_ServerMode): + """PrimaryPreferred read preference. + + * When directly connected to one mongod queries are allowed to standalone + servers, to a replica set primary, or to replica set secondaries. + * When connected to a mongos queries are sent to the primary of a shard if + available, otherwise a shard secondary. + * When connected to a replica set queries are sent to the primary if + available, otherwise a secondary. + + .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first + created reads will be routed to an available secondary until the + primary of the replica set is discovered. + + :param tag_sets: The :attr:`~tag_sets` to use if the primary is not + available. + :param max_staleness: (integer, in seconds) The maximum estimated + length of time a replica set secondary can fall behind the primary in + replication before it will no longer be selected for operations. + Default -1, meaning no maximum. If it is set, it must be at least + 90 seconds. + :param hedge: The :attr:`~hedge` to use if the primary is not available. + + .. versionchanged:: 3.11 + Added ``hedge`` parameter. + """ + + __slots__ = () + + def __init__( + self, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to Selection.""" + if selection.primary: + return selection.primary_selection + else: + return secondary_with_tags_server_selector( + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) + + +class Secondary(_ServerMode): + """Secondary read preference. + + * When directly connected to one mongod queries are allowed to standalone + servers, to a replica set primary, or to replica set secondaries. + * When connected to a mongos queries are distributed among shard + secondaries. An error is raised if no secondaries are available. + * When connected to a replica set queries are distributed among + secondaries. An error is raised if no secondaries are available. + + :param tag_sets: The :attr:`~tag_sets` for this read preference. + :param max_staleness: (integer, in seconds) The maximum estimated + length of time a replica set secondary can fall behind the primary in + replication before it will no longer be selected for operations. + Default -1, meaning no maximum. If it is set, it must be at least + 90 seconds. + :param hedge: The :attr:`~hedge` for this read preference. + + .. versionchanged:: 3.11 + Added ``hedge`` parameter. + """ + + __slots__ = () + + def __init__( + self, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + super().__init__(_SECONDARY, tag_sets, max_staleness, hedge) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to Selection.""" + return secondary_with_tags_server_selector( + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) + + +class SecondaryPreferred(_ServerMode): + """SecondaryPreferred read preference. + + * When directly connected to one mongod queries are allowed to standalone + servers, to a replica set primary, or to replica set secondaries. + * When connected to a mongos queries are distributed among shard + secondaries, or the shard primary if no secondary is available. + * When connected to a replica set queries are distributed among + secondaries, or the primary if no secondary is available. + + .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first + created reads will be routed to the primary of the replica set until + an available secondary is discovered. + + :param tag_sets: The :attr:`~tag_sets` for this read preference. + :param max_staleness: (integer, in seconds) The maximum estimated + length of time a replica set secondary can fall behind the primary in + replication before it will no longer be selected for operations. + Default -1, meaning no maximum. If it is set, it must be at least + 90 seconds. + :param hedge: The :attr:`~hedge` for this read preference. + + .. versionchanged:: 3.11 + Added ``hedge`` parameter. + """ + + __slots__ = () + + def __init__( + self, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to Selection.""" + secondaries = secondary_with_tags_server_selector( + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) + + if secondaries: + return secondaries + else: + return selection.primary_selection + + +class Nearest(_ServerMode): + """Nearest read preference. + + * When directly connected to one mongod queries are allowed to standalone + servers, to a replica set primary, or to replica set secondaries. + * When connected to a mongos queries are distributed among all members of + a shard. + * When connected to a replica set queries are distributed among all + members. + + :param tag_sets: The :attr:`~tag_sets` for this read preference. + :param max_staleness: (integer, in seconds) The maximum estimated + length of time a replica set secondary can fall behind the primary in + replication before it will no longer be selected for operations. + Default -1, meaning no maximum. If it is set, it must be at least + 90 seconds. + :param hedge: The :attr:`~hedge` for this read preference. + + .. versionchanged:: 3.11 + Added ``hedge`` parameter. + """ + + __slots__ = () + + def __init__( + self, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + super().__init__(_NEAREST, tag_sets, max_staleness, hedge) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to Selection.""" + return member_with_tags_server_selector( + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) + + +class _AggWritePref: + """Agg $out/$merge write preference. + + * If there are readable servers and there is any pre-5.0 server, use + primary read preference. + * Otherwise use `pref` read preference. + + :param pref: The read preference to use on MongoDB 5.0+. + """ + + __slots__ = ("pref", "effective_pref") + + def __init__(self, pref: _ServerMode): + self.pref = pref + self.effective_pref: _ServerMode = ReadPreference.PRIMARY + + def selection_hook(self, topology_description: TopologyDescription) -> None: + common_wv = topology_description.common_wire_version + if ( + topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED) + and common_wv + and common_wv < 13 + ): + self.effective_pref = ReadPreference.PRIMARY + else: + self.effective_pref = self.pref + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to a Selection.""" + return self.effective_pref(selection) + + def __repr__(self) -> str: + return f"_AggWritePref(pref={self.pref!r})" + + # Proxy other calls to the effective_pref so that _AggWritePref can be + # used in place of an actual read preference. + def __getattr__(self, name: str) -> Any: + return getattr(self.effective_pref, name) + + +_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest) + + +def make_read_preference( + mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1 +) -> _ServerMode: + if mode == _PRIMARY: + if tag_sets not in (None, [{}]): + raise ConfigurationError("Read preference primary cannot be combined with tags") + if max_staleness != -1: + raise ConfigurationError( + "Read preference primary cannot be combined with maxStalenessSeconds" + ) + return Primary() + return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore + + +_MODES = ( + "PRIMARY", + "PRIMARY_PREFERRED", + "SECONDARY", + "SECONDARY_PREFERRED", + "NEAREST", +) + + +class ReadPreference: + """An enum that defines some commonly used read preference modes. + + Apps can also create a custom read preference, for example:: + + Nearest(tag_sets=[{"node":"analytics"}]) + + See :doc:`/examples/high_availability` for code examples. + + A read preference is used in three cases: + + :class:`~pymongo.mongo_client.MongoClient` connected to a single mongod: + + - ``PRIMARY``: Queries are allowed if the server is standalone or a replica + set primary. + - All other modes allow queries to standalone servers, to a replica set + primary, or to replica set secondaries. + + :class:`~pymongo.mongo_client.MongoClient` initialized with the + ``replicaSet`` option: + + - ``PRIMARY``: Read from the primary. This is the default, and provides the + strongest consistency. If no primary is available, raise + :class:`~pymongo.errors.AutoReconnect`. + + - ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is + none, read from a secondary. + + - ``SECONDARY``: Read from a secondary. If no secondary is available, + raise :class:`~pymongo.errors.AutoReconnect`. + + - ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise + from the primary. + + - ``NEAREST``: Read from any member. + + :class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a + sharded cluster of replica sets: + + - ``PRIMARY``: Read from the primary of the shard, or raise + :class:`~pymongo.errors.OperationFailure` if there is none. + This is the default. + + - ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is + none, read from a secondary of the shard. + + - ``SECONDARY``: Read from a secondary of the shard, or raise + :class:`~pymongo.errors.OperationFailure` if there is none. + + - ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available, + otherwise from the shard primary. + + - ``NEAREST``: Read from any shard member. + """ + + PRIMARY = Primary() + PRIMARY_PREFERRED = PrimaryPreferred() + SECONDARY = Secondary() + SECONDARY_PREFERRED = SecondaryPreferred() + NEAREST = Nearest() + + +def read_pref_mode_from_name(name: str) -> int: + """Get the read preference mode from mongos/uri name.""" + return _MONGOS_MODES.index(name) + + +class MovingAverage: + """Tracks an exponentially-weighted moving average.""" + + average: Optional[float] + + def __init__(self) -> None: + self.average = None + + def add_sample(self, sample: float) -> None: + if sample < 0: + # Likely system time change while waiting for hello response + # and not using time.monotonic. Ignore it, the next one will + # probably be valid. + return + if self.average is None: + self.average = sample + else: + # The Server Selection Spec requires an exponentially weighted + # average with alpha = 0.2. + self.average = 0.8 * self.average + 0.2 * sample + + def get(self) -> Optional[float]: + """Get the calculated average, or None if no samples yet.""" + return self.average + + def reset(self) -> None: + self.average = None diff --git a/pymongo/synchronous/response.py b/pymongo/response.py similarity index 91% rename from pymongo/synchronous/response.py rename to pymongo/response.py index 94fd4df50..e47749423 100644 --- a/pymongo/synchronous/response.py +++ b/pymongo/response.py @@ -20,11 +20,8 @@ from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union if TYPE_CHECKING: from datetime import timedelta - from pymongo.synchronous.message import _OpMsg, _OpReply - from pymongo.synchronous.pool import Connection - from pymongo.synchronous.typings import _Address, _DocumentOut - -_IS_SYNC = True + from pymongo.message import _OpMsg, _OpReply + from pymongo.typings import _Address, _AgnosticConnection, _DocumentOut class Response: @@ -92,7 +89,7 @@ class PinnedResponse(Response): self, data: Union[_OpMsg, _OpReply], address: _Address, - conn: Connection, + conn: _AgnosticConnection, request_id: int, duration: Optional[timedelta], from_command: bool, @@ -103,7 +100,7 @@ class PinnedResponse(Response): :param data: A network response message. :param address: (host, port) of the source server. - :param conn: The Connection used for the initial query. + :param conn: The AsyncConnection/Connection used for the initial query. :param request_id: The request id of this operation. :param duration: The duration of the operation. :param from_command: If the response is the result of a db command. @@ -116,8 +113,8 @@ class PinnedResponse(Response): self._more_to_come = more_to_come @property - def conn(self) -> Connection: - """The Connection used for the initial query. + def conn(self) -> _AgnosticConnection: + """The AsyncConnection/Connection used for the initial query. The server will send batches on this socket, without waiting for getMores from the client, until the result set is exhausted or there diff --git a/pymongo/server_description.py b/pymongo/server_description.py index 4ee6b340d..6393fce0a 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -1,4 +1,4 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,288 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-import of synchronous ServerDescription API for compatibility.""" +"""Represent one server the driver is connected to.""" from __future__ import annotations -from pymongo.synchronous.server_description import * # noqa: F403 -from pymongo.synchronous.server_description import __doc__ as original_doc +import time +import warnings +from typing import Any, Mapping, Optional -__doc__ = original_doc +from bson import EPOCH_NAIVE +from bson.objectid import ObjectId +from pymongo.hello import Hello +from pymongo.server_type import SERVER_TYPE +from pymongo.typings import ClusterTime, _Address + + +class ServerDescription: + """Immutable representation of one server. + + :param address: A (host, port) pair + :param hello: Optional Hello instance + :param round_trip_time: Optional float + :param error: Optional, the last error attempting to connect to the server + :param round_trip_time: Optional float, the min latency from the most recent samples + """ + + __slots__ = ( + "_address", + "_server_type", + "_all_hosts", + "_tags", + "_replica_set_name", + "_primary", + "_max_bson_size", + "_max_message_size", + "_max_write_batch_size", + "_min_wire_version", + "_max_wire_version", + "_round_trip_time", + "_min_round_trip_time", + "_me", + "_is_writable", + "_is_readable", + "_ls_timeout_minutes", + "_error", + "_set_version", + "_election_id", + "_cluster_time", + "_last_write_date", + "_last_update_time", + "_topology_version", + ) + + def __init__( + self, + address: _Address, + hello: Optional[Hello] = None, + round_trip_time: Optional[float] = None, + error: Optional[Exception] = None, + min_round_trip_time: float = 0.0, + ) -> None: + self._address = address + if not hello: + hello = Hello({}) + + self._server_type = hello.server_type + self._all_hosts = hello.all_hosts + self._tags = hello.tags + self._replica_set_name = hello.replica_set_name + self._primary = hello.primary + self._max_bson_size = hello.max_bson_size + self._max_message_size = hello.max_message_size + self._max_write_batch_size = hello.max_write_batch_size + self._min_wire_version = hello.min_wire_version + self._max_wire_version = hello.max_wire_version + self._set_version = hello.set_version + self._election_id = hello.election_id + self._cluster_time = hello.cluster_time + self._is_writable = hello.is_writable + self._is_readable = hello.is_readable + self._ls_timeout_minutes = hello.logical_session_timeout_minutes + self._round_trip_time = round_trip_time + self._min_round_trip_time = min_round_trip_time + self._me = hello.me + self._last_update_time = time.monotonic() + self._error = error + self._topology_version = hello.topology_version + if error: + details = getattr(error, "details", None) + if isinstance(details, dict): + self._topology_version = details.get("topologyVersion") + + self._last_write_date: Optional[float] + if hello.last_write_date: + # Convert from datetime to seconds. + delta = hello.last_write_date - EPOCH_NAIVE + self._last_write_date = delta.total_seconds() + else: + self._last_write_date = None + + @property + def address(self) -> _Address: + """The address (host, port) of this server.""" + return self._address + + @property + def server_type(self) -> int: + """The type of this server.""" + return self._server_type + + @property + def server_type_name(self) -> str: + """The server type as a human readable string. + + .. versionadded:: 3.4 + """ + return SERVER_TYPE._fields[self._server_type] + + @property + def all_hosts(self) -> set[tuple[str, int]]: + """List of hosts, passives, and arbiters known to this server.""" + return self._all_hosts + + @property + def tags(self) -> Mapping[str, Any]: + return self._tags + + @property + def replica_set_name(self) -> Optional[str]: + """Replica set name or None.""" + return self._replica_set_name + + @property + def primary(self) -> Optional[tuple[str, int]]: + """This server's opinion about who the primary is, or None.""" + return self._primary + + @property + def max_bson_size(self) -> int: + return self._max_bson_size + + @property + def max_message_size(self) -> int: + return self._max_message_size + + @property + def max_write_batch_size(self) -> int: + return self._max_write_batch_size + + @property + def min_wire_version(self) -> int: + return self._min_wire_version + + @property + def max_wire_version(self) -> int: + return self._max_wire_version + + @property + def set_version(self) -> Optional[int]: + return self._set_version + + @property + def election_id(self) -> Optional[ObjectId]: + return self._election_id + + @property + def cluster_time(self) -> Optional[ClusterTime]: + return self._cluster_time + + @property + def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]: + warnings.warn( + "'election_tuple' is deprecated, use 'set_version' and 'election_id' instead", + DeprecationWarning, + stacklevel=2, + ) + return self._set_version, self._election_id + + @property + def me(self) -> Optional[tuple[str, int]]: + return self._me + + @property + def logical_session_timeout_minutes(self) -> Optional[int]: + return self._ls_timeout_minutes + + @property + def last_write_date(self) -> Optional[float]: + return self._last_write_date + + @property + def last_update_time(self) -> float: + return self._last_update_time + + @property + def round_trip_time(self) -> Optional[float]: + """The current average latency or None.""" + # This override is for unittesting only! + if self._address in self._host_to_round_trip_time: + return self._host_to_round_trip_time[self._address] + + return self._round_trip_time + + @property + def min_round_trip_time(self) -> float: + """The min latency from the most recent samples.""" + return self._min_round_trip_time + + @property + def error(self) -> Optional[Exception]: + """The last error attempting to connect to the server, or None.""" + return self._error + + @property + def is_writable(self) -> bool: + return self._is_writable + + @property + def is_readable(self) -> bool: + return self._is_readable + + @property + def mongos(self) -> bool: + return self._server_type == SERVER_TYPE.Mongos + + @property + def is_server_type_known(self) -> bool: + return self.server_type != SERVER_TYPE.Unknown + + @property + def retryable_writes_supported(self) -> bool: + """Checks if this server supports retryable writes.""" + return ( + self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary) + ) or self._server_type == SERVER_TYPE.LoadBalancer + + @property + def retryable_reads_supported(self) -> bool: + """Checks if this server supports retryable writes.""" + return self._max_wire_version >= 6 + + @property + def topology_version(self) -> Optional[Mapping[str, Any]]: + return self._topology_version + + def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription: + unknown = ServerDescription(self.address, error=error) + unknown._topology_version = self.topology_version + return unknown + + def __eq__(self, other: Any) -> bool: + if isinstance(other, ServerDescription): + return ( + (self._address == other.address) + and (self._server_type == other.server_type) + and (self._min_wire_version == other.min_wire_version) + and (self._max_wire_version == other.max_wire_version) + and (self._me == other.me) + and (self._all_hosts == other.all_hosts) + and (self._tags == other.tags) + and (self._replica_set_name == other.replica_set_name) + and (self._set_version == other.set_version) + and (self._election_id == other.election_id) + and (self._primary == other.primary) + and (self._ls_timeout_minutes == other.logical_session_timeout_minutes) + and (self._error == other.error) + ) + + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __repr__(self) -> str: + errmsg = "" + if self.error: + errmsg = f", error={self.error!r}" + return "<{} {} server_type: {}, rtt: {}{}>".format( + self.__class__.__name__, + self.address, + self.server_type_name, + self.round_trip_time, + errmsg, + ) + + # For unittesting only. Use under no circumstances! + _host_to_round_trip_time: dict = {} diff --git a/pymongo/synchronous/server_selectors.py b/pymongo/server_selectors.py similarity index 97% rename from pymongo/synchronous/server_selectors.py rename to pymongo/server_selectors.py index a3b2066ab..c22ad599e 100644 --- a/pymongo/synchronous/server_selectors.py +++ b/pymongo/server_selectors.py @@ -20,10 +20,9 @@ from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, TypeVar, cas from pymongo.server_type import SERVER_TYPE if TYPE_CHECKING: - from pymongo.synchronous.server_description import ServerDescription - from pymongo.synchronous.topology_description import TopologyDescription + from pymongo.server_description import ServerDescription + from pymongo.topology_description import TopologyDescription -_IS_SYNC = True T = TypeVar("T") TagSet = Mapping[str, Any] diff --git a/pymongo/synchronous/srv_resolver.py b/pymongo/srv_resolver.py similarity index 98% rename from pymongo/synchronous/srv_resolver.py rename to pymongo/srv_resolver.py index e5481305e..6f6cc285f 100644 --- a/pymongo/synchronous/srv_resolver.py +++ b/pymongo/srv_resolver.py @@ -19,14 +19,12 @@ import ipaddress import random from typing import TYPE_CHECKING, Any, Optional, Union +from pymongo.common import CONNECT_TIMEOUT from pymongo.errors import ConfigurationError -from pymongo.synchronous.common import CONNECT_TIMEOUT if TYPE_CHECKING: from dns import resolver -_IS_SYNC = True - def _have_dnspython() -> bool: try: diff --git a/pymongo/synchronous/aggregation.py b/pymongo/synchronous/aggregation.py index a4b5a957c..7c7e6252f 100644 --- a/pymongo/synchronous/aggregation.py +++ b/pymongo/synchronous/aggregation.py @@ -18,20 +18,20 @@ from __future__ import annotations from collections.abc import Callable, Mapping, MutableMapping from typing import TYPE_CHECKING, Any, Optional, Union +from pymongo import common +from pymongo.collation import validate_collation_or_none from pymongo.errors import ConfigurationError -from pymongo.synchronous import common -from pymongo.synchronous.collation import validate_collation_or_none -from pymongo.synchronous.read_preferences import ReadPreference, _AggWritePref +from pymongo.read_preferences import ReadPreference, _AggWritePref if TYPE_CHECKING: + from pymongo.read_preferences import _ServerMode from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.database import Database from pymongo.synchronous.pool import Connection - from pymongo.synchronous.read_preferences import _ServerMode from pymongo.synchronous.server import Server - from pymongo.synchronous.typings import _DocumentType, _Pipeline + from pymongo.typings import _DocumentType, _Pipeline _IS_SYNC = True diff --git a/pymongo/synchronous/auth.py b/pymongo/synchronous/auth.py index cb1b23d15..9a3477679 100644 --- a/pymongo/synchronous/auth.py +++ b/pymongo/synchronous/auth.py @@ -18,16 +18,12 @@ from __future__ import annotations import functools import hashlib import hmac -import os import socket -import typing from base64 import standard_b64decode, standard_b64encode -from collections import namedtuple from typing import ( TYPE_CHECKING, Any, Callable, - Dict, Mapping, MutableMapping, Optional, @@ -36,20 +32,22 @@ from typing import ( from urllib.parse import quote from bson.binary import Binary +from pymongo.auth_shared import ( + MongoCredential, + _authenticate_scram_start, + _parse_scram_response, + _xor, +) from pymongo.errors import ConfigurationError, OperationFailure from pymongo.saslprep import saslprep from pymongo.synchronous.auth_aws import _authenticate_aws from pymongo.synchronous.auth_oidc import ( _authenticate_oidc, _get_authenticator, - _OIDCAzureCallback, - _OIDCGCPCallback, - _OIDCProperties, - _OIDCTestCallback, ) if TYPE_CHECKING: - from pymongo.synchronous.hello import Hello + from pymongo.hello import Hello from pymongo.synchronous.pool import Connection HAVE_KERBEROS = True @@ -68,210 +66,6 @@ except ImportError: _IS_SYNC = True -MECHANISMS = frozenset( - [ - "GSSAPI", - "MONGODB-CR", - "MONGODB-OIDC", - "MONGODB-X509", - "MONGODB-AWS", - "PLAIN", - "SCRAM-SHA-1", - "SCRAM-SHA-256", - "DEFAULT", - ] -) -"""The authentication mechanisms supported by PyMongo.""" - - -class _Cache: - __slots__ = ("data",) - - _hash_val = hash("_Cache") - - def __init__(self) -> None: - self.data = None - - def __eq__(self, other: object) -> bool: - # Two instances must always compare equal. - if isinstance(other, _Cache): - return True - return NotImplemented - - def __ne__(self, other: object) -> bool: - if isinstance(other, _Cache): - return False - return NotImplemented - - def __hash__(self) -> int: - return self._hash_val - - -MongoCredential = namedtuple( - "MongoCredential", - ["mechanism", "source", "username", "password", "mechanism_properties", "cache"], -) -"""A hashable namedtuple of values used for authentication.""" - - -GSSAPIProperties = namedtuple( - "GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"] -) -"""Mechanism properties for GSSAPI authentication.""" - - -_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"]) -"""Mechanism properties for MONGODB-AWS authentication.""" - - -def _build_credentials_tuple( - mech: str, - source: Optional[str], - user: str, - passwd: str, - extra: Mapping[str, Any], - database: Optional[str], -) -> MongoCredential: - """Build and return a mechanism specific credentials tuple.""" - if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: - raise ConfigurationError(f"{mech} requires a username.") - if mech == "GSSAPI": - if source is not None and source != "$external": - raise ValueError("authentication source must be $external or None for GSSAPI") - properties = extra.get("authmechanismproperties", {}) - service_name = properties.get("SERVICE_NAME", "mongodb") - canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False)) - service_realm = properties.get("SERVICE_REALM") - props = GSSAPIProperties( - service_name=service_name, - canonicalize_host_name=canonicalize, - service_realm=service_realm, - ) - # Source is always $external. - return MongoCredential(mech, "$external", user, passwd, props, None) - elif mech == "MONGODB-X509": - if passwd is not None: - raise ConfigurationError("Passwords are not supported by MONGODB-X509") - if source is not None and source != "$external": - raise ValueError("authentication source must be $external or None for MONGODB-X509") - # Source is always $external, user can be None. - return MongoCredential(mech, "$external", user, None, None, None) - elif mech == "MONGODB-AWS": - if user is not None and passwd is None: - raise ConfigurationError("username without a password is not supported by MONGODB-AWS") - if source is not None and source != "$external": - raise ConfigurationError( - "authentication source must be $external or None for MONGODB-AWS" - ) - - properties = extra.get("authmechanismproperties", {}) - aws_session_token = properties.get("AWS_SESSION_TOKEN") - aws_props = _AWSProperties(aws_session_token=aws_session_token) - # user can be None for temporary link-local EC2 credentials. - return MongoCredential(mech, "$external", user, passwd, aws_props, None) - elif mech == "MONGODB-OIDC": - properties = extra.get("authmechanismproperties", {}) - callback = properties.get("OIDC_CALLBACK") - human_callback = properties.get("OIDC_HUMAN_CALLBACK") - environ = properties.get("ENVIRONMENT") - token_resource = properties.get("TOKEN_RESOURCE", "") - default_allowed = [ - "*.mongodb.net", - "*.mongodb-dev.net", - "*.mongodb-qa.net", - "*.mongodbgov.net", - "localhost", - "127.0.0.1", - "::1", - ] - allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed) - msg = ( - "authentication with MONGODB-OIDC requires providing either a callback or a environment" - ) - if passwd is not None: - msg = "password is not supported by MONGODB-OIDC" - raise ConfigurationError(msg) - if callback or human_callback: - if environ is not None: - raise ConfigurationError(msg) - if callback and human_callback: - msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK" - raise ConfigurationError(msg) - elif environ is not None: - if environ == "test": - if user is not None: - msg = "test environment for MONGODB-OIDC does not support username" - raise ConfigurationError(msg) - callback = _OIDCTestCallback() - elif environ == "azure": - passwd = None - if not token_resource: - raise ConfigurationError( - "Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" - ) - callback = _OIDCAzureCallback(token_resource) - elif environ == "gcp": - passwd = None - if not token_resource: - raise ConfigurationError( - "GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" - ) - callback = _OIDCGCPCallback(token_resource) - else: - raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}") - else: - raise ConfigurationError(msg) - - oidc_props = _OIDCProperties( - callback=callback, - human_callback=human_callback, - environment=environ, - allowed_hosts=allowed_hosts, - token_resource=token_resource, - username=user, - ) - return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache()) - - elif mech == "PLAIN": - source_database = source or database or "$external" - return MongoCredential(mech, source_database, user, passwd, None, None) - else: - source_database = source or database or "admin" - if passwd is None: - raise ConfigurationError("A password is required.") - return MongoCredential(mech, source_database, user, passwd, None, _Cache()) - - -def _xor(fir: bytes, sec: bytes) -> bytes: - """XOR two byte strings together.""" - return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)]) - - -def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]: - """Split a scram response into key, value pairs.""" - return dict( - typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1)) - for item in response.split(b",") - ) - - -def _authenticate_scram_start( - credentials: MongoCredential, mechanism: str -) -> tuple[bytes, bytes, MutableMapping[str, Any]]: - username = credentials.username - user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C") - nonce = standard_b64encode(os.urandom(32)) - first_bare = b"n=" + user + b",r=" + nonce - - cmd = { - "saslStart": 1, - "mechanism": mechanism, - "payload": Binary(b"n,," + first_bare), - "autoAuthorize": 1, - "options": {"skipEmptyExchange": True}, - } - return nonce, first_bare, cmd - def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanism: str) -> None: """Authenticate using SCRAM.""" diff --git a/pymongo/synchronous/auth_aws.py b/pymongo/synchronous/auth_aws.py index 04ceb95b3..7c0d24f3a 100644 --- a/pymongo/synchronous/auth_aws.py +++ b/pymongo/synchronous/auth_aws.py @@ -23,7 +23,7 @@ from pymongo.errors import ConfigurationError, OperationFailure if TYPE_CHECKING: from bson.typings import _ReadableBuffer - from pymongo.synchronous.auth import MongoCredential + from pymongo.auth_shared import MongoCredential from pymongo.synchronous.pool import Connection _IS_SYNC = True diff --git a/pymongo/synchronous/auth_oidc.py b/pymongo/synchronous/auth_oidc.py index f59b4d54a..6381a408a 100644 --- a/pymongo/synchronous/auth_oidc.py +++ b/pymongo/synchronous/auth_oidc.py @@ -15,79 +15,35 @@ """MONGODB-OIDC Authentication helpers.""" from __future__ import annotations -import abc -import os import threading import time from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union -from urllib.parse import quote import bson from bson.binary import Binary -from pymongo._azure_helpers import _get_azure_response from pymongo._csot import remaining -from pymongo._gcp_helpers import _get_gcp_response +from pymongo.auth_oidc_shared import ( + CALLBACK_VERSION, + HUMAN_CALLBACK_TIMEOUT_SECONDS, + MACHINE_CALLBACK_TIMEOUT_SECONDS, + TIME_BETWEEN_CALLS_SECONDS, + OIDCCallback, + OIDCCallbackContext, + OIDCCallbackResult, + OIDCIdPInfo, + _OIDCProperties, +) from pymongo.errors import ConfigurationError, OperationFailure -from pymongo.helpers_constants import _AUTHENTICATION_FAILURE_CODE +from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE if TYPE_CHECKING: - from pymongo.synchronous.auth import MongoCredential + from pymongo.auth_shared import MongoCredential from pymongo.synchronous.pool import Connection _IS_SYNC = True -@dataclass -class OIDCIdPInfo: - issuer: str - clientId: Optional[str] = field(default=None) - requestScopes: Optional[list[str]] = field(default=None) - - -@dataclass -class OIDCCallbackContext: - timeout_seconds: float - username: str - version: int - refresh_token: Optional[str] = field(default=None) - idp_info: Optional[OIDCIdPInfo] = field(default=None) - - -@dataclass -class OIDCCallbackResult: - access_token: str - expires_in_seconds: Optional[float] = field(default=None) - refresh_token: Optional[str] = field(default=None) - - -class OIDCCallback(abc.ABC): - """A base class for defining OIDC callbacks.""" - - @abc.abstractmethod - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - """Convert the given BSON value into our own type.""" - - -@dataclass -class _OIDCProperties: - callback: Optional[OIDCCallback] = field(default=None) - human_callback: Optional[OIDCCallback] = field(default=None) - environment: Optional[str] = field(default=None) - allowed_hosts: list[str] = field(default_factory=list) - token_resource: Optional[str] = field(default=None) - username: str = "" - - -"""Mechanism properties for MONGODB-OIDC authentication.""" - -TOKEN_BUFFER_MINUTES = 5 -HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60 -CALLBACK_VERSION = 1 -MACHINE_CALLBACK_TIMEOUT_SECONDS = 60 -TIME_BETWEEN_CALLS_SECONDS = 0.1 - - def _get_authenticator( credentials: MongoCredential, address: tuple[str, int] ) -> _OIDCAuthenticator: @@ -117,48 +73,6 @@ def _get_authenticator( return credentials.cache.data -class _OIDCTestCallback(OIDCCallback): - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - token_file = os.environ.get("OIDC_TOKEN_FILE") - if not token_file: - raise RuntimeError( - 'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set' - ) - with open(token_file) as fid: - return OIDCCallbackResult(access_token=fid.read().strip()) - - -class _OIDCAWSCallback(OIDCCallback): - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") - if not token_file: - raise RuntimeError( - 'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set' - ) - with open(token_file) as fid: - return OIDCCallbackResult(access_token=fid.read().strip()) - - -class _OIDCAzureCallback(OIDCCallback): - def __init__(self, token_resource: str) -> None: - self.token_resource = quote(token_resource) - - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds) - return OIDCCallbackResult( - access_token=resp["access_token"], expires_in_seconds=resp["expires_in"] - ) - - -class _OIDCGCPCallback(OIDCCallback): - def __init__(self, token_resource: str) -> None: - self.token_resource = quote(token_resource) - - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - resp = _get_gcp_response(self.token_resource, context.timeout_seconds) - return OIDCCallbackResult(access_token=resp["access_token"]) - - @dataclass class _OIDCAuthenticator: username: str diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 781acdb4d..4da64c4a7 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -19,6 +19,8 @@ from __future__ import annotations import copy +import datetime +import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -26,7 +28,6 @@ from typing import ( Any, Iterator, Mapping, - NoReturn, Optional, Type, Union, @@ -34,140 +35,50 @@ from typing import ( from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument -from pymongo import _csot -from pymongo.errors import ( - BulkWriteError, - ConfigurationError, - InvalidOperation, - OperationFailure, +from pymongo import _csot, common +from pymongo.bulk_shared import ( + _COMMANDS, + _DELETE_ALL, + _merge_command, + _raise_bulk_write_error, + _Run, ) -from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES -from pymongo.synchronous import common -from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern -from pymongo.synchronous.common import ( +from pymongo.common import ( validate_is_document_type, validate_ok_for_replace, validate_ok_for_update, ) -from pymongo.synchronous.helpers import _get_wce_doc -from pymongo.synchronous.message import ( +from pymongo.errors import ( + ConfigurationError, + InvalidOperation, + NotPrimaryError, + OperationFailure, +) +from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import ( _DELETE, _INSERT, _UPDATE, _BulkWriteContext, + _convert_exception, + _convert_write_result, _EncryptedBulkWriteContext, _randint, ) -from pymongo.synchronous.read_preferences import ReadPreference +from pymongo.read_preferences import ReadPreference +from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern +from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern if TYPE_CHECKING: from pymongo.synchronous.collection import Collection + from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection - from pymongo.synchronous.typings import _DocumentOut, _DocumentType, _Pipeline + from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline _IS_SYNC = True -_DELETE_ALL: int = 0 -_DELETE_ONE: int = 1 - -# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err -_BAD_VALUE: int = 2 -_UNKNOWN_ERROR: int = 8 -_WRITE_CONCERN_ERROR: int = 64 - -_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete") - - -class _Run: - """Represents a batch of write operations.""" - - def __init__(self, op_type: int) -> None: - """Initialize a new Run object.""" - self.op_type: int = op_type - self.index_map: list[int] = [] - self.ops: list[Any] = [] - self.idx_offset: int = 0 - - def index(self, idx: int) -> int: - """Get the original index of an operation in this run. - - :param idx: The Run index that maps to the original index. - """ - return self.index_map[idx] - - def add(self, original_index: int, operation: Any) -> None: - """Add an operation to this Run instance. - - :param original_index: The original index of this operation - within a larger bulk operation. - :param operation: The operation document. - """ - self.index_map.append(original_index) - self.ops.append(operation) - - -def _merge_command( - run: _Run, - full_result: MutableMapping[str, Any], - offset: int, - result: Mapping[str, Any], -) -> None: - """Merge a write command result into the full bulk result.""" - affected = result.get("n", 0) - - if run.op_type == _INSERT: - full_result["nInserted"] += affected - - elif run.op_type == _DELETE: - full_result["nRemoved"] += affected - - elif run.op_type == _UPDATE: - upserted = result.get("upserted") - if upserted: - n_upserted = len(upserted) - for doc in upserted: - doc["index"] = run.index(doc["index"] + offset) - full_result["upserted"].extend(upserted) - full_result["nUpserted"] += n_upserted - full_result["nMatched"] += affected - n_upserted - else: - full_result["nMatched"] += affected - full_result["nModified"] += result["nModified"] - - write_errors = result.get("writeErrors") - if write_errors: - for doc in write_errors: - # Leave the server response intact for APM. - replacement = doc.copy() - idx = doc["index"] + offset - replacement["index"] = run.index(idx) - # Add the failed operation to the error document. - replacement["op"] = run.ops[idx] - full_result["writeErrors"].append(replacement) - - wce = _get_wce_doc(result) - if wce: - full_result["writeConcernErrors"].append(wce) - - -def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn: - """Raise a BulkWriteError from the full bulk api result.""" - # retryWrites on MMAPv1 should raise an actionable error. - if full_result["writeErrors"]: - full_result["writeErrors"].sort(key=lambda error: error["index"]) - err = full_result["writeErrors"][0] - code = err["code"] - msg = err["errmsg"] - if code == 20 and msg.startswith("Transaction numbers"): - errmsg = ( - "This MongoDB deployment does not support " - "retryable writes. Please add retryWrites=false " - "to your connection string." - ) - raise OperationFailure(errmsg, code, full_result) - raise BulkWriteError(full_result) - class _Bulk: """The private guts of the bulk write API.""" @@ -204,13 +115,16 @@ class _Bulk: # Extra state so that we know where to pick up on a retry attempt. self.current_run = None self.next_run = None + self.is_encrypted = False @property def bulk_ctx_class(self) -> Type[_BulkWriteContext]: encrypter = self.collection.database.client._encrypter if encrypter and not encrypter._bypass_auto_encryption: + self.is_encrypted = True return _EncryptedBulkWriteContext else: + self.is_encrypted = False return _BulkWriteContext def add_insert(self, document: _DocumentOut) -> None: @@ -315,6 +229,230 @@ class _Bulk: if run.ops: yield run + @_handle_reauth + def write_command( + self, + bwc: _BulkWriteContext, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + docs: list[Mapping[str, Any]], + client: MongoClient, + ) -> dict[str, Any]: + """A proxy for SocketInfo.write_command that handles event publishing.""" + cmd[bwc.field] = docs + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.STARTED, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._start(cmd, request_id, docs) + try: + reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] + duration = datetime.datetime.now() - bwc.start_time + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.SUCCEEDED, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.FAILED, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + + if bwc.publish: + bwc._fail(request_id, failure, duration) + raise + finally: + bwc.start_time = datetime.datetime.now() + return reply # type: ignore[return-value] + + def unack_write( + self, + bwc: _BulkWriteContext, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + max_doc_size: int, + docs: list[Mapping[str, Any]], + client: MongoClient, + ) -> Optional[Mapping[str, Any]]: + """A proxy for Connection.unack_write that handles event publishing.""" + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.STARTED, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + cmd = bwc._start(cmd, request_id, docs) + try: + result = bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] + duration = datetime.datetime.now() - bwc.start_time + if result is not None: + reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] + else: + # Comply with APM spec. + reply = {"ok": 1} + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.SUCCEEDED, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._succeed(request_id, reply, duration) + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, OperationFailure): + failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] + elif isinstance(exc, NotPrimaryError): + failure = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.FAILED, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if bwc.publish: + assert bwc.start_time is not None + bwc._fail(request_id, failure, duration) + raise + finally: + bwc.start_time = datetime.datetime.now() + return result # type: ignore[return-value] + + def _execute_batch_unack( + self, + bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext], + cmd: dict[str, Any], + ops: list[Mapping[str, Any]], + client: MongoClient, + ) -> list[Mapping[str, Any]]: + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + bwc.conn.command( # type: ignore[misc] + bwc.db_name, + batched_cmd, # type: ignore[arg-type] + write_concern=WriteConcern(w=0), + session=bwc.session, # type: ignore[arg-type] + client=client, # type: ignore[arg-type] + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + # Though this isn't strictly a "legacy" write, the helper + # handles publishing commands and sending our message + # without receiving a result. Send 0 for max_doc_size + # to disable size checking. Size checking is handled while + # the documents are encoded to BSON. + self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) # type: ignore[arg-type] + + return to_send + + def _execute_batch( + self, + bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext], + cmd: dict[str, Any], + ops: list[Mapping[str, Any]], + client: MongoClient, + ) -> tuple[dict[str, Any], list[Mapping[str, Any]]]: + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + result = bwc.conn.command( # type: ignore[misc] + bwc.db_name, + batched_cmd, # type: ignore[arg-type] + codec_options=bwc.codec, + session=bwc.session, # type: ignore[arg-type] + client=client, # type: ignore[arg-type] + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + result = self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type] + client._process_response(result, bwc.session) # type: ignore[arg-type] + + return result, to_send # type: ignore[return-value] + def _execute_command( self, generator: Iterator[Any], @@ -387,7 +525,7 @@ class _Bulk: # Run as many ops as possible in one command. if write_concern.acknowledged: - result, to_send = bwc.execute(cmd, ops, client) + result, to_send = self._execute_batch(bwc, cmd, ops, client) # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) @@ -407,7 +545,7 @@ class _Bulk: if self.ordered and "writeErrors" in result: break else: - to_send = bwc.execute_unack(cmd, ops, client) + to_send = self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) @@ -458,7 +596,7 @@ class _Bulk: retryable_bulk, session, operation, - bulk=self, + bulk=self, # type: ignore[arg-type] operation_id=op_id, ) @@ -499,7 +637,7 @@ class _Bulk: conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. - to_send = bwc.execute_unack(cmd, ops, client) + to_send = self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) diff --git a/pymongo/synchronous/change_stream.py b/pymongo/synchronous/change_stream.py index 1b22ed9be..f7489249d 100644 --- a/pymongo/synchronous/change_stream.py +++ b/pymongo/synchronous/change_stream.py @@ -21,7 +21,8 @@ from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union from bson import CodecOptions, _bson_to_dict from bson.raw_bson import RawBSONDocument from bson.timestamp import Timestamp -from pymongo import _csot +from pymongo import _csot, common +from pymongo.collation import validate_collation_or_none from pymongo.errors import ( ConnectionFailure, CursorNotFound, @@ -29,16 +30,14 @@ from pymongo.errors import ( OperationFailure, PyMongoError, ) -from pymongo.synchronous import common +from pymongo.operations import _Op from pymongo.synchronous.aggregation import ( _AggregationCommand, _CollectionAggregationCommand, _DatabaseAggregationCommand, ) -from pymongo.synchronous.collation import validate_collation_or_none from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.operations import _Op -from pymongo.synchronous.typings import _CollationIn, _DocumentType, _Pipeline +from pymongo.typings import _CollationIn, _DocumentType, _Pipeline _IS_SYNC = True diff --git a/pymongo/synchronous/client_options.py b/pymongo/synchronous/client_options.py deleted file mode 100644 index 58042220f..000000000 --- a/pymongo/synchronous/client_options.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Tools to parse mongo client options.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast - -from bson.codec_options import _parse_codec_options -from pymongo.errors import ConfigurationError -from pymongo.read_concern import ReadConcern -from pymongo.ssl_support import get_ssl_context -from pymongo.synchronous import common -from pymongo.synchronous.compression_support import CompressionSettings -from pymongo.synchronous.monitoring import _EventListener, _EventListeners -from pymongo.synchronous.pool import PoolOptions -from pymongo.synchronous.read_preferences import ( - _ServerMode, - make_read_preference, - read_pref_mode_from_name, -) -from pymongo.synchronous.server_selectors import any_server_selector -from pymongo.write_concern import WriteConcern, validate_boolean - -if TYPE_CHECKING: - from bson.codec_options import CodecOptions - from pymongo.pyopenssl_context import SSLContext - from pymongo.synchronous.auth import MongoCredential - from pymongo.synchronous.encryption_options import AutoEncryptionOpts - from pymongo.synchronous.topology_description import _ServerSelector - -_IS_SYNC = True - - -def _parse_credentials( - username: str, password: str, database: Optional[str], options: Mapping[str, Any] -) -> Optional[MongoCredential]: - """Parse authentication credentials.""" - mechanism = options.get("authmechanism", "DEFAULT" if username else None) - source = options.get("authsource") - if username or mechanism: - from pymongo.synchronous.auth import _build_credentials_tuple - - return _build_credentials_tuple(mechanism, source, username, password, options, database) - return None - - -def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode: - """Parse read preference options.""" - if "read_preference" in options: - return options["read_preference"] - - name = options.get("readpreference", "primary") - mode = read_pref_mode_from_name(name) - tags = options.get("readpreferencetags") - max_staleness = options.get("maxstalenessseconds", -1) - return make_read_preference(mode, tags, max_staleness) - - -def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern: - """Parse write concern options.""" - concern = options.get("w") - wtimeout = options.get("wtimeoutms") - j = options.get("journal") - fsync = options.get("fsync") - return WriteConcern(concern, wtimeout, j, fsync) - - -def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern: - """Parse read concern options.""" - concern = options.get("readconcernlevel") - return ReadConcern(concern) - - -def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]: - """Parse ssl options.""" - use_tls = options.get("tls") - if use_tls is not None: - validate_boolean("tls", use_tls) - - certfile = options.get("tlscertificatekeyfile") - passphrase = options.get("tlscertificatekeyfilepassword") - ca_certs = options.get("tlscafile") - crlfile = options.get("tlscrlfile") - allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False) - allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False) - disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False) - - enabled_tls_opts = [] - for opt in ( - "tlscertificatekeyfile", - "tlscertificatekeyfilepassword", - "tlscafile", - "tlscrlfile", - ): - # Any non-null value of these options implies tls=True. - if opt in options and options[opt]: - enabled_tls_opts.append(opt) - for opt in ( - "tlsallowinvalidcertificates", - "tlsallowinvalidhostnames", - "tlsdisableocspendpointcheck", - ): - # A value of False for these options implies tls=True. - if opt in options and not options[opt]: - enabled_tls_opts.append(opt) - - if enabled_tls_opts: - if use_tls is None: - # Implicitly enable TLS when one of the tls* options is set. - use_tls = True - elif not use_tls: - # Error since tls is explicitly disabled but a tls option is set. - raise ConfigurationError( - "TLS has not been enabled but the " - "following tls parameters have been set: " - "%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts) - ) - - if use_tls: - ctx = get_ssl_context( - certfile, - passphrase, - ca_certs, - crlfile, - allow_invalid_certificates, - allow_invalid_hostnames, - disable_ocsp_endpoint_check, - ) - return ctx, allow_invalid_hostnames - return None, allow_invalid_hostnames - - -def _parse_pool_options( - username: str, password: str, database: Optional[str], options: Mapping[str, Any] -) -> PoolOptions: - """Parse connection pool options.""" - credentials = _parse_credentials(username, password, database, options) - max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE) - min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE) - max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC) - if max_pool_size is not None and min_pool_size > max_pool_size: - raise ValueError("minPoolSize must be smaller or equal to maxPoolSize") - connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT) - socket_timeout = options.get("sockettimeoutms") - wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT) - event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners")) - appname = options.get("appname") - driver = options.get("driver") - server_api = options.get("server_api") - compression_settings = CompressionSettings( - options.get("compressors", []), options.get("zlibcompressionlevel", -1) - ) - ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options) - load_balanced = options.get("loadbalanced") - max_connecting = options.get("maxconnecting", common.MAX_CONNECTING) - return PoolOptions( - max_pool_size, - min_pool_size, - max_idle_time_seconds, - connect_timeout, - socket_timeout, - wait_queue_timeout, - ssl_context, - tls_allow_invalid_hostnames, - _EventListeners(event_listeners), - appname, - driver, - compression_settings, - max_connecting=max_connecting, - server_api=server_api, - load_balanced=load_balanced, - credentials=credentials, - ) - - -class ClientOptions: - """Read only configuration options for a MongoClient. - - Should not be instantiated directly by application developers. Access - a client's options via :attr:`pymongo.mongo_client.MongoClient.options` - instead. - """ - - def __init__( - self, username: str, password: str, database: Optional[str], options: Mapping[str, Any] - ): - self.__options = options - self.__codec_options = _parse_codec_options(options) - self.__direct_connection = options.get("directconnection") - self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS) - # self.__server_selection_timeout is in seconds. Must use full name for - # common.SERVER_SELECTION_TIMEOUT because it is set directly by tests. - self.__server_selection_timeout = options.get( - "serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT - ) - self.__pool_options = _parse_pool_options(username, password, database, options) - self.__read_preference = _parse_read_preference(options) - self.__replica_set_name = options.get("replicaset") - self.__write_concern = _parse_write_concern(options) - self.__read_concern = _parse_read_concern(options) - self.__connect = options.get("connect") - self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY) - self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES) - self.__retry_reads = options.get("retryreads", common.RETRY_READS) - self.__server_selector = options.get("server_selector", any_server_selector) - self.__auto_encryption_opts = options.get("auto_encryption_opts") - self.__load_balanced = options.get("loadbalanced") - self.__timeout = options.get("timeoutms") - self.__server_monitoring_mode = options.get( - "servermonitoringmode", common.SERVER_MONITORING_MODE - ) - - @property - def _options(self) -> Mapping[str, Any]: - """The original options used to create this ClientOptions.""" - return self.__options - - @property - def connect(self) -> Optional[bool]: - """Whether to begin discovering a MongoDB topology automatically.""" - return self.__connect - - @property - def codec_options(self) -> CodecOptions: - """A :class:`~bson.codec_options.CodecOptions` instance.""" - return self.__codec_options - - @property - def direct_connection(self) -> Optional[bool]: - """Whether to connect to the deployment in 'Single' topology.""" - return self.__direct_connection - - @property - def local_threshold_ms(self) -> int: - """The local threshold for this instance.""" - return self.__local_threshold_ms - - @property - def server_selection_timeout(self) -> int: - """The server selection timeout for this instance in seconds.""" - return self.__server_selection_timeout - - @property - def server_selector(self) -> _ServerSelector: - return self.__server_selector - - @property - def heartbeat_frequency(self) -> int: - """The monitoring frequency in seconds.""" - return self.__heartbeat_frequency - - @property - def pool_options(self) -> PoolOptions: - """A :class:`~pymongo.pool.PoolOptions` instance.""" - return self.__pool_options - - @property - def read_preference(self) -> _ServerMode: - """A read preference instance.""" - return self.__read_preference - - @property - def replica_set_name(self) -> Optional[str]: - """Replica set name or None.""" - return self.__replica_set_name - - @property - def write_concern(self) -> WriteConcern: - """A :class:`~pymongo.write_concern.WriteConcern` instance.""" - return self.__write_concern - - @property - def read_concern(self) -> ReadConcern: - """A :class:`~pymongo.read_concern.ReadConcern` instance.""" - return self.__read_concern - - @property - def timeout(self) -> Optional[float]: - """The configured timeoutMS converted to seconds, or None. - - .. versionadded:: 4.2 - """ - return self.__timeout - - @property - def retry_writes(self) -> bool: - """If this instance should retry supported write operations.""" - return self.__retry_writes - - @property - def retry_reads(self) -> bool: - """If this instance should retry supported read operations.""" - return self.__retry_reads - - @property - def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]: - """A :class:`~pymongo.encryption.AutoEncryptionOpts` or None.""" - return self.__auto_encryption_opts - - @property - def load_balanced(self) -> Optional[bool]: - """True if the client was configured to connect to a load balancer.""" - return self.__load_balanced - - @property - def event_listeners(self) -> list[_EventListeners]: - """The event listeners registered for this client. - - See :mod:`~pymongo.monitoring` for details. - - .. versionadded:: 4.0 - """ - assert self.__pool_options._event_listeners is not None - return self.__pool_options._event_listeners.event_listeners() - - @property - def server_monitoring_mode(self) -> str: - """The configured serverMonitoringMode option. - - .. versionadded:: 4.5 - """ - return self.__server_monitoring_mode diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 6489dcd27..e07298b49 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -164,12 +164,12 @@ from pymongo.errors import ( PyMongoError, WTimeoutError, ) -from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES +from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES +from pymongo.operations import _Op from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.server_type import SERVER_TYPE from pymongo.synchronous.cursor import _ConnectionManager -from pymongo.synchronous.operations import _Op -from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode from pymongo.write_concern import WriteConcern if TYPE_CHECKING: @@ -177,7 +177,7 @@ if TYPE_CHECKING: from pymongo.synchronous.pool import Connection from pymongo.synchronous.server import Server - from pymongo.synchronous.typings import ClusterTime, _Address + from pymongo.typings import ClusterTime, _Address _IS_SYNC = True diff --git a/pymongo/synchronous/collation.py b/pymongo/synchronous/collation.py deleted file mode 100644 index 1ce1ee00b..000000000 --- a/pymongo/synchronous/collation.py +++ /dev/null @@ -1,226 +0,0 @@ -# Copyright 2016 MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tools for working with `collations`_. - -.. _collations: https://www.mongodb.com/docs/manual/reference/collation/ -""" -from __future__ import annotations - -from typing import Any, Mapping, Optional, Union - -from pymongo.synchronous import common -from pymongo.write_concern import validate_boolean - -_IS_SYNC = True - - -class CollationStrength: - """ - An enum that defines values for `strength` on a - :class:`~pymongo.collation.Collation`. - """ - - PRIMARY = 1 - """Differentiate base (unadorned) characters.""" - - SECONDARY = 2 - """Differentiate character accents.""" - - TERTIARY = 3 - """Differentiate character case.""" - - QUATERNARY = 4 - """Differentiate words with and without punctuation.""" - - IDENTICAL = 5 - """Differentiate unicode code point (characters are exactly identical).""" - - -class CollationAlternate: - """ - An enum that defines values for `alternate` on a - :class:`~pymongo.collation.Collation`. - """ - - NON_IGNORABLE = "non-ignorable" - """Spaces and punctuation are treated as base characters.""" - - SHIFTED = "shifted" - """Spaces and punctuation are *not* considered base characters. - - Spaces and punctuation are distinguished regardless when the - :class:`~pymongo.collation.Collation` strength is at least - :data:`~pymongo.collation.CollationStrength.QUATERNARY`. - - """ - - -class CollationMaxVariable: - """ - An enum that defines values for `max_variable` on a - :class:`~pymongo.collation.Collation`. - """ - - PUNCT = "punct" - """Both punctuation and spaces are ignored.""" - - SPACE = "space" - """Spaces alone are ignored.""" - - -class CollationCaseFirst: - """ - An enum that defines values for `case_first` on a - :class:`~pymongo.collation.Collation`. - """ - - UPPER = "upper" - """Sort uppercase characters first.""" - - LOWER = "lower" - """Sort lowercase characters first.""" - - OFF = "off" - """Default for locale or collation strength.""" - - -class Collation: - """Collation - - :param locale: (string) The locale of the collation. This should be a string - that identifies an `ICU locale ID` exactly. For example, ``en_US`` is - valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB - documentation for a list of supported locales. - :param caseLevel: (optional) If ``True``, turn on case sensitivity if - `strength` is 1 or 2 (case sensitivity is implied if `strength` is - greater than 2). Defaults to ``False``. - :param caseFirst: (optional) Specify that either uppercase or lowercase - characters take precedence. Must be one of the following values: - - * :data:`~CollationCaseFirst.UPPER` - * :data:`~CollationCaseFirst.LOWER` - * :data:`~CollationCaseFirst.OFF` (the default) - - :param strength: Specify the comparison strength. This is also - known as the ICU comparison level. This must be one of the following - values: - - * :data:`~CollationStrength.PRIMARY` - * :data:`~CollationStrength.SECONDARY` - * :data:`~CollationStrength.TERTIARY` (the default) - * :data:`~CollationStrength.QUATERNARY` - * :data:`~CollationStrength.IDENTICAL` - - Each successive level builds upon the previous. For example, a - `strength` of :data:`~CollationStrength.SECONDARY` differentiates - characters based both on the unadorned base character and its accents. - - :param numericOrdering: If ``True``, order numbers numerically - instead of in collation order (defaults to ``False``). - :param alternate: Specify whether spaces and punctuation are - considered base characters. This must be one of the following values: - - * :data:`~CollationAlternate.NON_IGNORABLE` (the default) - * :data:`~CollationAlternate.SHIFTED` - - :param maxVariable: When `alternate` is - :data:`~CollationAlternate.SHIFTED`, this option specifies what - characters may be ignored. This must be one of the following values: - - * :data:`~CollationMaxVariable.PUNCT` (the default) - * :data:`~CollationMaxVariable.SPACE` - - :param normalization: If ``True``, normalizes text into Unicode - NFD. Defaults to ``False``. - :param backwards: If ``True``, accents on characters are - considered from the back of the word to the front, as it is done in some - French dictionary ordering traditions. Defaults to ``False``. - :param kwargs: Keyword arguments supplying any additional options - to be sent with this Collation object. - - .. versionadded: 3.4 - - """ - - __slots__ = ("__document",) - - def __init__( - self, - locale: str, - caseLevel: Optional[bool] = None, - caseFirst: Optional[str] = None, - strength: Optional[int] = None, - numericOrdering: Optional[bool] = None, - alternate: Optional[str] = None, - maxVariable: Optional[str] = None, - normalization: Optional[bool] = None, - backwards: Optional[bool] = None, - **kwargs: Any, - ) -> None: - locale = common.validate_string("locale", locale) - self.__document: dict[str, Any] = {"locale": locale} - if caseLevel is not None: - self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel) - if caseFirst is not None: - self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst) - if strength is not None: - self.__document["strength"] = common.validate_integer("strength", strength) - if numericOrdering is not None: - self.__document["numericOrdering"] = validate_boolean( - "numericOrdering", numericOrdering - ) - if alternate is not None: - self.__document["alternate"] = common.validate_string("alternate", alternate) - if maxVariable is not None: - self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable) - if normalization is not None: - self.__document["normalization"] = validate_boolean("normalization", normalization) - if backwards is not None: - self.__document["backwards"] = validate_boolean("backwards", backwards) - self.__document.update(kwargs) - - @property - def document(self) -> dict[str, Any]: - """The document representation of this collation. - - .. note:: - :class:`Collation` is immutable. Mutating the value of - :attr:`document` does not mutate this :class:`Collation`. - """ - return self.__document.copy() - - def __repr__(self) -> str: - document = self.document - return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document)) - - def __eq__(self, other: Any) -> bool: - if isinstance(other, Collation): - return self.document == other.document - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -def validate_collation_or_none( - value: Optional[Union[Mapping[str, Any], Collation]] -) -> Optional[dict[str, Any]]: - if value is None: - return None - if isinstance(value, Collation): - return value.document - if isinstance(value, dict): - return value - raise TypeError("collation must be a dict, an instance of collation.Collation, or None.") diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index c04afbe7c..0b299bf5b 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -41,41 +41,18 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from bson.son import SON from bson.timestamp import Timestamp -from pymongo import ASCENDING, _csot +from pymongo import ASCENDING, _csot, common, helpers_shared, message +from pymongo.collation import validate_collation_or_none +from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.errors import ( ConfigurationError, InvalidName, InvalidOperation, OperationFailure, ) -from pymongo.read_concern import DEFAULT_READ_CONCERN -from pymongo.results import ( - BulkWriteResult, - DeleteResult, - InsertManyResult, - InsertOneResult, - UpdateResult, -) -from pymongo.synchronous import common, helpers, message -from pymongo.synchronous.aggregation import ( - _CollectionAggregationCommand, - _CollectionRawAggregationCommand, -) -from pymongo.synchronous.bulk import _Bulk -from pymongo.synchronous.change_stream import CollectionChangeStream -from pymongo.synchronous.collation import validate_collation_or_none -from pymongo.synchronous.command_cursor import ( - CommandCursor, - RawBatchCommandCursor, -) -from pymongo.synchronous.common import _ecoc_coll_name, _esc_coll_name -from pymongo.synchronous.cursor import ( - Cursor, - RawBatchCursor, -) -from pymongo.synchronous.helpers import _check_write_command_response -from pymongo.synchronous.message import _UNICODE_REPLACE_CODEC_OPTIONS -from pymongo.synchronous.operations import ( +from pymongo.helpers_shared import _check_write_command_response +from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS +from pymongo.operations import ( DeleteMany, DeleteOne, IndexModel, @@ -88,8 +65,30 @@ from pymongo.synchronous.operations import ( _IndexList, _Op, ) -from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode -from pymongo.synchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline +from pymongo.read_concern import DEFAULT_READ_CONCERN +from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.results import ( + BulkWriteResult, + DeleteResult, + InsertManyResult, + InsertOneResult, + UpdateResult, +) +from pymongo.synchronous.aggregation import ( + _CollectionAggregationCommand, + _CollectionRawAggregationCommand, +) +from pymongo.synchronous.bulk import _Bulk +from pymongo.synchronous.change_stream import CollectionChangeStream +from pymongo.synchronous.command_cursor import ( + CommandCursor, + RawBatchCommandCursor, +) +from pymongo.synchronous.cursor import ( + Cursor, + RawBatchCursor, +) +from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean _IS_SYNC = True @@ -125,10 +124,10 @@ class ReturnDocument: if TYPE_CHECKING: import bson + from pymongo.collation import Collation from pymongo.read_concern import ReadConcern from pymongo.synchronous.aggregation import _AggregationCommand from pymongo.synchronous.client_session import ClientSession - from pymongo.synchronous.collation import Collation from pymongo.synchronous.database import Database from pymongo.synchronous.pool import Connection from pymongo.synchronous.server import Server @@ -995,7 +994,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." ) if not isinstance(hint, str): - hint = helpers._index_document(hint) + hint = helpers_shared._index_document(hint) update_doc["hint"] = hint command = {"update": self.name, "ordered": ordered, "updates": [update_doc]} if let is not None: @@ -1476,7 +1475,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): "Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands." ) if not isinstance(hint, str): - hint = helpers._index_document(hint) + hint = helpers_shared._index_document(hint) delete_doc["hint"] = hint command = {"delete": self.name, "ordered": ordered, "deletes": [delete_doc]} @@ -2092,7 +2091,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): pipeline.append({"$group": {"_id": 1, "n": {"$sum": 1}}}) cmd = {"aggregate": self._name, "pipeline": pipeline, "cursor": {}} if "hint" in kwargs and not isinstance(kwargs["hint"], str): - kwargs["hint"] = helpers._index_document(kwargs["hint"]) + kwargs["hint"] = helpers_shared._index_document(kwargs["hint"]) collation = validate_collation_or_none(kwargs.pop("collation", None)) cmd.update(kwargs) @@ -2425,7 +2424,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): ) -> None: name = index_or_name if isinstance(index_or_name, list): - name = helpers._gen_index_name(index_or_name) + name = helpers_shared._gen_index_name(index_or_name) if not isinstance(name, str): raise TypeError("index_or_name must be an instance of str or list") @@ -3154,15 +3153,15 @@ class Collection(common.BaseObject, Generic[_DocumentType]): cmd["let"] = let cmd.update(kwargs) if projection is not None: - cmd["fields"] = helpers._fields_list_to_dict(projection, "projection") + cmd["fields"] = helpers_shared._fields_list_to_dict(projection, "projection") if sort is not None: - cmd["sort"] = helpers._index_document(sort) + cmd["sort"] = helpers_shared._index_document(sort) if upsert is not None: validate_boolean("upsert", upsert) cmd["upsert"] = upsert if hint is not None: if not isinstance(hint, str): - hint = helpers._index_document(hint) + hint = helpers_shared._index_document(hint) write_concern = self._write_concern_for_cmd(cmd, session) diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index acd658d69..ae245a7fc 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -31,16 +31,16 @@ from typing import ( from bson import CodecOptions, _convert_raw_document_lists_to_streams from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure -from pymongo.synchronous.cursor import _ConnectionManager -from pymongo.synchronous.message import ( +from pymongo.message import ( _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore, ) -from pymongo.synchronous.response import PinnedResponse -from pymongo.synchronous.typings import _Address, _DocumentOut, _DocumentType +from pymongo.response import PinnedResponse +from pymongo.synchronous.cursor import _ConnectionManager +from pymongo.typings import _Address, _DocumentOut, _DocumentType if TYPE_CHECKING: from pymongo.synchronous.client_session import ClientSession @@ -272,7 +272,7 @@ class CommandCursor(Generic[_DocumentType]): if isinstance(response, PinnedResponse): if not self._sock_mgr: - self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) + self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type] if response.from_command: cursor = response.docs[0]["cursor"] documents = cursor["nextBatch"] diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 121cee810..28af19724 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -36,17 +36,16 @@ from typing import ( from bson import RE_TYPE, _convert_raw_document_lists_to_streams from bson.code import Code from bson.son import SON -from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort -from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure -from pymongo.lock import _create_lock -from pymongo.synchronous import helpers -from pymongo.synchronous.collation import validate_collation_or_none -from pymongo.synchronous.common import ( +from pymongo import helpers_shared +from pymongo.collation import validate_collation_or_none +from pymongo.common import ( validate_is_document_type, validate_is_mapping, ) -from pymongo.synchronous.helpers import next -from pymongo.synchronous.message import ( +from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort +from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure +from pymongo.lock import _create_lock +from pymongo.message import ( _CursorAddress, _GetMore, _OpMsg, @@ -55,18 +54,19 @@ from pymongo.synchronous.message import ( _RawBatchGetMore, _RawBatchQuery, ) -from pymongo.synchronous.response import PinnedResponse -from pymongo.synchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType +from pymongo.response import PinnedResponse +from pymongo.synchronous.helpers import next +from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType from pymongo.write_concern import validate_boolean if TYPE_CHECKING: from _typeshed import SupportsItems from bson.codec_options import CodecOptions + from pymongo.read_preferences import _ServerMode from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.collection import Collection from pymongo.synchronous.pool import Connection - from pymongo.synchronous.read_preferences import _ServerMode _IS_SYNC = True @@ -179,7 +179,7 @@ class Cursor(Generic[_DocumentType]): allow_disk_use = validate_boolean("allow_disk_use", allow_disk_use) if projection is not None: - projection = helpers._fields_list_to_dict(projection, "projection") + projection = helpers_shared._fields_list_to_dict(projection, "projection") if let is not None: validate_is_document_type("let", let) @@ -191,7 +191,7 @@ class Cursor(Generic[_DocumentType]): self._skip = skip self._limit = limit self._batch_size = batch_size - self._ordering = sort and helpers._index_document(sort) or None + self._ordering = sort and helpers_shared._index_document(sort) or None self._max_scan = max_scan self._explain = False self._comment = comment @@ -740,8 +740,8 @@ class Cursor(Generic[_DocumentType]): key, if not given :data:`~pymongo.ASCENDING` is assumed """ self._check_okay_to_chain() - keys = helpers._index_list(key_or_list, direction) - self._ordering = helpers._index_document(keys) + keys = helpers_shared._index_list(key_or_list, direction) + self._ordering = helpers_shared._index_document(keys) return self def explain(self) -> _DocumentType: @@ -772,7 +772,7 @@ class Cursor(Generic[_DocumentType]): if isinstance(index, str): self._hint = index else: - self._hint = helpers._index_document(index) + self._hint = helpers_shared._index_document(index) def hint(self, index: Optional[_Hint]) -> Cursor[_DocumentType]: """Adds a 'hint', telling Mongo the proper index to use for the query. @@ -1120,7 +1120,7 @@ class Cursor(Generic[_DocumentType]): self._address = response.address if isinstance(response, PinnedResponse): if not self._sock_mgr: - self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) + self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type] cmd_name = operation.name docs = response.docs diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index 92521d7c1..eaef0558d 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -33,18 +33,17 @@ from typing import ( from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions from bson.dbref import DBRef from bson.timestamp import Timestamp -from pymongo import _csot +from pymongo import _csot, common +from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.database_shared import _check_name, _CodecDocumentType from pymongo.errors import CollectionInvalid, InvalidOperation -from pymongo.synchronous import common +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.synchronous.aggregation import _DatabaseAggregationCommand from pymongo.synchronous.change_stream import DatabaseChangeStream from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.common import _ecoc_coll_name, _esc_coll_name -from pymongo.synchronous.operations import _Op -from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode -from pymongo.synchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline +from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline if TYPE_CHECKING: import bson @@ -151,7 +150,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): >>> db1.read_preference Primary() - >>> from pymongo.synchronous.read_preferences import Secondary + >>> from pymongo.read_preferences import Secondary >>> db2 = db1.with_options(read_preference=Secondary([{'node': 'analytics'}])) >>> db1.read_preference Primary() diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index e498fc9e1..909fd12c8 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -58,7 +58,9 @@ from bson.codec_options import CodecOptions from bson.errors import BSONError from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson from pymongo import _csot +from pymongo.common import CONNECT_TIMEOUT from pymongo.daemon import _spawn_daemon +from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts from pymongo.errors import ( ConfigurationError, EncryptedCollectionError, @@ -68,19 +70,18 @@ from pymongo.errors import ( ServerSelectionTimeoutError, ) from pymongo.network_layer import BLOCKING_IO_ERRORS, sendall +from pymongo.operations import UpdateOne +from pymongo.pool_options import PoolOptions from pymongo.read_concern import ReadConcern from pymongo.results import BulkWriteResult, DeleteResult from pymongo.ssl_support import get_ssl_context from pymongo.synchronous.collection import Collection -from pymongo.synchronous.common import CONNECT_TIMEOUT from pymongo.synchronous.cursor import Cursor from pymongo.synchronous.database import Database -from pymongo.synchronous.encryption_options import AutoEncryptionOpts, RangeOpts from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.operations import UpdateOne -from pymongo.synchronous.pool import PoolOptions, _configured_socket, _raise_connection_failure -from pymongo.synchronous.typings import _DocumentType, _DocumentTypeArg -from pymongo.synchronous.uri_parser import parse_host +from pymongo.synchronous.pool import _configured_socket, _raise_connection_failure +from pymongo.typings import _DocumentType, _DocumentTypeArg +from pymongo.uri_parser import parse_host from pymongo.write_concern import WriteConcern if TYPE_CHECKING: @@ -379,7 +380,10 @@ class _Encrypter: ) io_callbacks = _EncryptionIO( # type:ignore[misc] - metadata_client, key_vault_coll, mongocryptd_client, opts + metadata_client, + key_vault_coll, # type:ignore[arg-type] + mongocryptd_client, + opts, ) self._auto_encrypter = AutoEncrypter( io_callbacks, @@ -457,7 +461,14 @@ class Algorithm(str, enum.Enum): RANGE = "Range" """Range. - .. versionadded:: 4.8 + .. versionadded:: 4.9 + """ + RANGEPREVIEW = "RangePreview" + """**DEPRECATED** - RangePreview. + + .. note:: Support for RangePreview is deprecated. Use :attr:`Algorithm.RANGE` instead. + + .. versionadded:: 4.4 """ @@ -471,7 +482,18 @@ class QueryType(str, enum.Enum): """Used to encrypt a value for an equality query.""" RANGE = "range" - """Used to encrypt a value for a range query.""" + """Used to encrypt a value for a range query. + + .. versionadded:: 4.9 + """ + + RANGEPREVIEW = "RangePreview" + """**DEPRECATED** - Used to encrypt a value for a rangePreview query. + + .. note:: Support for RangePreview is deprecated. Use :attr:`QueryType.RANGE` instead. + + .. versionadded:: 4.4 + """ def _create_mongocrypt_options(**kwargs: Any) -> MongoCryptOptions: @@ -568,7 +590,7 @@ class ClientEncryption(Generic[_DocumentType]): raise ConfigurationError( "client-side field level encryption requires the pymongocrypt " "library: install a compatible version with: " - "python -m pip install 'pymongo[encryption]'" + "python -m pip install --upgrade 'pymongo[encryption]'" ) if not isinstance(codec_options, CodecOptions): @@ -841,7 +863,7 @@ class ClientEncryption(Generic[_DocumentType]): :return: The encrypted value, a :class:`~bson.binary.Binary` with subtype 6. - .. versionchanged:: 4.8 + .. versionchanged:: 4.9 Added the `range_opts` parameter. .. versionchanged:: 4.7 @@ -897,7 +919,7 @@ class ClientEncryption(Generic[_DocumentType]): :return: The encrypted expression, a :class:`~bson.RawBSONDocument`. - .. versionchanged:: 4.8 + .. versionchanged:: 4.9 Added the `range_opts` parameter. .. versionchanged:: 4.7 diff --git a/pymongo/synchronous/encryption_options.py b/pymongo/synchronous/encryption_options.py deleted file mode 100644 index 5b1cebd7b..000000000 --- a/pymongo/synchronous/encryption_options.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright 2019-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Support for automatic client-side field level encryption.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Mapping, Optional - -try: - import pymongocrypt # type:ignore[import] # noqa: F401 - - _HAVE_PYMONGOCRYPT = True -except ImportError: - _HAVE_PYMONGOCRYPT = False -from bson import int64 -from pymongo.errors import ConfigurationError -from pymongo.synchronous.common import validate_is_mapping -from pymongo.synchronous.uri_parser import _parse_kms_tls_options - -if TYPE_CHECKING: - from pymongo.synchronous.mongo_client import MongoClient - from pymongo.synchronous.typings import _DocumentTypeArg - -_IS_SYNC = True - - -class AutoEncryptionOpts: - """Options to configure automatic client-side field level encryption.""" - - def __init__( - self, - kms_providers: Mapping[str, Any], - key_vault_namespace: str, - key_vault_client: Optional[MongoClient[_DocumentTypeArg]] = None, - schema_map: Optional[Mapping[str, Any]] = None, - bypass_auto_encryption: bool = False, - mongocryptd_uri: str = "mongodb://localhost:27020", - mongocryptd_bypass_spawn: bool = False, - mongocryptd_spawn_path: str = "mongocryptd", - mongocryptd_spawn_args: Optional[list[str]] = None, - kms_tls_options: Optional[Mapping[str, Any]] = None, - crypt_shared_lib_path: Optional[str] = None, - crypt_shared_lib_required: bool = False, - bypass_query_analysis: bool = False, - encrypted_fields_map: Optional[Mapping[str, Any]] = None, - ) -> None: - """Options to configure automatic client-side field level encryption. - - Automatic client-side field level encryption requires MongoDB >=4.2 - enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not - supported for operations on a database or view and will result in - error. - - Although automatic encryption requires MongoDB >=4.2 enterprise or a - MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all - users. To configure automatic *decryption* without automatic - *encryption* set ``bypass_auto_encryption=True``. Explicit - encryption and explicit decryption is also supported for all users - with the :class:`~pymongo.encryption.ClientEncryption` class. - - See :ref:`automatic-client-side-encryption` for an example. - - :param kms_providers: Map of KMS provider options. The `kms_providers` - map values differ by provider: - - - `aws`: Map with "accessKeyId" and "secretAccessKey" as strings. - These are the AWS access key ID and AWS secret access key used - to generate KMS messages. An optional "sessionToken" may be - included to support temporary AWS credentials. - - `azure`: Map with "tenantId", "clientId", and "clientSecret" as - strings. Additionally, "identityPlatformEndpoint" may also be - specified as a string (defaults to 'login.microsoftonline.com'). - These are the Azure Active Directory credentials used to - generate Azure Key Vault messages. - - `gcp`: Map with "email" as a string and "privateKey" - as `bytes` or a base64 encoded string. - Additionally, "endpoint" may also be specified as a string - (defaults to 'oauth2.googleapis.com'). These are the - credentials used to generate Google Cloud KMS messages. - - `kmip`: Map with "endpoint" as a host with required port. - For example: ``{"endpoint": "example.com:443"}``. - - `local`: Map with "key" as `bytes` (96 bytes in length) or - a base64 encoded string which decodes - to 96 bytes. "key" is the master key used to encrypt/decrypt - data keys. This key should be generated and stored as securely - as possible. - - KMS providers may be specified with an optional name suffix - separated by a colon, for example "kmip:name" or "aws:name". - Named KMS providers do not support :ref:`CSFLE on-demand credentials`. - Named KMS providers enables more than one of each KMS provider type to be configured. - For example, to configure multiple local KMS providers:: - - kms_providers = { - "local": {"key": local_kek1}, # Unnamed KMS provider. - "local:myname": {"key": local_kek2}, # Named KMS provider with name "myname". - } - - :param key_vault_namespace: The namespace for the key vault collection. - The key vault collection contains all data keys used for encryption - and decryption. Data keys are stored as documents in this MongoDB - collection. Data keys are protected with encryption by a KMS - provider. - :param key_vault_client: By default, the key vault collection - is assumed to reside in the same MongoDB cluster as the encrypted - MongoClient. Use this option to route data key queries to a - separate MongoDB cluster. - :param schema_map: Map of collection namespace ("db.coll") to - JSON Schema. By default, a collection's JSONSchema is periodically - polled with the listCollections command. But a JSONSchema may be - specified locally with the schemaMap option. - - **Supplying a `schema_map` provides more security than relying on - JSON Schemas obtained from the server. It protects against a - malicious server advertising a false JSON Schema, which could trick - the client into sending unencrypted data that should be - encrypted.** - - Schemas supplied in the schemaMap only apply to configuring - automatic encryption for client side encryption. Other validation - rules in the JSON schema will not be enforced by the driver and - will result in an error. - :param bypass_auto_encryption: If ``True``, automatic - encryption will be disabled but automatic decryption will still be - enabled. Defaults to ``False``. - :param mongocryptd_uri: The MongoDB URI used to connect - to the *local* mongocryptd process. Defaults to - ``'mongodb://localhost:27020'``. - :param mongocryptd_bypass_spawn: If ``True``, the encrypted - MongoClient will not attempt to spawn the mongocryptd process. - Defaults to ``False``. - :param mongocryptd_spawn_path: Used for spawning the - mongocryptd process. Defaults to ``'mongocryptd'`` and spawns - mongocryptd from the system path. - :param mongocryptd_spawn_args: A list of string arguments to - use when spawning the mongocryptd process. Defaults to - ``['--idleShutdownTimeoutSecs=60']``. If the list does not include - the ``idleShutdownTimeoutSecs`` option then - ``'--idleShutdownTimeoutSecs=60'`` will be added. - :param kms_tls_options: A map of KMS provider names to TLS - options to use when creating secure connections to KMS providers. - Accepts the same TLS options as - :class:`pymongo.mongo_client.MongoClient`. For example, to - override the system default CA file:: - - kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}} - - Or to supply a client certificate:: - - kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}} - :param crypt_shared_lib_path: Override the path to load the crypt_shared library. - :param crypt_shared_lib_required: If True, raise an error if libmongocrypt is - unable to load the crypt_shared library. - :param bypass_query_analysis: If ``True``, disable automatic analysis - of outgoing commands. Set `bypass_query_analysis` to use explicit - encryption on indexed fields without the MongoDB Enterprise Advanced - licensed crypt_shared library. - :param encrypted_fields_map: Map of collection namespace ("db.coll") to documents - that described the encrypted fields for Queryable Encryption. For example:: - - { - "db.encryptedCollection": { - "escCollection": "enxcol_.encryptedCollection.esc", - "ecocCollection": "enxcol_.encryptedCollection.ecoc", - "fields": [ - { - "path": "firstName", - "keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')), - "bsonType": "string", - "queries": {"queryType": "equality"} - }, - { - "path": "ssn", - "keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')), - "bsonType": "string" - } - ] - } - } - - .. versionchanged:: 4.2 - Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`, - and `bypass_query_analysis` parameters. - - .. versionchanged:: 4.0 - Added the `kms_tls_options` parameter and the "kmip" KMS provider. - - .. versionadded:: 3.9 - """ - if not _HAVE_PYMONGOCRYPT: - raise ConfigurationError( - "client side encryption requires the pymongocrypt library: " - "install a compatible version with: " - "python -m pip install 'pymongo[encryption]'" - ) - if encrypted_fields_map: - validate_is_mapping("encrypted_fields_map", encrypted_fields_map) - self._encrypted_fields_map = encrypted_fields_map - self._bypass_query_analysis = bypass_query_analysis - self._crypt_shared_lib_path = crypt_shared_lib_path - self._crypt_shared_lib_required = crypt_shared_lib_required - self._kms_providers = kms_providers - self._key_vault_namespace = key_vault_namespace - self._key_vault_client = key_vault_client - self._schema_map = schema_map - self._bypass_auto_encryption = bypass_auto_encryption - self._mongocryptd_uri = mongocryptd_uri - self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn - self._mongocryptd_spawn_path = mongocryptd_spawn_path - if mongocryptd_spawn_args is None: - mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"] - self._mongocryptd_spawn_args = mongocryptd_spawn_args - if not isinstance(self._mongocryptd_spawn_args, list): - raise TypeError("mongocryptd_spawn_args must be a list") - if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args): - self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60") - # Maps KMS provider name to a SSLContext. - self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options) - self._bypass_query_analysis = bypass_query_analysis - - -class RangeOpts: - """Options to configure encrypted queries using the range algorithm.""" - - def __init__( - self, - sparsity: int, - trim_factor: int, - min: Optional[Any] = None, - max: Optional[Any] = None, - precision: Optional[int] = None, - ) -> None: - """Options to configure encrypted queries using the range algorithm. - - :param sparsity: An integer. - :param trim_factor: An integer. - :param min: A BSON scalar value corresponding to the type being queried. - :param max: A BSON scalar value corresponding to the type being queried. - :param precision: An integer, may only be set for double or decimal128 types. - - .. versionadded:: 4.4 - """ - self.min = min - self.max = max - self.sparsity = sparsity - self.trim_factor = trim_factor - self.precision = precision - - @property - def document(self) -> dict[str, Any]: - doc = {} - for k, v in [ - ("sparsity", int64.Int64(self.sparsity)), - ("trimFactor", self.trim_factor), - ("precision", self.precision), - ("min", self.min), - ("max", self.max), - ]: - if v is not None: - doc[k] = v - return doc diff --git a/pymongo/synchronous/event_loggers.py b/pymongo/synchronous/event_loggers.py deleted file mode 100644 index fe9dd899d..000000000 --- a/pymongo/synchronous/event_loggers.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2020-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Example event logger classes. - -.. versionadded:: 3.11 - -These loggers can be registered using :func:`register` or -:class:`~pymongo.mongo_client.MongoClient`. - -``monitoring.register(CommandLogger())`` - -or - -``MongoClient(event_listeners=[CommandLogger()])`` -""" -from __future__ import annotations - -import logging - -from pymongo.synchronous import monitoring - -_IS_SYNC = True - - -class CommandLogger(monitoring.CommandListener): - """A simple listener that logs command events. - - Listens for :class:`~pymongo.monitoring.CommandStartedEvent`, - :class:`~pymongo.monitoring.CommandSucceededEvent` and - :class:`~pymongo.monitoring.CommandFailedEvent` events and - logs them at the `INFO` severity level using :mod:`logging`. - .. versionadded:: 3.11 - """ - - def started(self, event: monitoring.CommandStartedEvent) -> None: - logging.info( - f"Command {event.command_name} with request id " - f"{event.request_id} started on server " - f"{event.connection_id}" - ) - - def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: - logging.info( - f"Command {event.command_name} with request id " - f"{event.request_id} on server {event.connection_id} " - f"succeeded in {event.duration_micros} " - "microseconds" - ) - - def failed(self, event: monitoring.CommandFailedEvent) -> None: - logging.info( - f"Command {event.command_name} with request id " - f"{event.request_id} on server {event.connection_id} " - f"failed in {event.duration_micros} " - "microseconds" - ) - - -class ServerLogger(monitoring.ServerListener): - """A simple listener that logs server discovery events. - - Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`, - :class:`~pymongo.monitoring.ServerDescriptionChangedEvent`, - and :class:`~pymongo.monitoring.ServerClosedEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def opened(self, event: monitoring.ServerOpeningEvent) -> None: - logging.info(f"Server {event.server_address} added to topology {event.topology_id}") - - def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None: - previous_server_type = event.previous_description.server_type - new_server_type = event.new_description.server_type - if new_server_type != previous_server_type: - # server_type_name was added in PyMongo 3.4 - logging.info( - f"Server {event.server_address} changed type from " - f"{event.previous_description.server_type_name} to " - f"{event.new_description.server_type_name}" - ) - - def closed(self, event: monitoring.ServerClosedEvent) -> None: - logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}") - - -class HeartbeatLogger(monitoring.ServerHeartbeatListener): - """A simple listener that logs server heartbeat events. - - Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`, - :class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`, - and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None: - logging.info(f"Heartbeat sent to server {event.connection_id}") - - def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None: - # The reply.document attribute was added in PyMongo 3.4. - logging.info( - f"Heartbeat to server {event.connection_id} " - "succeeded with reply " - f"{event.reply.document}" - ) - - def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None: - logging.warning( - f"Heartbeat to server {event.connection_id} failed with error {event.reply}" - ) - - -class TopologyLogger(monitoring.TopologyListener): - """A simple listener that logs server topology events. - - Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`, - :class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`, - and :class:`~pymongo.monitoring.TopologyClosedEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def opened(self, event: monitoring.TopologyOpenedEvent) -> None: - logging.info(f"Topology with id {event.topology_id} opened") - - def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None: - logging.info(f"Topology description updated for topology id {event.topology_id}") - previous_topology_type = event.previous_description.topology_type - new_topology_type = event.new_description.topology_type - if new_topology_type != previous_topology_type: - # topology_type_name was added in PyMongo 3.4 - logging.info( - f"Topology {event.topology_id} changed type from " - f"{event.previous_description.topology_type_name} to " - f"{event.new_description.topology_type_name}" - ) - # The has_writable_server and has_readable_server methods - # were added in PyMongo 3.4. - if not event.new_description.has_writable_server(): - logging.warning("No writable servers available.") - if not event.new_description.has_readable_server(): - logging.warning("No readable servers available.") - - def closed(self, event: monitoring.TopologyClosedEvent) -> None: - logging.info(f"Topology with id {event.topology_id} closed") - - -class ConnectionPoolLogger(monitoring.ConnectionPoolListener): - """A simple listener that logs server connection pool events. - - Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`, - :class:`~pymongo.monitoring.PoolClearedEvent`, - :class:`~pymongo.monitoring.PoolClosedEvent`, - :~pymongo.monitoring.class:`ConnectionCreatedEvent`, - :class:`~pymongo.monitoring.ConnectionReadyEvent`, - :class:`~pymongo.monitoring.ConnectionClosedEvent`, - :class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`, - :class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`, - :class:`~pymongo.monitoring.ConnectionCheckedOutEvent`, - and :class:`~pymongo.monitoring.ConnectionCheckedInEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def pool_created(self, event: monitoring.PoolCreatedEvent) -> None: - logging.info(f"[pool {event.address}] pool created") - - def pool_ready(self, event: monitoring.PoolReadyEvent) -> None: - logging.info(f"[pool {event.address}] pool ready") - - def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None: - logging.info(f"[pool {event.address}] pool cleared") - - def pool_closed(self, event: monitoring.PoolClosedEvent) -> None: - logging.info(f"[pool {event.address}] pool closed") - - def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None: - logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created") - - def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded" - ) - - def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] " - f'connection closed, reason: "{event.reason}"' - ) - - def connection_check_out_started( - self, event: monitoring.ConnectionCheckOutStartedEvent - ) -> None: - logging.info(f"[pool {event.address}] connection check out started") - - def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None: - logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}") - - def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool" - ) - - def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool" - ) diff --git a/pymongo/synchronous/hello.py b/pymongo/synchronous/hello.py deleted file mode 100644 index 5c1d8438f..000000000 --- a/pymongo/synchronous/hello.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2021-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Helpers for the 'hello' and legacy hello commands.""" -from __future__ import annotations - -import copy -import datetime -import itertools -from typing import Any, Generic, Mapping, Optional - -from bson.objectid import ObjectId -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import common -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.typings import ClusterTime, _DocumentType - -_IS_SYNC = True - - -def _get_server_type(doc: Mapping[str, Any]) -> int: - """Determine the server type from a hello response.""" - if not doc.get("ok"): - return SERVER_TYPE.Unknown - - if doc.get("serviceId"): - return SERVER_TYPE.LoadBalancer - elif doc.get("isreplicaset"): - return SERVER_TYPE.RSGhost - elif doc.get("setName"): - if doc.get("hidden"): - return SERVER_TYPE.RSOther - elif doc.get(HelloCompat.PRIMARY): - return SERVER_TYPE.RSPrimary - elif doc.get(HelloCompat.LEGACY_PRIMARY): - return SERVER_TYPE.RSPrimary - elif doc.get("secondary"): - return SERVER_TYPE.RSSecondary - elif doc.get("arbiterOnly"): - return SERVER_TYPE.RSArbiter - else: - return SERVER_TYPE.RSOther - elif doc.get("msg") == "isdbgrid": - return SERVER_TYPE.Mongos - else: - return SERVER_TYPE.Standalone - - -class Hello(Generic[_DocumentType]): - """Parse a hello response from the server. - - .. versionadded:: 3.12 - """ - - __slots__ = ("_doc", "_server_type", "_is_writable", "_is_readable", "_awaitable") - - def __init__(self, doc: _DocumentType, awaitable: bool = False) -> None: - self._server_type = _get_server_type(doc) - self._doc: _DocumentType = doc - self._is_writable = self._server_type in ( - SERVER_TYPE.RSPrimary, - SERVER_TYPE.Standalone, - SERVER_TYPE.Mongos, - SERVER_TYPE.LoadBalancer, - ) - - self._is_readable = self.server_type == SERVER_TYPE.RSSecondary or self._is_writable - self._awaitable = awaitable - - @property - def document(self) -> _DocumentType: - """The complete hello command response document. - - .. versionadded:: 3.4 - """ - return copy.copy(self._doc) - - @property - def server_type(self) -> int: - return self._server_type - - @property - def all_hosts(self) -> set[tuple[str, int]]: - """List of hosts, passives, and arbiters known to this server.""" - return set( - map( - common.clean_node, - itertools.chain( - self._doc.get("hosts", []), - self._doc.get("passives", []), - self._doc.get("arbiters", []), - ), - ) - ) - - @property - def tags(self) -> Mapping[str, Any]: - """Replica set member tags or empty dict.""" - return self._doc.get("tags", {}) - - @property - def primary(self) -> Optional[tuple[str, int]]: - """This server's opinion about who the primary is, or None.""" - if self._doc.get("primary"): - return common.partition_node(self._doc["primary"]) - else: - return None - - @property - def replica_set_name(self) -> Optional[str]: - """Replica set name or None.""" - return self._doc.get("setName") - - @property - def max_bson_size(self) -> int: - return self._doc.get("maxBsonObjectSize", common.MAX_BSON_SIZE) - - @property - def max_message_size(self) -> int: - return self._doc.get("maxMessageSizeBytes", 2 * self.max_bson_size) - - @property - def max_write_batch_size(self) -> int: - return self._doc.get("maxWriteBatchSize", common.MAX_WRITE_BATCH_SIZE) - - @property - def min_wire_version(self) -> int: - return self._doc.get("minWireVersion", common.MIN_WIRE_VERSION) - - @property - def max_wire_version(self) -> int: - return self._doc.get("maxWireVersion", common.MAX_WIRE_VERSION) - - @property - def set_version(self) -> Optional[int]: - return self._doc.get("setVersion") - - @property - def election_id(self) -> Optional[ObjectId]: - return self._doc.get("electionId") - - @property - def cluster_time(self) -> Optional[ClusterTime]: - return self._doc.get("$clusterTime") - - @property - def logical_session_timeout_minutes(self) -> Optional[int]: - return self._doc.get("logicalSessionTimeoutMinutes") - - @property - def is_writable(self) -> bool: - return self._is_writable - - @property - def is_readable(self) -> bool: - return self._is_readable - - @property - def me(self) -> Optional[tuple[str, int]]: - me = self._doc.get("me") - if me: - return common.clean_node(me) - return None - - @property - def last_write_date(self) -> Optional[datetime.datetime]: - return self._doc.get("lastWrite", {}).get("lastWriteDate") - - @property - def compressors(self) -> Optional[list[str]]: - return self._doc.get("compression") - - @property - def sasl_supported_mechs(self) -> list[str]: - """Supported authentication mechanisms for the current user. - - For example:: - - >>> hello.sasl_supported_mechs - ["SCRAM-SHA-1", "SCRAM-SHA-256"] - - """ - return self._doc.get("saslSupportedMechs", []) - - @property - def speculative_authenticate(self) -> Optional[Mapping[str, Any]]: - """The speculativeAuthenticate field.""" - return self._doc.get("speculativeAuthenticate") - - @property - def topology_version(self) -> Optional[Mapping[str, Any]]: - return self._doc.get("topologyVersion") - - @property - def awaitable(self) -> bool: - return self._awaitable - - @property - def service_id(self) -> Optional[ObjectId]: - return self._doc.get("serviceId") - - @property - def hello_ok(self) -> bool: - return self._doc.get("helloOk", False) - - @property - def connection_id(self) -> Optional[int]: - return self._doc.get("connectionId") diff --git a/pymongo/synchronous/hello_compat.py b/pymongo/synchronous/hello_compat.py deleted file mode 100644 index 126ed4bf5..000000000 --- a/pymongo/synchronous/hello_compat.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The HelloCompat class, placed here to break circular import issues.""" -from __future__ import annotations - -_IS_SYNC = True - - -class HelloCompat: - CMD = "hello" - LEGACY_CMD = "ismaster" - PRIMARY = "isWritablePrimary" - LEGACY_PRIMARY = "ismaster" - LEGACY_ERROR = "not master" diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index 3f0df2a3c..e6bbf5d51 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -1,4 +1,4 @@ -# Copyright 2009-present MongoDB, Inc. +# Copyright 2024-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,270 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Bits and pieces used by the driver that don't really fit elsewhere.""" +"""Miscellaneous pieces that need to be synchronized.""" from __future__ import annotations import builtins import sys -import traceback -from collections import abc from typing import ( - TYPE_CHECKING, Any, Callable, - Container, - Iterable, - Mapping, - NoReturn, - Optional, - Sequence, TypeVar, - Union, cast, ) -from pymongo import ASCENDING from pymongo.errors import ( - CursorNotFound, - DuplicateKeyError, - ExecutionTimeout, - NotPrimaryError, OperationFailure, - WriteConcernError, - WriteError, - WTimeoutError, - _wtimeout_error, ) -from pymongo.helpers_constants import _NOT_PRIMARY_CODES, _REAUTHENTICATION_REQUIRED_CODE -from pymongo.synchronous.hello_compat import HelloCompat - -if TYPE_CHECKING: - from pymongo.cursor_shared import _Hint - from pymongo.synchronous.operations import _IndexList - from pymongo.synchronous.typings import _DocumentOut +from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE _IS_SYNC = True - -def _gen_index_name(keys: _IndexList) -> str: - """Generate an index name from the set of fields it is over.""" - return "_".join(["{}_{}".format(*item) for item in keys]) - - -def _index_list( - key_or_list: _Hint, direction: Optional[Union[int, str]] = None -) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]: - """Helper to generate a list of (key, direction) pairs. - - Takes such a list, or a single key, or a single key and direction. - """ - if direction is not None: - if not isinstance(key_or_list, str): - raise TypeError("Expected a string and a direction") - return [(key_or_list, direction)] - else: - if isinstance(key_or_list, str): - return [(key_or_list, ASCENDING)] - elif isinstance(key_or_list, abc.ItemsView): - return list(key_or_list) # type: ignore[arg-type] - elif isinstance(key_or_list, abc.Mapping): - return list(key_or_list.items()) - elif not isinstance(key_or_list, (list, tuple)): - raise TypeError("if no direction is specified, key_or_list must be an instance of list") - values: list[tuple[str, int]] = [] - for item in key_or_list: - if isinstance(item, str): - item = (item, ASCENDING) # noqa: PLW2901 - values.append(item) - return values - - -def _index_document(index_list: _IndexList) -> dict[str, Any]: - """Helper to generate an index specifying document. - - Takes a list of (key, direction) pairs. - """ - if not isinstance(index_list, (list, tuple, abc.Mapping)): - raise TypeError( - "must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list) - ) - if not len(index_list): - raise ValueError("key_or_list must not be empty") - - index: dict[str, Any] = {} - - if isinstance(index_list, abc.Mapping): - for key in index_list: - value = index_list[key] - _validate_index_key_pair(key, value) - index[key] = value - else: - for item in index_list: - if isinstance(item, str): - item = (item, ASCENDING) # noqa: PLW2901 - key, value = item - _validate_index_key_pair(key, value) - index[key] = value - return index - - -def _validate_index_key_pair(key: Any, value: Any) -> None: - if not isinstance(key, str): - raise TypeError("first item in each key pair must be an instance of str") - if not isinstance(value, (str, int, abc.Mapping)): - raise TypeError( - "second item in each key pair must be 1, -1, " - "'2d', or another valid MongoDB index specifier." - ) - - -def _check_command_response( - response: _DocumentOut, - max_wire_version: Optional[int], - allowable_errors: Optional[Container[Union[int, str]]] = None, - parse_write_concern_error: bool = False, -) -> None: - """Check the response to a command for errors.""" - if "ok" not in response: - # Server didn't recognize our message as a command. - raise OperationFailure( - response.get("$err"), # type: ignore[arg-type] - response.get("code"), - response, - max_wire_version, - ) - - if parse_write_concern_error and "writeConcernError" in response: - _error = response["writeConcernError"] - _labels = response.get("errorLabels") - if _labels: - _error.update({"errorLabels": _labels}) - _raise_write_concern_error(_error) - - if response["ok"]: - return - - details = response - # Mongos returns the error details in a 'raw' object - # for some errors. - if "raw" in response: - for shard in response["raw"].values(): - # Grab the first non-empty raw error from a shard. - if shard.get("errmsg") and not shard.get("ok"): - details = shard - break - - errmsg = details["errmsg"] - code = details.get("code") - - # For allowable errors, only check for error messages when the code is not - # included. - if allowable_errors: - if code is not None: - if code in allowable_errors: - return - elif errmsg in allowable_errors: - return - - # Server is "not primary" or "recovering" - if code is not None: - if code in _NOT_PRIMARY_CODES: - raise NotPrimaryError(errmsg, response) - elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg: - raise NotPrimaryError(errmsg, response) - - # Other errors - # findAndModify with upsert can raise duplicate key error - if code in (11000, 11001, 12582): - raise DuplicateKeyError(errmsg, code, response, max_wire_version) - elif code == 50: - raise ExecutionTimeout(errmsg, code, response, max_wire_version) - elif code == 43: - raise CursorNotFound(errmsg, code, response, max_wire_version) - - raise OperationFailure(errmsg, code, response, max_wire_version) - - -def _raise_last_write_error(write_errors: list[Any]) -> NoReturn: - # If the last batch had multiple errors only report - # the last error to emulate continue_on_error. - error = write_errors[-1] - if error.get("code") == 11000: - raise DuplicateKeyError(error.get("errmsg"), 11000, error) - raise WriteError(error.get("errmsg"), error.get("code"), error) - - -def _raise_write_concern_error(error: Any) -> NoReturn: - if _wtimeout_error(error): - # Make sure we raise WTimeoutError - raise WTimeoutError(error.get("errmsg"), error.get("code"), error) - raise WriteConcernError(error.get("errmsg"), error.get("code"), error) - - -def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]: - """Return the writeConcernError or None.""" - wce = result.get("writeConcernError") - if wce: - # The server reports errorLabels at the top level but it's more - # convenient to attach it to the writeConcernError doc itself. - error_labels = result.get("errorLabels") - if error_labels: - # Copy to avoid changing the original document. - wce = wce.copy() - wce["errorLabels"] = error_labels - return wce - - -def _check_write_command_response(result: Mapping[str, Any]) -> None: - """Backward compatibility helper for write command error handling.""" - # Prefer write errors over write concern errors - write_errors = result.get("writeErrors") - if write_errors: - _raise_last_write_error(write_errors) - - wce = _get_wce_doc(result) - if wce: - _raise_write_concern_error(wce) - - -def _fields_list_to_dict( - fields: Union[Mapping[str, Any], Iterable[str]], option_name: str -) -> Mapping[str, Any]: - """Takes a sequence of field names and returns a matching dictionary. - - ["a", "b"] becomes {"a": 1, "b": 1} - - and - - ["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1} - """ - if isinstance(fields, abc.Mapping): - return fields - - if isinstance(fields, (abc.Sequence, abc.Set)): - if not all(isinstance(field, str) for field in fields): - raise TypeError(f"{option_name} must be a list of key names, each an instance of str") - return dict.fromkeys(fields, 1) - - raise TypeError(f"{option_name} must be a mapping or list of key names") - - -def _handle_exception() -> None: - """Print exceptions raised by subscribers to stderr.""" - # Heavily influenced by logging.Handler.handleError. - - # See note here: - # https://docs.python.org/3.4/library/sys.html#sys.__stderr__ - if sys.stderr: - einfo = sys.exc_info() - try: - traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr) - except OSError: - pass - finally: - del einfo - - # See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories F = TypeVar("F", bound=Callable[..., Any]) @@ -283,7 +38,7 @@ F = TypeVar("F", bound=Callable[..., Any]) def _handle_reauth(func: F) -> F: def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) - from pymongo.synchronous.message import _BulkWriteContext + from pymongo.message import _BulkWriteContext from pymongo.synchronous.pool import Connection try: @@ -301,7 +56,7 @@ def _handle_reauth(func: F) -> F: conn = arg break if isinstance(arg, _BulkWriteContext): - conn = arg.conn + conn = arg.conn # type: ignore[assignment] break if conn: conn.authenticate(reauthenticate=True) @@ -318,5 +73,5 @@ if sys.version_info >= (3, 10): else: def next(cls: Any) -> Any: - """Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext.""" + """Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next.""" return cls.__next__() diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 31b9cbc5f..6b9f6231b 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -57,7 +57,8 @@ from typing import ( from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, helpers_constants +from pymongo import _csot, common, helpers_shared, uri_parser +from pymongo.client_options import ClientOptions from pymongo.errors import ( AutoReconnect, BulkWriteError, @@ -72,29 +73,21 @@ from pymongo.errors import ( WriteConcernError, ) from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks +from pymongo.logger import _CLIENT_LOGGER, _log_or_warn +from pymongo.message import _CursorAddress, _GetMore, _Query +from pymongo.monitoring import ConnectionClosedReason +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import ( - client_session, - common, - database, - helpers, - message, - periodic_executor, - uri_parser, -) +from pymongo.synchronous import client_session, database, periodic_executor from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream -from pymongo.synchronous.client_options import ClientOptions from pymongo.synchronous.client_session import _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.logger import _CLIENT_LOGGER, _log_or_warn -from pymongo.synchronous.monitoring import ConnectionClosedReason -from pymongo.synchronous.operations import _Op -from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode -from pymongo.synchronous.server_selectors import writable_server_selector from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE, TopologyDescription -from pymongo.synchronous.typings import ( +from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription +from pymongo.typings import ( ClusterTime, _Address, _CollationIn, @@ -102,7 +95,7 @@ from pymongo.synchronous.typings import ( _DocumentTypeArg, _Pipeline, ) -from pymongo.synchronous.uri_parser import ( +from pymongo.uri_parser import ( _check_options, _handle_option_deprecations, _handle_security_options, @@ -115,20 +108,22 @@ if TYPE_CHECKING: from bson.objectid import ObjectId from pymongo.read_concern import ReadConcern + from pymongo.response import Response + from pymongo.server_selectors import Selection from pymongo.synchronous.bulk import _Bulk from pymongo.synchronous.client_session import ClientSession, _ServerSession from pymongo.synchronous.cursor import _ConnectionManager - from pymongo.synchronous.message import _CursorAddress, _GetMore, _Query from pymongo.synchronous.pool import Connection - from pymongo.synchronous.response import Response from pymongo.synchronous.server import Server - from pymongo.synchronous.server_selectors import Selection T = TypeVar("T") _WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], T] -_ReadCall = Callable[[Optional["ClientSession"], "Server", "Connection", _ServerMode], T] +_ReadCall = Callable[ + [Optional["ClientSession"], "Server", "Connection", _ServerMode], + T, +] _IS_SYNC = True @@ -1669,13 +1664,13 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): if operation.conn_mgr: server = self._select_server( operation.read_preference, - operation.session, + operation.session, # type: ignore[arg-type] operation.name, address=address, ) with operation.conn_mgr._alock: - with _MongoClientErrorHandler(self, server, operation.session) as err_handler: + with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) return server.run_operation( operation.conn_mgr.conn, @@ -1705,9 +1700,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): return self._retryable_read( _cmd, operation.read_preference, - operation.session, + operation.session, # type: ignore[arg-type] address=address, - retryable=isinstance(operation, message._Query), + retryable=isinstance(operation, _Query), operation=operation.name, ) @@ -1991,7 +1986,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): # can be caught in _process_periodic_tasks raise else: - helpers._handle_exception() + helpers_shared._handle_exception() # Don't re-open topology if it's closed and there's no pending cursors. if address_to_cursor_ids: @@ -2003,7 +1998,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): if isinstance(exc, InvalidOperation) and self._topology._closed: raise else: - helpers._handle_exception() + helpers_shared._handle_exception() # This method is run periodically by a background thread. def _process_periodic_tasks(self) -> None: @@ -2017,7 +2012,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): if isinstance(exc, InvalidOperation) and self._topology._closed: return else: - helpers._handle_exception() + helpers_shared._handle_exception() def _return_server_session( self, server_session: Union[_ServerSession, _EmptyServerSession] @@ -2227,7 +2222,7 @@ def _add_retryable_write_error(exc: PyMongoError, max_wire_version: int, is_mong # Do not consult writeConcernError for pre-4.4 mongos. if isinstance(exc, WriteConcernError) and is_mongos: pass - elif code in helpers_constants._RETRYABLE_ERROR_CODES: + elif code in helpers_shared._RETRYABLE_ERROR_CODES: exc._add_error_label("RetryableWriteError") # Connection errors are always retryable except NotPrimaryError and WaitQueueTimeoutError which is @@ -2383,7 +2378,7 @@ class _ClientConnectionRetryable(Generic[T]): exc_code = getattr(exc, "code", None) if self._is_not_eligible_for_retry() or ( isinstance(exc, OperationFailure) - and exc_code not in helpers_constants._RETRYABLE_ERROR_CODES + and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES ): raise self._retrying = True diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index 96849e734..8106c1922 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -21,16 +21,17 @@ import time import weakref from typing import TYPE_CHECKING, Any, Mapping, Optional, cast +from pymongo import common from pymongo._csot import MovingMinimum from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled +from pymongo.hello import Hello from pymongo.lock import _create_lock -from pymongo.synchronous import common, periodic_executor -from pymongo.synchronous.hello import Hello +from pymongo.pool_options import _is_faas +from pymongo.read_preferences import MovingAverage +from pymongo.server_description import ServerDescription +from pymongo.srv_resolver import _SrvResolver +from pymongo.synchronous import periodic_executor from pymongo.synchronous.periodic_executor import _shutdown_executors -from pymongo.synchronous.pool import _is_faas -from pymongo.synchronous.read_preferences import MovingAverage -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.srv_resolver import _SrvResolver if TYPE_CHECKING: from pymongo.synchronous.pool import Connection, Pool, _CancellationContext diff --git a/pymongo/synchronous/monitoring.py b/pymongo/synchronous/monitoring.py deleted file mode 100644 index a4b729688..000000000 --- a/pymongo/synchronous/monitoring.py +++ /dev/null @@ -1,1903 +0,0 @@ -# Copyright 2015-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Tools to monitor driver events. - -.. versionadded:: 3.1 - -.. attention:: Starting in PyMongo 3.11, the monitoring classes outlined below - are included in the PyMongo distribution under the - :mod:`~pymongo.event_loggers` submodule. - -Use :func:`register` to register global listeners for specific events. -Listeners must inherit from one of the abstract classes below and implement -the correct functions for that class. - -For example, a simple command logger might be implemented like this:: - - import logging - - from pymongo import monitoring - - class CommandLogger(monitoring.CommandListener): - - def started(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} started on server " - "{0.connection_id}".format(event)) - - def succeeded(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} on server {0.connection_id} " - "succeeded in {0.duration_micros} " - "microseconds".format(event)) - - def failed(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} on server {0.connection_id} " - "failed in {0.duration_micros} " - "microseconds".format(event)) - - monitoring.register(CommandLogger()) - -Server discovery and monitoring events are also available. For example:: - - class ServerLogger(monitoring.ServerListener): - - def opened(self, event): - logging.info("Server {0.server_address} added to topology " - "{0.topology_id}".format(event)) - - def description_changed(self, event): - previous_server_type = event.previous_description.server_type - new_server_type = event.new_description.server_type - if new_server_type != previous_server_type: - # server_type_name was added in PyMongo 3.4 - logging.info( - "Server {0.server_address} changed type from " - "{0.previous_description.server_type_name} to " - "{0.new_description.server_type_name}".format(event)) - - def closed(self, event): - logging.warning("Server {0.server_address} removed from topology " - "{0.topology_id}".format(event)) - - - class HeartbeatLogger(monitoring.ServerHeartbeatListener): - - def started(self, event): - logging.info("Heartbeat sent to server " - "{0.connection_id}".format(event)) - - def succeeded(self, event): - # The reply.document attribute was added in PyMongo 3.4. - logging.info("Heartbeat to server {0.connection_id} " - "succeeded with reply " - "{0.reply.document}".format(event)) - - def failed(self, event): - logging.warning("Heartbeat to server {0.connection_id} " - "failed with error {0.reply}".format(event)) - - class TopologyLogger(monitoring.TopologyListener): - - def opened(self, event): - logging.info("Topology with id {0.topology_id} " - "opened".format(event)) - - def description_changed(self, event): - logging.info("Topology description updated for " - "topology id {0.topology_id}".format(event)) - previous_topology_type = event.previous_description.topology_type - new_topology_type = event.new_description.topology_type - if new_topology_type != previous_topology_type: - # topology_type_name was added in PyMongo 3.4 - logging.info( - "Topology {0.topology_id} changed type from " - "{0.previous_description.topology_type_name} to " - "{0.new_description.topology_type_name}".format(event)) - # The has_writable_server and has_readable_server methods - # were added in PyMongo 3.4. - if not event.new_description.has_writable_server(): - logging.warning("No writable servers available.") - if not event.new_description.has_readable_server(): - logging.warning("No readable servers available.") - - def closed(self, event): - logging.info("Topology with id {0.topology_id} " - "closed".format(event)) - -Connection monitoring and pooling events are also available. For example:: - - class ConnectionPoolLogger(ConnectionPoolListener): - - def pool_created(self, event): - logging.info("[pool {0.address}] pool created".format(event)) - - def pool_ready(self, event): - logging.info("[pool {0.address}] pool is ready".format(event)) - - def pool_cleared(self, event): - logging.info("[pool {0.address}] pool cleared".format(event)) - - def pool_closed(self, event): - logging.info("[pool {0.address}] pool closed".format(event)) - - def connection_created(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection created".format(event)) - - def connection_ready(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection setup succeeded".format(event)) - - def connection_closed(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection closed, reason: " - "{0.reason}".format(event)) - - def connection_check_out_started(self, event): - logging.info("[pool {0.address}] connection check out " - "started".format(event)) - - def connection_check_out_failed(self, event): - logging.info("[pool {0.address}] connection check out " - "failed, reason: {0.reason}".format(event)) - - def connection_checked_out(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection checked out of pool".format(event)) - - def connection_checked_in(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection checked into pool".format(event)) - - -Event listeners can also be registered per instance of -:class:`~pymongo.mongo_client.MongoClient`:: - - client = MongoClient(event_listeners=[CommandLogger()]) - -Note that previously registered global listeners are automatically included -when configuring per client event listeners. Registering a new global listener -will not add that listener to existing client instances. - -.. note:: Events are delivered **synchronously**. Application threads block - waiting for event handlers (e.g. :meth:`~CommandListener.started`) to - return. Care must be taken to ensure that your event handlers are efficient - enough to not adversely affect overall application performance. - -.. warning:: The command documents published through this API are *not* copies. - If you intend to modify them in any way you must copy them in your event - handler first. -""" - -from __future__ import annotations - -import datetime -from collections import abc, namedtuple -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence - -from bson.objectid import ObjectId -from pymongo.helpers_constants import _SENSITIVE_COMMANDS -from pymongo.synchronous.hello import Hello -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.helpers import _handle_exception -from pymongo.synchronous.typings import _Address, _DocumentOut - -if TYPE_CHECKING: - from datetime import timedelta - - from pymongo.synchronous.server_description import ServerDescription - from pymongo.synchronous.topology_description import TopologyDescription - -_IS_SYNC = True - -_Listeners = namedtuple( - "_Listeners", - ( - "command_listeners", - "server_listeners", - "server_heartbeat_listeners", - "topology_listeners", - "cmap_listeners", - ), -) - -_LISTENERS = _Listeners([], [], [], [], []) - - -class _EventListener: - """Abstract base class for all event listeners.""" - - -class CommandListener(_EventListener): - """Abstract base class for command listeners. - - Handles `CommandStartedEvent`, `CommandSucceededEvent`, - and `CommandFailedEvent`. - """ - - def started(self, event: CommandStartedEvent) -> None: - """Abstract method to handle a `CommandStartedEvent`. - - :param event: An instance of :class:`CommandStartedEvent`. - """ - raise NotImplementedError - - def succeeded(self, event: CommandSucceededEvent) -> None: - """Abstract method to handle a `CommandSucceededEvent`. - - :param event: An instance of :class:`CommandSucceededEvent`. - """ - raise NotImplementedError - - def failed(self, event: CommandFailedEvent) -> None: - """Abstract method to handle a `CommandFailedEvent`. - - :param event: An instance of :class:`CommandFailedEvent`. - """ - raise NotImplementedError - - -class ConnectionPoolListener(_EventListener): - """Abstract base class for connection pool listeners. - - Handles all of the connection pool events defined in the Connection - Monitoring and Pooling Specification: - :class:`PoolCreatedEvent`, :class:`PoolClearedEvent`, - :class:`PoolClosedEvent`, :class:`ConnectionCreatedEvent`, - :class:`ConnectionReadyEvent`, :class:`ConnectionClosedEvent`, - :class:`ConnectionCheckOutStartedEvent`, - :class:`ConnectionCheckOutFailedEvent`, - :class:`ConnectionCheckedOutEvent`, - and :class:`ConnectionCheckedInEvent`. - - .. versionadded:: 3.9 - """ - - def pool_created(self, event: PoolCreatedEvent) -> None: - """Abstract method to handle a :class:`PoolCreatedEvent`. - - Emitted when a connection Pool is created. - - :param event: An instance of :class:`PoolCreatedEvent`. - """ - raise NotImplementedError - - def pool_ready(self, event: PoolReadyEvent) -> None: - """Abstract method to handle a :class:`PoolReadyEvent`. - - Emitted when a connection Pool is marked ready. - - :param event: An instance of :class:`PoolReadyEvent`. - - .. versionadded:: 4.0 - """ - raise NotImplementedError - - def pool_cleared(self, event: PoolClearedEvent) -> None: - """Abstract method to handle a `PoolClearedEvent`. - - Emitted when a connection Pool is cleared. - - :param event: An instance of :class:`PoolClearedEvent`. - """ - raise NotImplementedError - - def pool_closed(self, event: PoolClosedEvent) -> None: - """Abstract method to handle a `PoolClosedEvent`. - - Emitted when a connection Pool is closed. - - :param event: An instance of :class:`PoolClosedEvent`. - """ - raise NotImplementedError - - def connection_created(self, event: ConnectionCreatedEvent) -> None: - """Abstract method to handle a :class:`ConnectionCreatedEvent`. - - Emitted when a connection Pool creates a Connection object. - - :param event: An instance of :class:`ConnectionCreatedEvent`. - """ - raise NotImplementedError - - def connection_ready(self, event: ConnectionReadyEvent) -> None: - """Abstract method to handle a :class:`ConnectionReadyEvent`. - - Emitted when a connection has finished its setup, and is now ready to - use. - - :param event: An instance of :class:`ConnectionReadyEvent`. - """ - raise NotImplementedError - - def connection_closed(self, event: ConnectionClosedEvent) -> None: - """Abstract method to handle a :class:`ConnectionClosedEvent`. - - Emitted when a connection Pool closes a connection. - - :param event: An instance of :class:`ConnectionClosedEvent`. - """ - raise NotImplementedError - - def connection_check_out_started(self, event: ConnectionCheckOutStartedEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckOutStartedEvent`. - - Emitted when the driver starts attempting to check out a connection. - - :param event: An instance of :class:`ConnectionCheckOutStartedEvent`. - """ - raise NotImplementedError - - def connection_check_out_failed(self, event: ConnectionCheckOutFailedEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckOutFailedEvent`. - - Emitted when the driver's attempt to check out a connection fails. - - :param event: An instance of :class:`ConnectionCheckOutFailedEvent`. - """ - raise NotImplementedError - - def connection_checked_out(self, event: ConnectionCheckedOutEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckedOutEvent`. - - Emitted when the driver successfully checks out a connection. - - :param event: An instance of :class:`ConnectionCheckedOutEvent`. - """ - raise NotImplementedError - - def connection_checked_in(self, event: ConnectionCheckedInEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckedInEvent`. - - Emitted when the driver checks in a connection back to the connection - Pool. - - :param event: An instance of :class:`ConnectionCheckedInEvent`. - """ - raise NotImplementedError - - -class ServerHeartbeatListener(_EventListener): - """Abstract base class for server heartbeat listeners. - - Handles `ServerHeartbeatStartedEvent`, `ServerHeartbeatSucceededEvent`, - and `ServerHeartbeatFailedEvent`. - - .. versionadded:: 3.3 - """ - - def started(self, event: ServerHeartbeatStartedEvent) -> None: - """Abstract method to handle a `ServerHeartbeatStartedEvent`. - - :param event: An instance of :class:`ServerHeartbeatStartedEvent`. - """ - raise NotImplementedError - - def succeeded(self, event: ServerHeartbeatSucceededEvent) -> None: - """Abstract method to handle a `ServerHeartbeatSucceededEvent`. - - :param event: An instance of :class:`ServerHeartbeatSucceededEvent`. - """ - raise NotImplementedError - - def failed(self, event: ServerHeartbeatFailedEvent) -> None: - """Abstract method to handle a `ServerHeartbeatFailedEvent`. - - :param event: An instance of :class:`ServerHeartbeatFailedEvent`. - """ - raise NotImplementedError - - -class TopologyListener(_EventListener): - """Abstract base class for topology monitoring listeners. - Handles `TopologyOpenedEvent`, `TopologyDescriptionChangedEvent`, and - `TopologyClosedEvent`. - - .. versionadded:: 3.3 - """ - - def opened(self, event: TopologyOpenedEvent) -> None: - """Abstract method to handle a `TopologyOpenedEvent`. - - :param event: An instance of :class:`TopologyOpenedEvent`. - """ - raise NotImplementedError - - def description_changed(self, event: TopologyDescriptionChangedEvent) -> None: - """Abstract method to handle a `TopologyDescriptionChangedEvent`. - - :param event: An instance of :class:`TopologyDescriptionChangedEvent`. - """ - raise NotImplementedError - - def closed(self, event: TopologyClosedEvent) -> None: - """Abstract method to handle a `TopologyClosedEvent`. - - :param event: An instance of :class:`TopologyClosedEvent`. - """ - raise NotImplementedError - - -class ServerListener(_EventListener): - """Abstract base class for server listeners. - Handles `ServerOpeningEvent`, `ServerDescriptionChangedEvent`, and - `ServerClosedEvent`. - - .. versionadded:: 3.3 - """ - - def opened(self, event: ServerOpeningEvent) -> None: - """Abstract method to handle a `ServerOpeningEvent`. - - :param event: An instance of :class:`ServerOpeningEvent`. - """ - raise NotImplementedError - - def description_changed(self, event: ServerDescriptionChangedEvent) -> None: - """Abstract method to handle a `ServerDescriptionChangedEvent`. - - :param event: An instance of :class:`ServerDescriptionChangedEvent`. - """ - raise NotImplementedError - - def closed(self, event: ServerClosedEvent) -> None: - """Abstract method to handle a `ServerClosedEvent`. - - :param event: An instance of :class:`ServerClosedEvent`. - """ - raise NotImplementedError - - -def _to_micros(dur: timedelta) -> int: - """Convert duration 'dur' to microseconds.""" - return int(dur.total_seconds() * 10e5) - - -def _validate_event_listeners( - option: str, listeners: Sequence[_EventListeners] -) -> Sequence[_EventListeners]: - """Validate event listeners""" - if not isinstance(listeners, abc.Sequence): - raise TypeError(f"{option} must be a list or tuple") - for listener in listeners: - if not isinstance(listener, _EventListener): - raise TypeError( - f"Listeners for {option} must be either a " - "CommandListener, ServerHeartbeatListener, " - "ServerListener, TopologyListener, or " - "ConnectionPoolListener." - ) - return listeners - - -def register(listener: _EventListener) -> None: - """Register a global event listener. - - :param listener: A subclasses of :class:`CommandListener`, - :class:`ServerHeartbeatListener`, :class:`ServerListener`, - :class:`TopologyListener`, or :class:`ConnectionPoolListener`. - """ - if not isinstance(listener, _EventListener): - raise TypeError( - f"Listeners for {listener} must be either a " - "CommandListener, ServerHeartbeatListener, " - "ServerListener, TopologyListener, or " - "ConnectionPoolListener." - ) - if isinstance(listener, CommandListener): - _LISTENERS.command_listeners.append(listener) - if isinstance(listener, ServerHeartbeatListener): - _LISTENERS.server_heartbeat_listeners.append(listener) - if isinstance(listener, ServerListener): - _LISTENERS.server_listeners.append(listener) - if isinstance(listener, TopologyListener): - _LISTENERS.topology_listeners.append(listener) - if isinstance(listener, ConnectionPoolListener): - _LISTENERS.cmap_listeners.append(listener) - - -# The "hello" command is also deemed sensitive when attempting speculative -# authentication. -def _is_speculative_authenticate(command_name: str, doc: Mapping[str, Any]) -> bool: - if ( - command_name.lower() in ("hello", HelloCompat.LEGACY_CMD) - and "speculativeAuthenticate" in doc - ): - return True - return False - - -class _CommandEvent: - """Base class for command events.""" - - __slots__ = ( - "__cmd_name", - "__rqst_id", - "__conn_id", - "__op_id", - "__service_id", - "__db", - "__server_conn_id", - ) - - def __init__( - self, - command_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - database_name: str = "", - server_connection_id: Optional[int] = None, - ) -> None: - self.__cmd_name = command_name - self.__rqst_id = request_id - self.__conn_id = connection_id - self.__op_id = operation_id - self.__service_id = service_id - self.__db = database_name - self.__server_conn_id = server_connection_id - - @property - def command_name(self) -> str: - """The command name.""" - return self.__cmd_name - - @property - def request_id(self) -> int: - """The request id for this operation.""" - return self.__rqst_id - - @property - def connection_id(self) -> _Address: - """The address (host, port) of the server this command was sent to.""" - return self.__conn_id - - @property - def service_id(self) -> Optional[ObjectId]: - """The service_id this command was sent to, or ``None``. - - .. versionadded:: 3.12 - """ - return self.__service_id - - @property - def operation_id(self) -> Optional[int]: - """An id for this series of events or None.""" - return self.__op_id - - @property - def database_name(self) -> str: - """The database_name this command was sent to, or ``""``. - - .. versionadded:: 4.6 - """ - return self.__db - - @property - def server_connection_id(self) -> Optional[int]: - """The server-side connection id for the connection this command was sent on, or ``None``. - - .. versionadded:: 4.7 - """ - return self.__server_conn_id - - -class CommandStartedEvent(_CommandEvent): - """Event published when a command starts. - - :param command: The command document. - :param database_name: The name of the database this command was run against. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this command - was sent to. - :param operation_id: An optional identifier for a series of related events. - :param service_id: The service_id this command was sent to, or ``None``. - """ - - __slots__ = ("__cmd",) - - def __init__( - self, - command: _DocumentOut, - database_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - server_connection_id: Optional[int] = None, - ) -> None: - if not command: - raise ValueError(f"{command!r} is not a valid command") - # Command name must be first key. - command_name = next(iter(command)) - super().__init__( - command_name, - request_id, - connection_id, - operation_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - cmd_name = command_name.lower() - if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, command): - self.__cmd: _DocumentOut = {} - else: - self.__cmd = command - - @property - def command(self) -> _DocumentOut: - """The command document.""" - return self.__cmd - - @property - def database_name(self) -> str: - """The name of the database this command was run against.""" - return super().database_name - - def __repr__(self) -> str: - return ( - "<{} {} db: {!r}, command: {!r}, operation_id: {}, service_id: {}, server_connection_id: {}>" - ).format( - self.__class__.__name__, - self.connection_id, - self.database_name, - self.command_name, - self.operation_id, - self.service_id, - self.server_connection_id, - ) - - -class CommandSucceededEvent(_CommandEvent): - """Event published when a command succeeds. - - :param duration: The command duration as a datetime.timedelta. - :param reply: The server reply document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this command - was sent to. - :param operation_id: An optional identifier for a series of related events. - :param service_id: The service_id this command was sent to, or ``None``. - :param database_name: The database this command was sent to, or ``""``. - """ - - __slots__ = ("__duration_micros", "__reply") - - def __init__( - self, - duration: datetime.timedelta, - reply: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - database_name: str = "", - server_connection_id: Optional[int] = None, - ) -> None: - super().__init__( - command_name, - request_id, - connection_id, - operation_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - self.__duration_micros = _to_micros(duration) - cmd_name = command_name.lower() - if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, reply): - self.__reply: _DocumentOut = {} - else: - self.__reply = reply - - @property - def duration_micros(self) -> int: - """The duration of this operation in microseconds.""" - return self.__duration_micros - - @property - def reply(self) -> _DocumentOut: - """The server failure document for this operation.""" - return self.__reply - - def __repr__(self) -> str: - return ( - "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, service_id: {}, server_connection_id: {}>" - ).format( - self.__class__.__name__, - self.connection_id, - self.database_name, - self.command_name, - self.operation_id, - self.duration_micros, - self.service_id, - self.server_connection_id, - ) - - -class CommandFailedEvent(_CommandEvent): - """Event published when a command fails. - - :param duration: The command duration as a datetime.timedelta. - :param failure: The server reply document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this command - was sent to. - :param operation_id: An optional identifier for a series of related events. - :param service_id: The service_id this command was sent to, or ``None``. - :param database_name: The database this command was sent to, or ``""``. - """ - - __slots__ = ("__duration_micros", "__failure") - - def __init__( - self, - duration: datetime.timedelta, - failure: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - database_name: str = "", - server_connection_id: Optional[int] = None, - ) -> None: - super().__init__( - command_name, - request_id, - connection_id, - operation_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - self.__duration_micros = _to_micros(duration) - self.__failure = failure - - @property - def duration_micros(self) -> int: - """The duration of this operation in microseconds.""" - return self.__duration_micros - - @property - def failure(self) -> _DocumentOut: - """The server failure document for this operation.""" - return self.__failure - - def __repr__(self) -> str: - return ( - "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, " - "failure: {!r}, service_id: {}, server_connection_id: {}>" - ).format( - self.__class__.__name__, - self.connection_id, - self.database_name, - self.command_name, - self.operation_id, - self.duration_micros, - self.failure, - self.service_id, - self.server_connection_id, - ) - - -class _PoolEvent: - """Base class for pool events.""" - - __slots__ = ("__address",) - - def __init__(self, address: _Address) -> None: - self.__address = address - - @property - def address(self) -> _Address: - """The address (host, port) pair of the server the pool is attempting - to connect to. - """ - return self.__address - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.__address!r})" - - -class PoolCreatedEvent(_PoolEvent): - """Published when a Connection Pool is created. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__options",) - - def __init__(self, address: _Address, options: dict[str, Any]) -> None: - super().__init__(address) - self.__options = options - - @property - def options(self) -> dict[str, Any]: - """Any non-default pool options that were set on this Connection Pool.""" - return self.__options - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__options!r})" - - -class PoolReadyEvent(_PoolEvent): - """Published when a Connection Pool is marked ready. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - - .. versionadded:: 4.0 - """ - - __slots__ = () - - -class PoolClearedEvent(_PoolEvent): - """Published when a Connection Pool is cleared. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - :param service_id: The service_id this command was sent to, or ``None``. - :param interrupt_connections: True if all active connections were interrupted by the Pool during clearing. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__service_id", "__interrupt_connections") - - def __init__( - self, - address: _Address, - service_id: Optional[ObjectId] = None, - interrupt_connections: bool = False, - ) -> None: - super().__init__(address) - self.__service_id = service_id - self.__interrupt_connections = interrupt_connections - - @property - def service_id(self) -> Optional[ObjectId]: - """Connections with this service_id are cleared. - - When service_id is ``None``, all connections in the pool are cleared. - - .. versionadded:: 3.12 - """ - return self.__service_id - - @property - def interrupt_connections(self) -> bool: - """If True, active connections are interrupted during clearing. - - .. versionadded:: 4.7 - """ - return self.__interrupt_connections - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__service_id!r}, {self.__interrupt_connections!r})" - - -class PoolClosedEvent(_PoolEvent): - """Published when a Connection Pool is closed. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionClosedReason: - """An enum that defines values for `reason` on a - :class:`ConnectionClosedEvent`. - - .. versionadded:: 3.9 - """ - - STALE = "stale" - """The pool was cleared, making the connection no longer valid.""" - - IDLE = "idle" - """The connection became stale by being idle for too long (maxIdleTimeMS). - """ - - ERROR = "error" - """The connection experienced an error, making it no longer valid.""" - - POOL_CLOSED = "poolClosed" - """The pool was closed, making the connection no longer valid.""" - - -class ConnectionCheckOutFailedReason: - """An enum that defines values for `reason` on a - :class:`ConnectionCheckOutFailedEvent`. - - .. versionadded:: 3.9 - """ - - TIMEOUT = "timeout" - """The connection check out attempt exceeded the specified timeout.""" - - POOL_CLOSED = "poolClosed" - """The pool was previously closed, and cannot provide new connections.""" - - CONN_ERROR = "connectionError" - """The connection check out attempt experienced an error while setting up - a new connection. - """ - - -class _ConnectionEvent: - """Private base class for connection events.""" - - __slots__ = ("__address",) - - def __init__(self, address: _Address) -> None: - self.__address = address - - @property - def address(self) -> _Address: - """The address (host, port) pair of the server this connection is - attempting to connect to. - """ - return self.__address - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.__address!r})" - - -class _ConnectionIdEvent(_ConnectionEvent): - """Private base class for connection events with an id.""" - - __slots__ = ("__connection_id",) - - def __init__(self, address: _Address, connection_id: int) -> None: - super().__init__(address) - self.__connection_id = connection_id - - @property - def connection_id(self) -> int: - """The ID of the connection.""" - return self.__connection_id - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__connection_id!r})" - - -class _ConnectionDurationEvent(_ConnectionIdEvent): - """Private base class for connection events with a duration.""" - - __slots__ = ("__duration",) - - def __init__(self, address: _Address, connection_id: int, duration: Optional[float]) -> None: - super().__init__(address, connection_id) - self.__duration = duration - - @property - def duration(self) -> Optional[float]: - """The duration of the connection event. - - .. versionadded:: 4.7 - """ - return self.__duration - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.connection_id!r}, {self.__duration!r})" - - -class ConnectionCreatedEvent(_ConnectionIdEvent): - """Published when a Connection Pool creates a Connection object. - - NOTE: This connection is not ready for use until the - :class:`ConnectionReadyEvent` is published. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionReadyEvent(_ConnectionDurationEvent): - """Published when a Connection has finished its setup, and is ready to use. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionClosedEvent(_ConnectionIdEvent): - """Published when a Connection is closed. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - :param reason: A reason explaining why this connection was closed. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__reason",) - - def __init__(self, address: _Address, connection_id: int, reason: str): - super().__init__(address, connection_id) - self.__reason = reason - - @property - def reason(self) -> str: - """A reason explaining why this connection was closed. - - The reason must be one of the strings from the - :class:`ConnectionClosedReason` enum. - """ - return self.__reason - - def __repr__(self) -> str: - return "{}({!r}, {!r}, {!r})".format( - self.__class__.__name__, - self.address, - self.connection_id, - self.__reason, - ) - - -class ConnectionCheckOutStartedEvent(_ConnectionEvent): - """Published when the driver starts attempting to check out a connection. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionCheckOutFailedEvent(_ConnectionDurationEvent): - """Published when the driver's attempt to check out a connection fails. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param reason: A reason explaining why connection check out failed. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__reason",) - - def __init__(self, address: _Address, reason: str, duration: Optional[float]) -> None: - super().__init__(address=address, connection_id=0, duration=duration) - self.__reason = reason - - @property - def reason(self) -> str: - """A reason explaining why connection check out failed. - - The reason must be one of the strings from the - :class:`ConnectionCheckOutFailedReason` enum. - """ - return self.__reason - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__reason!r}, {self.duration!r})" - - -class ConnectionCheckedOutEvent(_ConnectionDurationEvent): - """Published when the driver successfully checks out a connection. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionCheckedInEvent(_ConnectionIdEvent): - """Published when the driver checks in a Connection into the Pool. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class _ServerEvent: - """Base class for server events.""" - - __slots__ = ("__server_address", "__topology_id") - - def __init__(self, server_address: _Address, topology_id: ObjectId) -> None: - self.__server_address = server_address - self.__topology_id = topology_id - - @property - def server_address(self) -> _Address: - """The address (host, port) pair of the server""" - return self.__server_address - - @property - def topology_id(self) -> ObjectId: - """A unique identifier for the topology this server is a part of.""" - return self.__topology_id - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} {self.server_address} topology_id: {self.topology_id}>" - - -class ServerDescriptionChangedEvent(_ServerEvent): - """Published when server description changes. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__previous_description", "__new_description") - - def __init__( - self, - previous_description: ServerDescription, - new_description: ServerDescription, - *args: Any, - ) -> None: - super().__init__(*args) - self.__previous_description = previous_description - self.__new_description = new_description - - @property - def previous_description(self) -> ServerDescription: - """The previous - :class:`~pymongo.server_description.ServerDescription`. - """ - return self.__previous_description - - @property - def new_description(self) -> ServerDescription: - """The new - :class:`~pymongo.server_description.ServerDescription`. - """ - return self.__new_description - - def __repr__(self) -> str: - return "<{} {} changed from: {}, to: {}>".format( - self.__class__.__name__, - self.server_address, - self.previous_description, - self.new_description, - ) - - -class ServerOpeningEvent(_ServerEvent): - """Published when server is initialized. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class ServerClosedEvent(_ServerEvent): - """Published when server is closed. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class TopologyEvent: - """Base class for topology description events.""" - - __slots__ = ("__topology_id",) - - def __init__(self, topology_id: ObjectId) -> None: - self.__topology_id = topology_id - - @property - def topology_id(self) -> ObjectId: - """A unique identifier for the topology this server is a part of.""" - return self.__topology_id - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} topology_id: {self.topology_id}>" - - -class TopologyDescriptionChangedEvent(TopologyEvent): - """Published when the topology description changes. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__previous_description", "__new_description") - - def __init__( - self, - previous_description: TopologyDescription, - new_description: TopologyDescription, - *args: Any, - ) -> None: - super().__init__(*args) - self.__previous_description = previous_description - self.__new_description = new_description - - @property - def previous_description(self) -> TopologyDescription: - """The previous - :class:`~pymongo.topology_description.TopologyDescription`. - """ - return self.__previous_description - - @property - def new_description(self) -> TopologyDescription: - """The new - :class:`~pymongo.topology_description.TopologyDescription`. - """ - return self.__new_description - - def __repr__(self) -> str: - return "<{} topology_id: {} changed from: {}, to: {}>".format( - self.__class__.__name__, - self.topology_id, - self.previous_description, - self.new_description, - ) - - -class TopologyOpenedEvent(TopologyEvent): - """Published when the topology is initialized. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class TopologyClosedEvent(TopologyEvent): - """Published when the topology is closed. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class _ServerHeartbeatEvent: - """Base class for server heartbeat events.""" - - __slots__ = ("__connection_id", "__awaited") - - def __init__(self, connection_id: _Address, awaited: bool = False) -> None: - self.__connection_id = connection_id - self.__awaited = awaited - - @property - def connection_id(self) -> _Address: - """The address (host, port) of the server this heartbeat was sent - to. - """ - return self.__connection_id - - @property - def awaited(self) -> bool: - """Whether the heartbeat was issued as an awaitable hello command. - - .. versionadded:: 4.6 - """ - return self.__awaited - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} {self.connection_id} awaited: {self.awaited}>" - - -class ServerHeartbeatStartedEvent(_ServerHeartbeatEvent): - """Published when a heartbeat is started. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent): - """Fired when the server heartbeat succeeds. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__duration", "__reply") - - def __init__( - self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False - ) -> None: - super().__init__(connection_id, awaited) - self.__duration = duration - self.__reply = reply - - @property - def duration(self) -> float: - """The duration of this heartbeat in microseconds.""" - return self.__duration - - @property - def reply(self) -> Hello: - """An instance of :class:`~pymongo.hello.Hello`.""" - return self.__reply - - @property - def awaited(self) -> bool: - """Whether the heartbeat was awaited. - - If true, then :meth:`duration` reflects the sum of the round trip time - to the server and the time that the server waited before sending a - response. - - .. versionadded:: 3.11 - """ - return super().awaited - - def __repr__(self) -> str: - return "<{} {} duration: {}, awaited: {}, reply: {}>".format( - self.__class__.__name__, - self.connection_id, - self.duration, - self.awaited, - self.reply, - ) - - -class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent): - """Fired when the server heartbeat fails, either with an "ok: 0" - or a socket exception. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__duration", "__reply") - - def __init__( - self, duration: float, reply: Exception, connection_id: _Address, awaited: bool = False - ) -> None: - super().__init__(connection_id, awaited) - self.__duration = duration - self.__reply = reply - - @property - def duration(self) -> float: - """The duration of this heartbeat in microseconds.""" - return self.__duration - - @property - def reply(self) -> Exception: - """A subclass of :exc:`Exception`.""" - return self.__reply - - @property - def awaited(self) -> bool: - """Whether the heartbeat was awaited. - - If true, then :meth:`duration` reflects the sum of the round trip time - to the server and the time that the server waited before sending a - response. - - .. versionadded:: 3.11 - """ - return super().awaited - - def __repr__(self) -> str: - return "<{} {} duration: {}, awaited: {}, reply: {!r}>".format( - self.__class__.__name__, - self.connection_id, - self.duration, - self.awaited, - self.reply, - ) - - -class _EventListeners: - """Configure event listeners for a client instance. - - Any event listeners registered globally are included by default. - - :param listeners: A list of event listeners. - """ - - def __init__(self, listeners: Optional[Sequence[_EventListener]]): - self.__command_listeners = _LISTENERS.command_listeners[:] - self.__server_listeners = _LISTENERS.server_listeners[:] - lst = _LISTENERS.server_heartbeat_listeners - self.__server_heartbeat_listeners = lst[:] - self.__topology_listeners = _LISTENERS.topology_listeners[:] - self.__cmap_listeners = _LISTENERS.cmap_listeners[:] - if listeners is not None: - for lst in listeners: - if isinstance(lst, CommandListener): - self.__command_listeners.append(lst) - if isinstance(lst, ServerListener): - self.__server_listeners.append(lst) - if isinstance(lst, ServerHeartbeatListener): - self.__server_heartbeat_listeners.append(lst) - if isinstance(lst, TopologyListener): - self.__topology_listeners.append(lst) - if isinstance(lst, ConnectionPoolListener): - self.__cmap_listeners.append(lst) - self.__enabled_for_commands = bool(self.__command_listeners) - self.__enabled_for_server = bool(self.__server_listeners) - self.__enabled_for_server_heartbeat = bool(self.__server_heartbeat_listeners) - self.__enabled_for_topology = bool(self.__topology_listeners) - self.__enabled_for_cmap = bool(self.__cmap_listeners) - - @property - def enabled_for_commands(self) -> bool: - """Are any CommandListener instances registered?""" - return self.__enabled_for_commands - - @property - def enabled_for_server(self) -> bool: - """Are any ServerListener instances registered?""" - return self.__enabled_for_server - - @property - def enabled_for_server_heartbeat(self) -> bool: - """Are any ServerHeartbeatListener instances registered?""" - return self.__enabled_for_server_heartbeat - - @property - def enabled_for_topology(self) -> bool: - """Are any TopologyListener instances registered?""" - return self.__enabled_for_topology - - @property - def enabled_for_cmap(self) -> bool: - """Are any ConnectionPoolListener instances registered?""" - return self.__enabled_for_cmap - - def event_listeners(self) -> list[_EventListeners]: - """List of registered event listeners.""" - return ( - self.__command_listeners - + self.__server_heartbeat_listeners - + self.__server_listeners - + self.__topology_listeners - + self.__cmap_listeners - ) - - def publish_command_start( - self, - command: _DocumentOut, - database_name: str, - request_id: int, - connection_id: _Address, - server_connection_id: Optional[int], - op_id: Optional[int] = None, - service_id: Optional[ObjectId] = None, - ) -> None: - """Publish a CommandStartedEvent to all command listeners. - - :param command: The command document. - :param database_name: The name of the database this command was run - against. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this - command was sent to. - :param op_id: The (optional) operation id for this operation. - :param service_id: The service_id this command was sent to, or ``None``. - """ - if op_id is None: - op_id = request_id - event = CommandStartedEvent( - command, - database_name, - request_id, - connection_id, - op_id, - service_id=service_id, - server_connection_id=server_connection_id, - ) - for subscriber in self.__command_listeners: - try: - subscriber.started(event) - except Exception: - _handle_exception() - - def publish_command_success( - self, - duration: timedelta, - reply: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - server_connection_id: Optional[int], - op_id: Optional[int] = None, - service_id: Optional[ObjectId] = None, - speculative_hello: bool = False, - database_name: str = "", - ) -> None: - """Publish a CommandSucceededEvent to all command listeners. - - :param duration: The command duration as a datetime.timedelta. - :param reply: The server reply document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this - command was sent to. - :param op_id: The (optional) operation id for this operation. - :param service_id: The service_id this command was sent to, or ``None``. - :param speculative_hello: Was the command sent with speculative auth? - :param database_name: The database this command was sent to, or ``""``. - """ - if op_id is None: - op_id = request_id - if speculative_hello: - # Redact entire response when the command started contained - # speculativeAuthenticate. - reply = {} - event = CommandSucceededEvent( - duration, - reply, - command_name, - request_id, - connection_id, - op_id, - service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - for subscriber in self.__command_listeners: - try: - subscriber.succeeded(event) - except Exception: - _handle_exception() - - def publish_command_failure( - self, - duration: timedelta, - failure: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - server_connection_id: Optional[int], - op_id: Optional[int] = None, - service_id: Optional[ObjectId] = None, - database_name: str = "", - ) -> None: - """Publish a CommandFailedEvent to all command listeners. - - :param duration: The command duration as a datetime.timedelta. - :param failure: The server reply document or failure description - document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this - command was sent to. - :param op_id: The (optional) operation id for this operation. - :param service_id: The service_id this command was sent to, or ``None``. - :param database_name: The database this command was sent to, or ``""``. - """ - if op_id is None: - op_id = request_id - event = CommandFailedEvent( - duration, - failure, - command_name, - request_id, - connection_id, - op_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - for subscriber in self.__command_listeners: - try: - subscriber.failed(event) - except Exception: - _handle_exception() - - def publish_server_heartbeat_started(self, connection_id: _Address, awaited: bool) -> None: - """Publish a ServerHeartbeatStartedEvent to all server heartbeat - listeners. - - :param connection_id: The address (host, port) pair of the connection. - :param awaited: True if this heartbeat is part of an awaitable hello command. - """ - event = ServerHeartbeatStartedEvent(connection_id, awaited) - for subscriber in self.__server_heartbeat_listeners: - try: - subscriber.started(event) - except Exception: - _handle_exception() - - def publish_server_heartbeat_succeeded( - self, connection_id: _Address, duration: float, reply: Hello, awaited: bool - ) -> None: - """Publish a ServerHeartbeatSucceededEvent to all server heartbeat - listeners. - - :param connection_id: The address (host, port) pair of the connection. - :param duration: The execution time of the event in the highest possible - resolution for the platform. - :param reply: The command reply. - :param awaited: True if the response was awaited. - """ - event = ServerHeartbeatSucceededEvent(duration, reply, connection_id, awaited) - for subscriber in self.__server_heartbeat_listeners: - try: - subscriber.succeeded(event) - except Exception: - _handle_exception() - - def publish_server_heartbeat_failed( - self, connection_id: _Address, duration: float, reply: Exception, awaited: bool - ) -> None: - """Publish a ServerHeartbeatFailedEvent to all server heartbeat - listeners. - - :param connection_id: The address (host, port) pair of the connection. - :param duration: The execution time of the event in the highest possible - resolution for the platform. - :param reply: The command reply. - :param awaited: True if the response was awaited. - """ - event = ServerHeartbeatFailedEvent(duration, reply, connection_id, awaited) - for subscriber in self.__server_heartbeat_listeners: - try: - subscriber.failed(event) - except Exception: - _handle_exception() - - def publish_server_opened(self, server_address: _Address, topology_id: ObjectId) -> None: - """Publish a ServerOpeningEvent to all server listeners. - - :param server_address: The address (host, port) pair of the server. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = ServerOpeningEvent(server_address, topology_id) - for subscriber in self.__server_listeners: - try: - subscriber.opened(event) - except Exception: - _handle_exception() - - def publish_server_closed(self, server_address: _Address, topology_id: ObjectId) -> None: - """Publish a ServerClosedEvent to all server listeners. - - :param server_address: The address (host, port) pair of the server. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = ServerClosedEvent(server_address, topology_id) - for subscriber in self.__server_listeners: - try: - subscriber.closed(event) - except Exception: - _handle_exception() - - def publish_server_description_changed( - self, - previous_description: ServerDescription, - new_description: ServerDescription, - server_address: _Address, - topology_id: ObjectId, - ) -> None: - """Publish a ServerDescriptionChangedEvent to all server listeners. - - :param previous_description: The previous server description. - :param server_address: The address (host, port) pair of the server. - :param new_description: The new server description. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = ServerDescriptionChangedEvent( - previous_description, new_description, server_address, topology_id - ) - for subscriber in self.__server_listeners: - try: - subscriber.description_changed(event) - except Exception: - _handle_exception() - - def publish_topology_opened(self, topology_id: ObjectId) -> None: - """Publish a TopologyOpenedEvent to all topology listeners. - - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = TopologyOpenedEvent(topology_id) - for subscriber in self.__topology_listeners: - try: - subscriber.opened(event) - except Exception: - _handle_exception() - - def publish_topology_closed(self, topology_id: ObjectId) -> None: - """Publish a TopologyClosedEvent to all topology listeners. - - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = TopologyClosedEvent(topology_id) - for subscriber in self.__topology_listeners: - try: - subscriber.closed(event) - except Exception: - _handle_exception() - - def publish_topology_description_changed( - self, - previous_description: TopologyDescription, - new_description: TopologyDescription, - topology_id: ObjectId, - ) -> None: - """Publish a TopologyDescriptionChangedEvent to all topology listeners. - - :param previous_description: The previous topology description. - :param new_description: The new topology description. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = TopologyDescriptionChangedEvent(previous_description, new_description, topology_id) - for subscriber in self.__topology_listeners: - try: - subscriber.description_changed(event) - except Exception: - _handle_exception() - - def publish_pool_created(self, address: _Address, options: dict[str, Any]) -> None: - """Publish a :class:`PoolCreatedEvent` to all pool listeners.""" - event = PoolCreatedEvent(address, options) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_created(event) - except Exception: - _handle_exception() - - def publish_pool_ready(self, address: _Address) -> None: - """Publish a :class:`PoolReadyEvent` to all pool listeners.""" - event = PoolReadyEvent(address) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_ready(event) - except Exception: - _handle_exception() - - def publish_pool_cleared( - self, - address: _Address, - service_id: Optional[ObjectId], - interrupt_connections: bool = False, - ) -> None: - """Publish a :class:`PoolClearedEvent` to all pool listeners.""" - event = PoolClearedEvent(address, service_id, interrupt_connections) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_cleared(event) - except Exception: - _handle_exception() - - def publish_pool_closed(self, address: _Address) -> None: - """Publish a :class:`PoolClosedEvent` to all pool listeners.""" - event = PoolClosedEvent(address) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_closed(event) - except Exception: - _handle_exception() - - def publish_connection_created(self, address: _Address, connection_id: int) -> None: - """Publish a :class:`ConnectionCreatedEvent` to all connection - listeners. - """ - event = ConnectionCreatedEvent(address, connection_id) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_created(event) - except Exception: - _handle_exception() - - def publish_connection_ready( - self, address: _Address, connection_id: int, duration: float - ) -> None: - """Publish a :class:`ConnectionReadyEvent` to all connection listeners.""" - event = ConnectionReadyEvent(address, connection_id, duration) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_ready(event) - except Exception: - _handle_exception() - - def publish_connection_closed(self, address: _Address, connection_id: int, reason: str) -> None: - """Publish a :class:`ConnectionClosedEvent` to all connection - listeners. - """ - event = ConnectionClosedEvent(address, connection_id, reason) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_closed(event) - except Exception: - _handle_exception() - - def publish_connection_check_out_started(self, address: _Address) -> None: - """Publish a :class:`ConnectionCheckOutStartedEvent` to all connection - listeners. - """ - event = ConnectionCheckOutStartedEvent(address) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_check_out_started(event) - except Exception: - _handle_exception() - - def publish_connection_check_out_failed( - self, address: _Address, reason: str, duration: float - ) -> None: - """Publish a :class:`ConnectionCheckOutFailedEvent` to all connection - listeners. - """ - event = ConnectionCheckOutFailedEvent(address, reason, duration) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_check_out_failed(event) - except Exception: - _handle_exception() - - def publish_connection_checked_out( - self, address: _Address, connection_id: int, duration: float - ) -> None: - """Publish a :class:`ConnectionCheckedOutEvent` to all connection - listeners. - """ - event = ConnectionCheckedOutEvent(address, connection_id, duration) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_checked_out(event) - except Exception: - _handle_exception() - - def publish_connection_checked_in(self, address: _Address, connection_id: int) -> None: - """Publish a :class:`ConnectionCheckedInEvent` to all connection - listeners. - """ - event = ConnectionCheckedInEvent(address, connection_id) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_checked_in(event) - except Exception: - _handle_exception() diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index 3f5319fd3..c1978087a 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -32,13 +32,18 @@ from typing import ( ) from bson import _decode_all_selective -from pymongo import _csot +from pymongo import _csot, helpers_shared, message +from pymongo.common import MAX_MESSAGE_SIZE +from pymongo.compression_support import _NO_COMPRESSION, decompress from pymongo.errors import ( NotPrimaryError, OperationFailure, ProtocolError, _OperationCancelled, ) +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply +from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( _POLL_TIMEOUT, _UNPACK_COMPRESSION_HEADER, @@ -47,24 +52,17 @@ from pymongo.network_layer import ( sendall, ) from pymongo.socket_checker import _errno_from_exception -from pymongo.synchronous import helpers as _async_helpers -from pymongo.synchronous import message as _async_message -from pymongo.synchronous.common import MAX_MESSAGE_SIZE -from pymongo.synchronous.compression_support import _NO_COMPRESSION, decompress -from pymongo.synchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.synchronous.message import _UNPACK_REPLY, _OpMsg, _OpReply -from pymongo.synchronous.monitoring import _is_speculative_authenticate if TYPE_CHECKING: from bson import CodecOptions + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.monitoring import _EventListeners from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode from pymongo.synchronous.client_session import ClientSession - from pymongo.synchronous.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.synchronous.mongo_client import MongoClient - from pymongo.synchronous.monitoring import _EventListeners from pymongo.synchronous.pool import Connection - from pymongo.synchronous.read_preferences import _ServerMode - from pymongo.synchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType + from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -129,7 +127,7 @@ def command( orig = spec if is_mongos and not use_op_msg: assert read_preference is not None - spec = _async_message._maybe_add_read_preference(spec, read_preference) + spec = message._maybe_add_read_preference(spec, read_preference) if read_concern and not (session and session.in_transaction): if read_concern.level: spec["readConcern"] = read_concern.document @@ -157,22 +155,20 @@ def command( if use_op_msg: flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 - request_id, msg, size, max_doc_size = _async_message._op_msg( + request_id, msg, size, max_doc_size = message._op_msg( flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx ) # If this is an unacknowledged write then make sure the encoded doc(s) # are small enough, otherwise rely on the server to return an error. if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: - _async_message._raise_document_too_large(name, size, max_bson_size) + message._raise_document_too_large(name, size, max_bson_size) else: - request_id, msg, size = _async_message._query( + request_id, msg, size = message._query( 0, ns, 0, -1, spec, None, codec_options, compression_ctx ) - if max_bson_size is not None and size > max_bson_size + _async_message._COMMAND_OVERHEAD: - _async_message._raise_document_too_large( - name, size, max_bson_size + _async_message._COMMAND_OVERHEAD - ) + if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: + message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) if client is not None: if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( @@ -219,7 +215,7 @@ def command( if client: client._process_response(response_doc, session) if check: - _async_helpers._check_command_response( + helpers_shared._check_command_response( response_doc, conn.max_wire_version, allowable_errors, @@ -230,7 +226,7 @@ def command( if isinstance(exc, (NotPrimaryError, OperationFailure)): failure: _DocumentOut = exc.details # type: ignore[assignment] else: - failure = _async_message._convert_exception(exc) + failure = message._convert_exception(exc) if client is not None: if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( diff --git a/pymongo/synchronous/operations.py b/pymongo/synchronous/operations.py deleted file mode 100644 index 148f84a42..000000000 --- a/pymongo/synchronous/operations.py +++ /dev/null @@ -1,625 +0,0 @@ -# Copyright 2015-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Operation class definitions.""" -from __future__ import annotations - -import enum -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Mapping, - Optional, - Sequence, - Tuple, - Union, -) - -from bson.raw_bson import RawBSONDocument -from pymongo.synchronous import helpers -from pymongo.synchronous.collation import validate_collation_or_none -from pymongo.synchronous.common import validate_is_mapping, validate_list -from pymongo.synchronous.helpers import _gen_index_name, _index_document, _index_list -from pymongo.synchronous.typings import _CollationIn, _DocumentType, _Pipeline -from pymongo.write_concern import validate_boolean - -if TYPE_CHECKING: - from pymongo.synchronous.bulk import _Bulk - -_IS_SYNC = True - -# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary -_IndexList = Union[ - Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any] -] -_IndexKeyHint = Union[str, _IndexList] - - -class _Op(str, enum.Enum): - ABORT = "abortTransaction" - AGGREGATE = "aggregate" - COMMIT = "commitTransaction" - COUNT = "count" - CREATE = "create" - CREATE_INDEXES = "createIndexes" - CREATE_SEARCH_INDEXES = "createSearchIndexes" - DELETE = "delete" - DISTINCT = "distinct" - DROP = "drop" - DROP_DATABASE = "dropDatabase" - DROP_INDEXES = "dropIndexes" - DROP_SEARCH_INDEXES = "dropSearchIndexes" - END_SESSIONS = "endSessions" - FIND_AND_MODIFY = "findAndModify" - FIND = "find" - INSERT = "insert" - LIST_COLLECTIONS = "listCollections" - LIST_INDEXES = "listIndexes" - LIST_SEARCH_INDEX = "listSearchIndexes" - LIST_DATABASES = "listDatabases" - UPDATE = "update" - UPDATE_INDEX = "updateIndex" - UPDATE_SEARCH_INDEX = "updateSearchIndex" - RENAME = "rename" - GETMORE = "getMore" - KILL_CURSORS = "killCursors" - TEST = "testOperation" - - -class InsertOne(Generic[_DocumentType]): - """Represents an insert_one operation.""" - - __slots__ = ("_doc",) - - def __init__(self, document: _DocumentType) -> None: - """Create an InsertOne instance. - - For use with :meth:`~pymongo.collection.Collection.bulk_write`. - - :param document: The document to insert. If the document is missing an - _id field one will be added. - """ - self._doc = document - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_insert(self._doc) # type: ignore[arg-type] - - def __repr__(self) -> str: - return f"InsertOne({self._doc!r})" - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return other._doc == self._doc - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -class DeleteOne: - """Represents a delete_one operation.""" - - __slots__ = ("_filter", "_collation", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - collation: Optional[_CollationIn] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create a DeleteOne instance. - - For use with :meth:`~pymongo.collection.Collection.bulk_write`. - - :param filter: A query that matches the document to delete. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.Collection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.4 and above. - - .. versionchanged:: 3.11 - Added the ``hint`` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - if filter is not None: - validate_is_mapping("filter", filter) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - self._filter = filter - self._collation = collation - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_delete( - self._filter, - 1, - collation=validate_collation_or_none(self._collation), - hint=self._hint, - ) - - def __repr__(self) -> str: - return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})" - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return (other._filter, other._collation, other._hint) == ( - self._filter, - self._collation, - self._hint, - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -class DeleteMany: - """Represents a delete_many operation.""" - - __slots__ = ("_filter", "_collation", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - collation: Optional[_CollationIn] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create a DeleteMany instance. - - For use with :meth:`~pymongo.collection.Collection.bulk_write`. - - :param filter: A query that matches the documents to delete. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.Collection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.4 and above. - - .. versionchanged:: 3.11 - Added the ``hint`` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - if filter is not None: - validate_is_mapping("filter", filter) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - self._filter = filter - self._collation = collation - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_delete( - self._filter, - 0, - collation=validate_collation_or_none(self._collation), - hint=self._hint, - ) - - def __repr__(self) -> str: - return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})" - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return (other._filter, other._collation, other._hint) == ( - self._filter, - self._collation, - self._hint, - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -class ReplaceOne(Generic[_DocumentType]): - """Represents a replace_one operation.""" - - __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - replacement: Union[_DocumentType, RawBSONDocument], - upsert: bool = False, - collation: Optional[_CollationIn] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create a ReplaceOne instance. - - For use with :meth:`~pymongo.collection.Collection.bulk_write`. - - :param filter: A query that matches the document to replace. - :param replacement: The new document. - :param upsert: If ``True``, perform an insert if no documents - match the filter. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.Collection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.2 and above. - - .. versionchanged:: 3.11 - Added the ``hint`` option. - .. versionchanged:: 3.5 - Added the ``collation`` option. - """ - if filter is not None: - validate_is_mapping("filter", filter) - if upsert is not None: - validate_boolean("upsert", upsert) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - self._filter = filter - self._doc = replacement - self._upsert = upsert - self._collation = collation - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_replace( - self._filter, - self._doc, - self._upsert, - collation=validate_collation_or_none(self._collation), - hint=self._hint, - ) - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return ( - other._filter, - other._doc, - other._upsert, - other._collation, - other._hint, - ) == ( - self._filter, - self._doc, - self._upsert, - self._collation, - other._hint, - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __repr__(self) -> str: - return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format( - self.__class__.__name__, - self._filter, - self._doc, - self._upsert, - self._collation, - self._hint, - ) - - -class _UpdateOp: - """Private base class for update operations.""" - - __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - doc: Union[Mapping[str, Any], _Pipeline], - upsert: bool, - collation: Optional[_CollationIn], - array_filters: Optional[list[Mapping[str, Any]]], - hint: Optional[_IndexKeyHint], - ): - if filter is not None: - validate_is_mapping("filter", filter) - if upsert is not None: - validate_boolean("upsert", upsert) - if array_filters is not None: - validate_list("array_filters", array_filters) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - - self._filter = filter - self._doc = doc - self._upsert = upsert - self._collation = collation - self._array_filters = array_filters - - def __eq__(self, other: object) -> bool: - if isinstance(other, type(self)): - return ( - other._filter, - other._doc, - other._upsert, - other._collation, - other._array_filters, - other._hint, - ) == ( - self._filter, - self._doc, - self._upsert, - self._collation, - self._array_filters, - self._hint, - ) - return NotImplemented - - def __repr__(self) -> str: - return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( - self.__class__.__name__, - self._filter, - self._doc, - self._upsert, - self._collation, - self._array_filters, - self._hint, - ) - - -class UpdateOne(_UpdateOp): - """Represents an update_one operation.""" - - __slots__ = () - - def __init__( - self, - filter: Mapping[str, Any], - update: Union[Mapping[str, Any], _Pipeline], - upsert: bool = False, - collation: Optional[_CollationIn] = None, - array_filters: Optional[list[Mapping[str, Any]]] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Represents an update_one operation. - - For use with :meth:`~pymongo.collection.Collection.bulk_write`. - - :param filter: A query that matches the document to update. - :param update: The modifications to apply. - :param upsert: If ``True``, perform an insert if no documents - match the filter. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param array_filters: A list of filters specifying which - array elements an update should apply. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.Collection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.2 and above. - - .. versionchanged:: 3.11 - Added the `hint` option. - .. versionchanged:: 3.9 - Added the ability to accept a pipeline as the `update`. - .. versionchanged:: 3.6 - Added the `array_filters` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - super().__init__(filter, update, upsert, collation, array_filters, hint) - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_update( - self._filter, - self._doc, - False, - self._upsert, - collation=validate_collation_or_none(self._collation), - array_filters=self._array_filters, - hint=self._hint, - ) - - -class UpdateMany(_UpdateOp): - """Represents an update_many operation.""" - - __slots__ = () - - def __init__( - self, - filter: Mapping[str, Any], - update: Union[Mapping[str, Any], _Pipeline], - upsert: bool = False, - collation: Optional[_CollationIn] = None, - array_filters: Optional[list[Mapping[str, Any]]] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create an UpdateMany instance. - - For use with :meth:`~pymongo.collection.Collection.bulk_write`. - - :param filter: A query that matches the documents to update. - :param update: The modifications to apply. - :param upsert: If ``True``, perform an insert if no documents - match the filter. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param array_filters: A list of filters specifying which - array elements an update should apply. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.Collection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.2 and above. - - .. versionchanged:: 3.11 - Added the `hint` option. - .. versionchanged:: 3.9 - Added the ability to accept a pipeline as the `update`. - .. versionchanged:: 3.6 - Added the `array_filters` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - super().__init__(filter, update, upsert, collation, array_filters, hint) - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_update( - self._filter, - self._doc, - True, - self._upsert, - collation=validate_collation_or_none(self._collation), - array_filters=self._array_filters, - hint=self._hint, - ) - - -class IndexModel: - """Represents an index to create.""" - - __slots__ = ("__document",) - - def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None: - """Create an Index instance. - - For use with :meth:`~pymongo.collection.Collection.create_indexes`. - - Takes either a single key or a list containing (key, direction) pairs - or keys. If no direction is given, :data:`~pymongo.ASCENDING` will - be assumed. - The key(s) must be an instance of :class:`str`, and the direction(s) must - be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, - :data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`, - :data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`). - - Valid options include, but are not limited to: - - - `name`: custom name to use for this index - if none is - given, a name will be generated. - - `unique`: if ``True``, creates a uniqueness constraint on the index. - - `background`: if ``True``, this index should be created in the - background. - - `sparse`: if ``True``, omit from the index any documents that lack - the indexed field. - - `bucketSize`: for use with geoHaystack indexes. - Number of documents to group together within a certain proximity - to a given longitude and latitude. - - `min`: minimum value for keys in a :data:`~pymongo.GEO2D` - index. - - `max`: maximum value for keys in a :data:`~pymongo.GEO2D` - index. - - `expireAfterSeconds`: Used to create an expiring (TTL) - collection. MongoDB will automatically delete documents from - this collection after seconds. The indexed field must - be a UTC datetime or the data will not expire. - - `partialFilterExpression`: A document that specifies a filter for - a partial index. - - `collation`: An instance of :class:`~pymongo.collation.Collation` - that specifies the collation to use. - - `wildcardProjection`: Allows users to include or exclude specific - field paths from a `wildcard index`_ using the { "$**" : 1} key - pattern. Requires MongoDB >= 4.2. - - `hidden`: if ``True``, this index will be hidden from the query - planner and will not be evaluated as part of query plan - selection. Requires MongoDB >= 4.4. - - See the MongoDB documentation for a full list of supported options by - server version. - - :param keys: a single key or a list containing (key, direction) pairs - or keys specifying the index to create. - :param kwargs: any additional index creation - options (see the above list) should be passed as keyword - arguments. - - .. versionchanged:: 3.11 - Added the ``hidden`` option. - .. versionchanged:: 3.2 - Added the ``partialFilterExpression`` option to support partial - indexes. - - .. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/ - """ - keys = _index_list(keys) - if kwargs.get("name") is None: - kwargs["name"] = _gen_index_name(keys) - kwargs["key"] = _index_document(keys) - collation = validate_collation_or_none(kwargs.pop("collation", None)) - self.__document = kwargs - if collation is not None: - self.__document["collation"] = collation - - @property - def document(self) -> dict[str, Any]: - """An index document suitable for passing to the createIndexes - command. - """ - return self.__document - - -class SearchIndexModel: - """Represents a search index to create.""" - - __slots__ = ("__document",) - - def __init__( - self, - definition: Mapping[str, Any], - name: Optional[str] = None, - type: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Create a Search Index instance. - - For use with :meth:`~pymongo.collection.Collection.create_search_index` and :meth:`~pymongo.collection.Collection.create_search_indexes`. - - :param definition: The definition for this index. - :param name: The name for this index, if present. - :param type: The type for this index which defaults to "search". Alternative values include "vectorSearch". - :param kwargs: Keyword arguments supplying any additional options. - - .. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster. - .. versionadded:: 4.5 - .. versionchanged:: 4.7 - Added the type and kwargs arguments. - """ - self.__document: dict[str, Any] = {} - if name is not None: - self.__document["name"] = name - self.__document["definition"] = definition - if type is not None: - self.__document["type"] = type - self.__document.update(kwargs) - - @property - def document(self) -> Mapping[str, Any]: - """The document for this index.""" - return self.__document diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index f77f78cd7..88de474c2 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -16,17 +16,14 @@ from __future__ import annotations import collections import contextlib -import copy import logging import os -import platform import socket import ssl import sys import threading import time import weakref -from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -39,9 +36,15 @@ from typing import ( Union, ) -import bson from bson import DEFAULT_CODEC_OPTIONS -from pymongo import __version__, _csot +from pymongo import _csot, helpers_shared +from pymongo.common import ( + MAX_BSON_SIZE, + MAX_MESSAGE_SIZE, + MAX_WIRE_VERSION, + MAX_WRITE_BATCH_SIZE, + ORDERED_TYPES, +) from pymongo.errors import ( # type:ignore[attr-defined] AutoReconnect, ConfigurationError, @@ -56,62 +59,45 @@ from pymongo.errors import ( # type:ignore[attr-defined] WaitQueueTimeoutError, _CertificateError, ) +from pymongo.hello import Hello, HelloCompat from pymongo.lock import _create_lock -from pymongo.network_layer import sendall -from pymongo.server_api import _add_to_command -from pymongo.server_type import SERVER_TYPE -from pymongo.socket_checker import SocketChecker -from pymongo.ssl_support import HAS_SNI, SSLError -from pymongo.synchronous import helpers -from pymongo.synchronous.client_session import _validate_session_write_concern -from pymongo.synchronous.common import ( - MAX_BSON_SIZE, - MAX_CONNECTING, - MAX_IDLE_TIME_SEC, - MAX_MESSAGE_SIZE, - MAX_POOL_SIZE, - MAX_WIRE_VERSION, - MAX_WRITE_BATCH_SIZE, - MIN_POOL_SIZE, - ORDERED_TYPES, - WAIT_QUEUE_TIMEOUT, -) -from pymongo.synchronous.hello import Hello -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.helpers import _handle_reauth -from pymongo.synchronous.logger import ( +from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, _debug_log, _verbose_connection_error_reason, ) -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( ConnectionCheckOutFailedReason, ConnectionClosedReason, - _EventListeners, ) +from pymongo.network_layer import sendall +from pymongo.pool_options import PoolOptions +from pymongo.read_preferences import ReadPreference +from pymongo.server_api import _add_to_command +from pymongo.server_type import SERVER_TYPE +from pymongo.socket_checker import SocketChecker +from pymongo.ssl_support import HAS_SNI, SSLError +from pymongo.synchronous.client_session import _validate_session_write_concern +from pymongo.synchronous.helpers import _handle_reauth from pymongo.synchronous.network import command, receive_message -from pymongo.synchronous.read_preferences import ReadPreference if TYPE_CHECKING: from bson import CodecOptions from bson.objectid import ObjectId - from pymongo.driver_info import DriverInfo - from pymongo.pyopenssl_context import SSLContext, _sslConn - from pymongo.read_concern import ReadConcern - from pymongo.server_api import ServerApi - from pymongo.synchronous.auth import MongoCredential, _AuthContext - from pymongo.synchronous.client_session import ClientSession - from pymongo.synchronous.compression_support import ( - CompressionSettings, + from pymongo.compression_support import ( SnappyContext, ZlibContext, ZstdContext, ) - from pymongo.synchronous.message import _OpMsg, _OpReply + from pymongo.message import _OpMsg, _OpReply + from pymongo.pyopenssl_context import _sslConn + from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode + from pymongo.synchronous.auth import _AuthContext + from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler - from pymongo.synchronous.read_preferences import _ServerMode - from pymongo.synchronous.typings import ClusterTime, _Address, _CollationIn + from pymongo.typings import ClusterTime, _Address, _CollationIn from pymongo.write_concern import WriteConcern try: @@ -191,236 +177,6 @@ else: _set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT) -_METADATA: dict[str, Any] = {"driver": {"name": "PyMongo", "version": __version__}} - -if sys.platform.startswith("linux"): - _METADATA["os"] = { - "type": platform.system(), - "name": platform.system(), - "architecture": platform.machine(), - # Kernel version (e.g. 4.4.0-17-generic). - "version": platform.release(), - } -elif sys.platform == "darwin": - _METADATA["os"] = { - "type": platform.system(), - "name": platform.system(), - "architecture": platform.machine(), - # (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin - # kernel version. - "version": platform.mac_ver()[0], - } -elif sys.platform == "win32": - _ver = sys.getwindowsversion() - _METADATA["os"] = { - "type": "Windows", - "name": "Windows", - # Avoid using platform calls, see PYTHON-4455. - "architecture": os.environ.get("PROCESSOR_ARCHITECTURE") or platform.machine(), - # Windows patch level (e.g. 10.0.17763-SP0). - "version": ".".join(map(str, _ver[:3])) + f"-SP{_ver[-1] or '0'}", - } -elif sys.platform.startswith("java"): - _name, _ver, _arch = platform.java_ver()[-1] - _METADATA["os"] = { - # Linux, Windows 7, Mac OS X, etc. - "type": _name, - "name": _name, - # x86, x86_64, AMD64, etc. - "architecture": _arch, - # Linux kernel version, OSX version, etc. - "version": _ver, - } -else: - # Get potential alias (e.g. SunOS 5.11 becomes Solaris 2.11) - _aliased = platform.system_alias(platform.system(), platform.release(), platform.version()) - _METADATA["os"] = { - "type": platform.system(), - "name": " ".join([part for part in _aliased[:2] if part]), - "architecture": platform.machine(), - "version": _aliased[2], - } - -if platform.python_implementation().startswith("PyPy"): - _METADATA["platform"] = " ".join( - ( - platform.python_implementation(), - ".".join(map(str, sys.pypy_version_info)), # type: ignore - "(Python %s)" % ".".join(map(str, sys.version_info)), - ) - ) -elif sys.platform.startswith("java"): - _METADATA["platform"] = " ".join( - ( - platform.python_implementation(), - ".".join(map(str, sys.version_info)), - "(%s)" % " ".join((platform.system(), platform.release())), - ) - ) -else: - _METADATA["platform"] = " ".join( - (platform.python_implementation(), ".".join(map(str, sys.version_info))) - ) - -DOCKER_ENV_PATH = "/.dockerenv" -ENV_VAR_K8S = "KUBERNETES_SERVICE_HOST" - -RUNTIME_NAME_DOCKER = "docker" -ORCHESTRATOR_NAME_K8S = "kubernetes" - - -def get_container_env_info() -> dict[str, str]: - """Returns the runtime and orchestrator of a container. - If neither value is present, the metadata client.env.container field will be omitted.""" - container = {} - - if Path(DOCKER_ENV_PATH).exists(): - container["runtime"] = RUNTIME_NAME_DOCKER - if os.getenv(ENV_VAR_K8S): - container["orchestrator"] = ORCHESTRATOR_NAME_K8S - - return container - - -def _is_lambda() -> bool: - if os.getenv("AWS_LAMBDA_RUNTIME_API"): - return True - env = os.getenv("AWS_EXECUTION_ENV") - if env: - return env.startswith("AWS_Lambda_") - return False - - -def _is_azure_func() -> bool: - return bool(os.getenv("FUNCTIONS_WORKER_RUNTIME")) - - -def _is_gcp_func() -> bool: - return bool(os.getenv("K_SERVICE") or os.getenv("FUNCTION_NAME")) - - -def _is_vercel() -> bool: - return bool(os.getenv("VERCEL")) - - -def _is_faas() -> bool: - return _is_lambda() or _is_azure_func() or _is_gcp_func() or _is_vercel() - - -def _getenv_int(key: str) -> Optional[int]: - """Like os.getenv but returns an int, or None if the value is missing/malformed.""" - val = os.getenv(key) - if not val: - return None - try: - return int(val) - except ValueError: - return None - - -def _metadata_env() -> dict[str, Any]: - env: dict[str, Any] = {} - container = get_container_env_info() - if container: - env["container"] = container - # Skip if multiple (or no) envs are matched. - if (_is_lambda(), _is_azure_func(), _is_gcp_func(), _is_vercel()).count(True) != 1: - return env - if _is_lambda(): - env["name"] = "aws.lambda" - region = os.getenv("AWS_REGION") - if region: - env["region"] = region - memory_mb = _getenv_int("AWS_LAMBDA_FUNCTION_MEMORY_SIZE") - if memory_mb is not None: - env["memory_mb"] = memory_mb - elif _is_azure_func(): - env["name"] = "azure.func" - elif _is_gcp_func(): - env["name"] = "gcp.func" - region = os.getenv("FUNCTION_REGION") - if region: - env["region"] = region - memory_mb = _getenv_int("FUNCTION_MEMORY_MB") - if memory_mb is not None: - env["memory_mb"] = memory_mb - timeout_sec = _getenv_int("FUNCTION_TIMEOUT_SEC") - if timeout_sec is not None: - env["timeout_sec"] = timeout_sec - elif _is_vercel(): - env["name"] = "vercel" - region = os.getenv("VERCEL_REGION") - if region: - env["region"] = region - return env - - -_MAX_METADATA_SIZE = 512 - - -# See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations -def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None: - """Perform metadata truncation.""" - if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: - return - # 1. Omit fields from env except env.name. - env_name = metadata.get("env", {}).get("name") - if env_name: - metadata["env"] = {"name": env_name} - if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: - return - # 2. Omit fields from os except os.type. - os_type = metadata.get("os", {}).get("type") - if os_type: - metadata["os"] = {"type": os_type} - if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: - return - # 3. Omit the env document entirely. - metadata.pop("env", None) - encoded_size = len(bson.encode(metadata)) - if encoded_size <= _MAX_METADATA_SIZE: - return - # 4. Truncate platform. - overflow = encoded_size - _MAX_METADATA_SIZE - plat = metadata.get("platform", "") - if plat: - plat = plat[:-overflow] - if plat: - metadata["platform"] = plat - else: - metadata.pop("platform", None) - encoded_size = len(bson.encode(metadata)) - if encoded_size <= _MAX_METADATA_SIZE: - return - # 5. Truncate driver info. - overflow = encoded_size - _MAX_METADATA_SIZE - driver = metadata.get("driver", {}) - if driver: - # Truncate driver version. - driver_version = driver.get("version")[:-overflow] - if len(driver_version) >= len(_METADATA["driver"]["version"]): - metadata["driver"]["version"] = driver_version - else: - metadata["driver"]["version"] = _METADATA["driver"]["version"] - encoded_size = len(bson.encode(metadata)) - if encoded_size <= _MAX_METADATA_SIZE: - return - # Truncate driver name. - overflow = encoded_size - _MAX_METADATA_SIZE - driver_name = driver.get("name")[:-overflow] - if len(driver_name) >= len(_METADATA["driver"]["name"]): - metadata["driver"]["name"] = driver_name - else: - metadata["driver"]["name"] = _METADATA["driver"]["name"] - - -# If the first getaddrinfo call of this interpreter's life is on a thread, -# while the main thread holds the import lock, getaddrinfo deadlocks trying -# to import the IDNA codec. Import it here, where presumably we're on the -# main thread, to avoid the deadlock. See PYTHON-607. -"foo".encode("idna") - - def _raise_connection_failure( address: Any, error: Exception, @@ -481,238 +237,6 @@ def format_timeout_details(details: Optional[dict[str, float]]) -> str: return result -class PoolOptions: - """Read only connection pool options for a MongoClient. - - Should not be instantiated directly by application developers. Access - a client's pool options via - :attr:`~pymongo.client_options.ClientOptions.pool_options` instead:: - - pool_opts = client.options.pool_options - pool_opts.max_pool_size - pool_opts.min_pool_size - - """ - - __slots__ = ( - "__max_pool_size", - "__min_pool_size", - "__max_idle_time_seconds", - "__connect_timeout", - "__socket_timeout", - "__wait_queue_timeout", - "__ssl_context", - "__tls_allow_invalid_hostnames", - "__event_listeners", - "__appname", - "__driver", - "__metadata", - "__compression_settings", - "__max_connecting", - "__pause_enabled", - "__server_api", - "__load_balanced", - "__credentials", - ) - - def __init__( - self, - max_pool_size: int = MAX_POOL_SIZE, - min_pool_size: int = MIN_POOL_SIZE, - max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC, - connect_timeout: Optional[float] = None, - socket_timeout: Optional[float] = None, - wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT, - ssl_context: Optional[SSLContext] = None, - tls_allow_invalid_hostnames: bool = False, - event_listeners: Optional[_EventListeners] = None, - appname: Optional[str] = None, - driver: Optional[DriverInfo] = None, - compression_settings: Optional[CompressionSettings] = None, - max_connecting: int = MAX_CONNECTING, - pause_enabled: bool = True, - server_api: Optional[ServerApi] = None, - load_balanced: Optional[bool] = None, - credentials: Optional[MongoCredential] = None, - ): - self.__max_pool_size = max_pool_size - self.__min_pool_size = min_pool_size - self.__max_idle_time_seconds = max_idle_time_seconds - self.__connect_timeout = connect_timeout - self.__socket_timeout = socket_timeout - self.__wait_queue_timeout = wait_queue_timeout - self.__ssl_context = ssl_context - self.__tls_allow_invalid_hostnames = tls_allow_invalid_hostnames - self.__event_listeners = event_listeners - self.__appname = appname - self.__driver = driver - self.__compression_settings = compression_settings - self.__max_connecting = max_connecting - self.__pause_enabled = pause_enabled - self.__server_api = server_api - self.__load_balanced = load_balanced - self.__credentials = credentials - self.__metadata = copy.deepcopy(_METADATA) - if appname: - self.__metadata["application"] = {"name": appname} - - # Combine the "driver" MongoClient option with PyMongo's info, like: - # { - # 'driver': { - # 'name': 'PyMongo|MyDriver', - # 'version': '4.2.0|1.2.3', - # }, - # 'platform': 'CPython 3.8.0|MyPlatform' - # } - if driver: - if driver.name: - self.__metadata["driver"]["name"] = "{}|{}".format( - _METADATA["driver"]["name"], - driver.name, - ) - if driver.version: - self.__metadata["driver"]["version"] = "{}|{}".format( - _METADATA["driver"]["version"], - driver.version, - ) - if driver.platform: - self.__metadata["platform"] = "{}|{}".format(_METADATA["platform"], driver.platform) - - env = _metadata_env() - if env: - self.__metadata["env"] = env - - _truncate_metadata(self.__metadata) - - @property - def _credentials(self) -> Optional[MongoCredential]: - """A :class:`~pymongo.auth.MongoCredentials` instance or None.""" - return self.__credentials - - @property - def non_default_options(self) -> dict[str, Any]: - """The non-default options this pool was created with. - - Added for CMAP's :class:`PoolCreatedEvent`. - """ - opts = {} - if self.__max_pool_size != MAX_POOL_SIZE: - opts["maxPoolSize"] = self.__max_pool_size - if self.__min_pool_size != MIN_POOL_SIZE: - opts["minPoolSize"] = self.__min_pool_size - if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC: - assert self.__max_idle_time_seconds is not None - opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000 - if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT: - assert self.__wait_queue_timeout is not None - opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000 - if self.__max_connecting != MAX_CONNECTING: - opts["maxConnecting"] = self.__max_connecting - return opts - - @property - def max_pool_size(self) -> float: - """The maximum allowable number of concurrent connections to each - connected server. Requests to a server will block if there are - `maxPoolSize` outstanding connections to the requested server. - Defaults to 100. Cannot be 0. - - When a server's pool has reached `max_pool_size`, operations for that - server block waiting for a socket to be returned to the pool. If - ``waitQueueTimeoutMS`` is set, a blocked operation will raise - :exc:`~pymongo.errors.ConnectionFailure` after a timeout. - By default ``waitQueueTimeoutMS`` is not set. - """ - return self.__max_pool_size - - @property - def min_pool_size(self) -> int: - """The minimum required number of concurrent connections that the pool - will maintain to each connected server. Default is 0. - """ - return self.__min_pool_size - - @property - def max_connecting(self) -> int: - """The maximum number of concurrent connection creation attempts per - pool. Defaults to 2. - """ - return self.__max_connecting - - @property - def pause_enabled(self) -> bool: - return self.__pause_enabled - - @property - def max_idle_time_seconds(self) -> Optional[int]: - """The maximum number of seconds that a connection can remain - idle in the pool before being removed and replaced. Defaults to - `None` (no limit). - """ - return self.__max_idle_time_seconds - - @property - def connect_timeout(self) -> Optional[float]: - """How long a connection can take to be opened before timing out.""" - return self.__connect_timeout - - @property - def socket_timeout(self) -> Optional[float]: - """How long a send or receive on a socket can take before timing out.""" - return self.__socket_timeout - - @property - def wait_queue_timeout(self) -> Optional[int]: - """How long a thread will wait for a socket from the pool if the pool - has no free sockets. - """ - return self.__wait_queue_timeout - - @property - def _ssl_context(self) -> Optional[SSLContext]: - """An SSLContext instance or None.""" - return self.__ssl_context - - @property - def tls_allow_invalid_hostnames(self) -> bool: - """If True skip ssl.match_hostname.""" - return self.__tls_allow_invalid_hostnames - - @property - def _event_listeners(self) -> Optional[_EventListeners]: - """An instance of pymongo.monitoring._EventListeners.""" - return self.__event_listeners - - @property - def appname(self) -> Optional[str]: - """The application name, for sending with hello in server handshake.""" - return self.__appname - - @property - def driver(self) -> Optional[DriverInfo]: - """Driver name and version, for sending with hello in handshake.""" - return self.__driver - - @property - def _compression_settings(self) -> Optional[CompressionSettings]: - return self.__compression_settings - - @property - def metadata(self) -> dict[str, Any]: - """A dict of metadata about the application, driver, os, and platform.""" - return self.__metadata.copy() - - @property - def server_api(self) -> Optional[ServerApi]: - """A pymongo.server_api.ServerApi or None.""" - return self.__server_api - - @property - def load_balanced(self) -> Optional[bool]: - """True if this Pool is configured in load balanced mode.""" - return self.__load_balanced - - class _CancellationContext: def __init__(self) -> None: self._cancelled = False @@ -946,7 +470,7 @@ class Connection: self.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response() response_doc = unpacked_docs[0] - helpers._check_command_response(response_doc, self.max_wire_version) + helpers_shared._check_command_response(response_doc, self.max_wire_version) return response_doc @_handle_reauth @@ -1099,7 +623,7 @@ class Connection: result = reply.command_response(codec_options) # Raises NotPrimaryError or OperationFailure. - helpers._check_command_response(result, self.max_wire_version) + helpers_shared._check_command_response(result, self.max_wire_version) return result def authenticate(self, reauthenticate: bool = False) -> None: diff --git a/pymongo/synchronous/read_preferences.py b/pymongo/synchronous/read_preferences.py deleted file mode 100644 index 464256c34..000000000 --- a/pymongo/synchronous/read_preferences.py +++ /dev/null @@ -1,624 +0,0 @@ -# Copyright 2012-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License", -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for choosing which member of a replica set to read from.""" - -from __future__ import annotations - -from collections import abc -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence - -from pymongo.errors import ConfigurationError -from pymongo.synchronous import max_staleness_selectors -from pymongo.synchronous.server_selectors import ( - member_with_tags_server_selector, - secondary_with_tags_server_selector, -) - -if TYPE_CHECKING: - from pymongo.synchronous.server_selectors import Selection - from pymongo.synchronous.topology_description import TopologyDescription - -_IS_SYNC = True - -_PRIMARY = 0 -_PRIMARY_PREFERRED = 1 -_SECONDARY = 2 -_SECONDARY_PREFERRED = 3 -_NEAREST = 4 - - -_MONGOS_MODES = ( - "primary", - "primaryPreferred", - "secondary", - "secondaryPreferred", - "nearest", -) - -_Hedge = Mapping[str, Any] -_TagSets = Sequence[Mapping[str, Any]] - - -def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]: - """Validate tag sets for a MongoClient.""" - if tag_sets is None: - return tag_sets - - if not isinstance(tag_sets, (list, tuple)): - raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence") - if len(tag_sets) == 0: - raise ValueError( - f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags" - ) - - for tags in tag_sets: - if not isinstance(tags, abc.Mapping): - raise TypeError( - f"Tag set {tags!r} invalid, must be an instance of dict, " - "bson.son.SON or other type that inherits from " - "collection.Mapping" - ) - - return list(tag_sets) - - -def _invalid_max_staleness_msg(max_staleness: Any) -> str: - return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness - - -# Some duplication with common.py to avoid import cycle. -def _validate_max_staleness(max_staleness: Any) -> int: - """Validate max_staleness.""" - if max_staleness == -1: - return -1 - - if not isinstance(max_staleness, int): - raise TypeError(_invalid_max_staleness_msg(max_staleness)) - - if max_staleness <= 0: - raise ValueError(_invalid_max_staleness_msg(max_staleness)) - - return max_staleness - - -def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]: - """Validate hedge.""" - if hedge is None: - return None - - if not isinstance(hedge, dict): - raise TypeError(f"hedge must be a dictionary, not {hedge!r}") - - return hedge - - -class _ServerMode: - """Base class for all read preferences.""" - - __slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge") - - def __init__( - self, - mode: int, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - self.__mongos_mode = _MONGOS_MODES[mode] - self.__mode = mode - self.__tag_sets = _validate_tag_sets(tag_sets) - self.__max_staleness = _validate_max_staleness(max_staleness) - self.__hedge = _validate_hedge(hedge) - - @property - def name(self) -> str: - """The name of this read preference.""" - return self.__class__.__name__ - - @property - def mongos_mode(self) -> str: - """The mongos mode of this read preference.""" - return self.__mongos_mode - - @property - def document(self) -> dict[str, Any]: - """Read preference as a document.""" - doc: dict[str, Any] = {"mode": self.__mongos_mode} - if self.__tag_sets not in (None, [{}]): - doc["tags"] = self.__tag_sets - if self.__max_staleness != -1: - doc["maxStalenessSeconds"] = self.__max_staleness - if self.__hedge not in (None, {}): - doc["hedge"] = self.__hedge - return doc - - @property - def mode(self) -> int: - """The mode of this read preference instance.""" - return self.__mode - - @property - def tag_sets(self) -> _TagSets: - """Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to - read only from members whose ``dc`` tag has the value ``"ny"``. - To specify a priority-order for tag sets, provide a list of - tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag - set, ``{}``, means "read from any member that matches the mode, - ignoring tags." MongoClient tries each set of tags in turn - until it finds a set of tags with at least one matching member. - For example, to only send a query to an analytic node:: - - Nearest(tag_sets=[{"node":"analytics"}]) - - Or using :class:`SecondaryPreferred`:: - - SecondaryPreferred(tag_sets=[{"node":"analytics"}]) - - .. seealso:: `Data-Center Awareness - `_ - """ - return list(self.__tag_sets) if self.__tag_sets else [{}] - - @property - def max_staleness(self) -> int: - """The maximum estimated length of time (in seconds) a replica set - secondary can fall behind the primary in replication before it will - no longer be selected for operations, or -1 for no maximum. - """ - return self.__max_staleness - - @property - def hedge(self) -> Optional[_Hedge]: - """The read preference ``hedge`` parameter. - - A dictionary that configures how the server will perform hedged reads. - It consists of the following keys: - - - ``enabled``: Enables or disables hedged reads in sharded clusters. - - Hedged reads are automatically enabled in MongoDB 4.4+ when using a - ``nearest`` read preference. To explicitly enable hedged reads, set - the ``enabled`` key to ``true``:: - - >>> Nearest(hedge={'enabled': True}) - - To explicitly disable hedged reads, set the ``enabled`` key to - ``False``:: - - >>> Nearest(hedge={'enabled': False}) - - .. versionadded:: 3.11 - """ - return self.__hedge - - @property - def min_wire_version(self) -> int: - """The wire protocol version the server must support. - - Some read preferences impose version requirements on all servers (e.g. - maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5). - - All servers' maxWireVersion must be at least this read preference's - `min_wire_version`, or the driver raises - :exc:`~pymongo.errors.ConfigurationError`. - """ - return 0 if self.__max_staleness == -1 else 5 - - def __repr__(self) -> str: - return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format( - self.name, - self.__tag_sets, - self.__max_staleness, - self.__hedge, - ) - - def __eq__(self, other: Any) -> bool: - if isinstance(other, _ServerMode): - return ( - self.mode == other.mode - and self.tag_sets == other.tag_sets - and self.max_staleness == other.max_staleness - and self.hedge == other.hedge - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __getstate__(self) -> dict[str, Any]: - """Return value of object for pickling. - - Needed explicitly because __slots__() defined. - """ - return { - "mode": self.__mode, - "tag_sets": self.__tag_sets, - "max_staleness": self.__max_staleness, - "hedge": self.__hedge, - } - - def __setstate__(self, value: Mapping[str, Any]) -> None: - """Restore from pickling.""" - self.__mode = value["mode"] - self.__mongos_mode = _MONGOS_MODES[self.__mode] - self.__tag_sets = _validate_tag_sets(value["tag_sets"]) - self.__max_staleness = _validate_max_staleness(value["max_staleness"]) - self.__hedge = _validate_hedge(value["hedge"]) - - def __call__(self, selection: Selection) -> Selection: - return selection - - -class Primary(_ServerMode): - """Primary read preference. - - * When directly connected to one mongod queries are allowed if the server - is standalone or a replica set primary. - * When connected to a mongos queries are sent to the primary of a shard. - * When connected to a replica set queries are sent to the primary of - the replica set. - """ - - __slots__ = () - - def __init__(self) -> None: - super().__init__(_PRIMARY) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to a Selection.""" - return selection.primary_selection - - def __repr__(self) -> str: - return "Primary()" - - def __eq__(self, other: Any) -> bool: - if isinstance(other, _ServerMode): - return other.mode == _PRIMARY - return NotImplemented - - -class PrimaryPreferred(_ServerMode): - """PrimaryPreferred read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are sent to the primary of a shard if - available, otherwise a shard secondary. - * When connected to a replica set queries are sent to the primary if - available, otherwise a secondary. - - .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first - created reads will be routed to an available secondary until the - primary of the replica set is discovered. - - :param tag_sets: The :attr:`~tag_sets` to use if the primary is not - available. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` to use if the primary is not available. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - if selection.primary: - return selection.primary_selection - else: - return secondary_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - -class Secondary(_ServerMode): - """Secondary read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are distributed among shard - secondaries. An error is raised if no secondaries are available. - * When connected to a replica set queries are distributed among - secondaries. An error is raised if no secondaries are available. - - :param tag_sets: The :attr:`~tag_sets` for this read preference. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_SECONDARY, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - return secondary_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - -class SecondaryPreferred(_ServerMode): - """SecondaryPreferred read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are distributed among shard - secondaries, or the shard primary if no secondary is available. - * When connected to a replica set queries are distributed among - secondaries, or the primary if no secondary is available. - - .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first - created reads will be routed to the primary of the replica set until - an available secondary is discovered. - - :param tag_sets: The :attr:`~tag_sets` for this read preference. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - secondaries = secondary_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - if secondaries: - return secondaries - else: - return selection.primary_selection - - -class Nearest(_ServerMode): - """Nearest read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are distributed among all members of - a shard. - * When connected to a replica set queries are distributed among all - members. - - :param tag_sets: The :attr:`~tag_sets` for this read preference. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_NEAREST, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - return member_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - -class _AggWritePref: - """Agg $out/$merge write preference. - - * If there are readable servers and there is any pre-5.0 server, use - primary read preference. - * Otherwise use `pref` read preference. - - :param pref: The read preference to use on MongoDB 5.0+. - """ - - __slots__ = ("pref", "effective_pref") - - def __init__(self, pref: _ServerMode): - self.pref = pref - self.effective_pref: _ServerMode = ReadPreference.PRIMARY - - def selection_hook(self, topology_description: TopologyDescription) -> None: - common_wv = topology_description.common_wire_version - if ( - topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED) - and common_wv - and common_wv < 13 - ): - self.effective_pref = ReadPreference.PRIMARY - else: - self.effective_pref = self.pref - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to a Selection.""" - return self.effective_pref(selection) - - def __repr__(self) -> str: - return f"_AggWritePref(pref={self.pref!r})" - - # Proxy other calls to the effective_pref so that _AggWritePref can be - # used in place of an actual read preference. - def __getattr__(self, name: str) -> Any: - return getattr(self.effective_pref, name) - - -_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest) - - -def make_read_preference( - mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1 -) -> _ServerMode: - if mode == _PRIMARY: - if tag_sets not in (None, [{}]): - raise ConfigurationError("Read preference primary cannot be combined with tags") - if max_staleness != -1: - raise ConfigurationError( - "Read preference primary cannot be combined with maxStalenessSeconds" - ) - return Primary() - return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore - - -_MODES = ( - "PRIMARY", - "PRIMARY_PREFERRED", - "SECONDARY", - "SECONDARY_PREFERRED", - "NEAREST", -) - - -class ReadPreference: - """An enum that defines some commonly used read preference modes. - - Apps can also create a custom read preference, for example:: - - Nearest(tag_sets=[{"node":"analytics"}]) - - See :doc:`/examples/high_availability` for code examples. - - A read preference is used in three cases: - - :class:`~pymongo.mongo_client.MongoClient` connected to a single mongod: - - - ``PRIMARY``: Queries are allowed if the server is standalone or a replica - set primary. - - All other modes allow queries to standalone servers, to a replica set - primary, or to replica set secondaries. - - :class:`~pymongo.mongo_client.MongoClient` initialized with the - ``replicaSet`` option: - - - ``PRIMARY``: Read from the primary. This is the default, and provides the - strongest consistency. If no primary is available, raise - :class:`~pymongo.errors.AutoReconnect`. - - - ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is - none, read from a secondary. - - - ``SECONDARY``: Read from a secondary. If no secondary is available, - raise :class:`~pymongo.errors.AutoReconnect`. - - - ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise - from the primary. - - - ``NEAREST``: Read from any member. - - :class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a - sharded cluster of replica sets: - - - ``PRIMARY``: Read from the primary of the shard, or raise - :class:`~pymongo.errors.OperationFailure` if there is none. - This is the default. - - - ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is - none, read from a secondary of the shard. - - - ``SECONDARY``: Read from a secondary of the shard, or raise - :class:`~pymongo.errors.OperationFailure` if there is none. - - - ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available, - otherwise from the shard primary. - - - ``NEAREST``: Read from any shard member. - """ - - PRIMARY = Primary() - PRIMARY_PREFERRED = PrimaryPreferred() - SECONDARY = Secondary() - SECONDARY_PREFERRED = SecondaryPreferred() - NEAREST = Nearest() - - -def read_pref_mode_from_name(name: str) -> int: - """Get the read preference mode from mongos/uri name.""" - return _MONGOS_MODES.index(name) - - -class MovingAverage: - """Tracks an exponentially-weighted moving average.""" - - average: Optional[float] - - def __init__(self) -> None: - self.average = None - - def add_sample(self, sample: float) -> None: - if sample < 0: - # Likely system time change while waiting for hello response - # and not using time.monotonic. Ignore it, the next one will - # probably be valid. - return - if self.average is None: - self.average = sample - else: - # The Server Selection Spec requires an exponentially weighted - # average with alpha = 0.2. - self.average = 0.8 * self.average + 0.2 * sample - - def get(self) -> Optional[float]: - """Get the calculated average, or None if no samples yet.""" - return self.average - - def reset(self) -> None: - self.average = None diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 4c7956999..347155784 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -28,23 +28,24 @@ from typing import ( from bson import _decode_all_selective from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.synchronous.helpers import _check_command_response, _handle_reauth -from pymongo.synchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.synchronous.message import _convert_exception, _GetMore, _OpMsg, _Query -from pymongo.synchronous.response import PinnedResponse, Response +from pymongo.helpers_shared import _check_command_response +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query +from pymongo.response import PinnedResponse, Response +from pymongo.synchronous.helpers import _handle_reauth if TYPE_CHECKING: from queue import Queue from weakref import ReferenceType from bson.objectid import ObjectId + from pymongo.monitoring import _EventListeners + from pymongo.read_preferences import _ServerMode + from pymongo.server_description import ServerDescription from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler from pymongo.synchronous.monitor import Monitor - from pymongo.synchronous.monitoring import _EventListeners from pymongo.synchronous.pool import Connection, Pool - from pymongo.synchronous.read_preferences import _ServerMode - from pymongo.synchronous.server_description import ServerDescription - from pymongo.synchronous.typings import _DocumentOut + from pymongo.typings import _DocumentOut _IS_SYNC = True @@ -105,6 +106,19 @@ class Server: """Check the server's state soon.""" self._monitor.request_check() + def operation_to_command( + self, operation: Union[_Query, _GetMore], conn: Connection, apply_timeout: bool = False + ) -> tuple[dict[str, Any], str]: + cmd, db = operation.as_command(conn, apply_timeout) + # Support auto encryption + if operation.client._encrypter and not operation.client._encrypter._bypass_auto_encryption: + cmd = operation.client._encrypter.encrypt( # type: ignore[misc, assignment] + operation.db, cmd, operation.codec_options + ) + operation.update_command(cmd) + + return cmd, db + @_handle_reauth def run_operation( self, @@ -126,21 +140,21 @@ class Server: :param read_preference: The read preference to use. :param listeners: Instance of _EventListeners or None. :param unpack_res: A callable that decodes the wire protocol response. + :param client: A MongoClient instance. """ - duration = None assert listeners is not None publish = listeners.enabled_for_commands start = datetime.now() use_cmd = operation.use_command(conn) more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come + cmd, dbn = self.operation_to_command(operation, conn, use_cmd) if more_to_come: request_id = 0 else: message = operation.get_message(read_preference, conn, use_cmd) request_id, data, max_doc_size = self._split_message(message) - cmd, dbn = operation.as_command(conn) if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, @@ -159,7 +173,6 @@ class Server: ) if publish: - cmd, dbn = operation.as_command(conn) if "$db" not in cmd: cmd["$db"] = dbn assert listeners is not None @@ -195,7 +208,7 @@ class Server: ) if use_cmd: first = docs[0] - operation.client._process_response(first, operation.session) + operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] _check_command_response(first, conn.max_wire_version) except Exception as exc: duration = datetime.now() - start @@ -278,7 +291,7 @@ class Server: ) # Decrypt response. - client = operation.client + client = operation.client # type: ignore[assignment] if client and client._encrypter: if use_cmd: decrypted = client._encrypter.decrypt(reply.raw_command_response()) @@ -286,7 +299,7 @@ class Server: response: Response - if client._should_pin_cursor(operation.session) or operation.exhaust: + if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type] conn.pin_cursor() if isinstance(reply, _OpMsg): # In OP_MSG, the server keeps sending only if the diff --git a/pymongo/synchronous/server_description.py b/pymongo/synchronous/server_description.py deleted file mode 100644 index 4a23fc129..000000000 --- a/pymongo/synchronous/server_description.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Represent one server the driver is connected to.""" -from __future__ import annotations - -import time -import warnings -from typing import Any, Mapping, Optional - -from bson import EPOCH_NAIVE -from bson.objectid import ObjectId -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.hello import Hello -from pymongo.synchronous.typings import ClusterTime, _Address - -_IS_SYNC = True - - -class ServerDescription: - """Immutable representation of one server. - - :param address: A (host, port) pair - :param hello: Optional Hello instance - :param round_trip_time: Optional float - :param error: Optional, the last error attempting to connect to the server - :param round_trip_time: Optional float, the min latency from the most recent samples - """ - - __slots__ = ( - "_address", - "_server_type", - "_all_hosts", - "_tags", - "_replica_set_name", - "_primary", - "_max_bson_size", - "_max_message_size", - "_max_write_batch_size", - "_min_wire_version", - "_max_wire_version", - "_round_trip_time", - "_min_round_trip_time", - "_me", - "_is_writable", - "_is_readable", - "_ls_timeout_minutes", - "_error", - "_set_version", - "_election_id", - "_cluster_time", - "_last_write_date", - "_last_update_time", - "_topology_version", - ) - - def __init__( - self, - address: _Address, - hello: Optional[Hello] = None, - round_trip_time: Optional[float] = None, - error: Optional[Exception] = None, - min_round_trip_time: float = 0.0, - ) -> None: - self._address = address - if not hello: - hello = Hello({}) - - self._server_type = hello.server_type - self._all_hosts = hello.all_hosts - self._tags = hello.tags - self._replica_set_name = hello.replica_set_name - self._primary = hello.primary - self._max_bson_size = hello.max_bson_size - self._max_message_size = hello.max_message_size - self._max_write_batch_size = hello.max_write_batch_size - self._min_wire_version = hello.min_wire_version - self._max_wire_version = hello.max_wire_version - self._set_version = hello.set_version - self._election_id = hello.election_id - self._cluster_time = hello.cluster_time - self._is_writable = hello.is_writable - self._is_readable = hello.is_readable - self._ls_timeout_minutes = hello.logical_session_timeout_minutes - self._round_trip_time = round_trip_time - self._min_round_trip_time = min_round_trip_time - self._me = hello.me - self._last_update_time = time.monotonic() - self._error = error - self._topology_version = hello.topology_version - if error: - details = getattr(error, "details", None) - if isinstance(details, dict): - self._topology_version = details.get("topologyVersion") - - self._last_write_date: Optional[float] - if hello.last_write_date: - # Convert from datetime to seconds. - delta = hello.last_write_date - EPOCH_NAIVE - self._last_write_date = delta.total_seconds() - else: - self._last_write_date = None - - @property - def address(self) -> _Address: - """The address (host, port) of this server.""" - return self._address - - @property - def server_type(self) -> int: - """The type of this server.""" - return self._server_type - - @property - def server_type_name(self) -> str: - """The server type as a human readable string. - - .. versionadded:: 3.4 - """ - return SERVER_TYPE._fields[self._server_type] - - @property - def all_hosts(self) -> set[tuple[str, int]]: - """List of hosts, passives, and arbiters known to this server.""" - return self._all_hosts - - @property - def tags(self) -> Mapping[str, Any]: - return self._tags - - @property - def replica_set_name(self) -> Optional[str]: - """Replica set name or None.""" - return self._replica_set_name - - @property - def primary(self) -> Optional[tuple[str, int]]: - """This server's opinion about who the primary is, or None.""" - return self._primary - - @property - def max_bson_size(self) -> int: - return self._max_bson_size - - @property - def max_message_size(self) -> int: - return self._max_message_size - - @property - def max_write_batch_size(self) -> int: - return self._max_write_batch_size - - @property - def min_wire_version(self) -> int: - return self._min_wire_version - - @property - def max_wire_version(self) -> int: - return self._max_wire_version - - @property - def set_version(self) -> Optional[int]: - return self._set_version - - @property - def election_id(self) -> Optional[ObjectId]: - return self._election_id - - @property - def cluster_time(self) -> Optional[ClusterTime]: - return self._cluster_time - - @property - def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]: - warnings.warn( - "'election_tuple' is deprecated, use 'set_version' and 'election_id' instead", - DeprecationWarning, - stacklevel=2, - ) - return self._set_version, self._election_id - - @property - def me(self) -> Optional[tuple[str, int]]: - return self._me - - @property - def logical_session_timeout_minutes(self) -> Optional[int]: - return self._ls_timeout_minutes - - @property - def last_write_date(self) -> Optional[float]: - return self._last_write_date - - @property - def last_update_time(self) -> float: - return self._last_update_time - - @property - def round_trip_time(self) -> Optional[float]: - """The current average latency or None.""" - # This override is for unittesting only! - if self._address in self._host_to_round_trip_time: - return self._host_to_round_trip_time[self._address] - - return self._round_trip_time - - @property - def min_round_trip_time(self) -> float: - """The min latency from the most recent samples.""" - return self._min_round_trip_time - - @property - def error(self) -> Optional[Exception]: - """The last error attempting to connect to the server, or None.""" - return self._error - - @property - def is_writable(self) -> bool: - return self._is_writable - - @property - def is_readable(self) -> bool: - return self._is_readable - - @property - def mongos(self) -> bool: - return self._server_type == SERVER_TYPE.Mongos - - @property - def is_server_type_known(self) -> bool: - return self.server_type != SERVER_TYPE.Unknown - - @property - def retryable_writes_supported(self) -> bool: - """Checks if this server supports retryable writes.""" - return ( - self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary) - ) or self._server_type == SERVER_TYPE.LoadBalancer - - @property - def retryable_reads_supported(self) -> bool: - """Checks if this server supports retryable writes.""" - return self._max_wire_version >= 6 - - @property - def topology_version(self) -> Optional[Mapping[str, Any]]: - return self._topology_version - - def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription: - unknown = ServerDescription(self.address, error=error) - unknown._topology_version = self.topology_version - return unknown - - def __eq__(self, other: Any) -> bool: - if isinstance(other, ServerDescription): - return ( - (self._address == other.address) - and (self._server_type == other.server_type) - and (self._min_wire_version == other.min_wire_version) - and (self._max_wire_version == other.max_wire_version) - and (self._me == other.me) - and (self._all_hosts == other.all_hosts) - and (self._tags == other.tags) - and (self._replica_set_name == other.replica_set_name) - and (self._set_version == other.set_version) - and (self._election_id == other.election_id) - and (self._primary == other.primary) - and (self._ls_timeout_minutes == other.logical_session_timeout_minutes) - and (self._error == other.error) - ) - - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __repr__(self) -> str: - errmsg = "" - if self.error: - errmsg = f", error={self.error!r}" - return "<{} {} server_type: {}, rtt: {}{}>".format( - self.__class__.__name__, - self.address, - self.server_type_name, - self.round_trip_time, - errmsg, - ) - - # For unittesting only. Use under no circumstances! - _host_to_round_trip_time: dict = {} diff --git a/pymongo/synchronous/settings.py b/pymongo/synchronous/settings.py index f51b5307a..8719e8608 100644 --- a/pymongo/synchronous/settings.py +++ b/pymongo/synchronous/settings.py @@ -20,12 +20,14 @@ import traceback from typing import Any, Collection, Optional, Type, Union from bson.objectid import ObjectId +from pymongo import common +from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT from pymongo.errors import ConfigurationError -from pymongo.synchronous import common, monitor, pool -from pymongo.synchronous.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT -from pymongo.synchronous.pool import Pool, PoolOptions -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE, _ServerSelector +from pymongo.pool_options import PoolOptions +from pymongo.server_description import ServerDescription +from pymongo.synchronous import monitor, pool +from pymongo.synchronous.pool import Pool +from pymongo.topology_description import TOPOLOGY_TYPE, _ServerSelector _IS_SYNC = True diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index 6c8cd8870..b2c102ae0 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -27,7 +27,7 @@ import weakref from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast -from pymongo import _csot, helpers_constants +from pymongo import _csot, common, helpers_shared from pymongo.errors import ( ConnectionFailure, InvalidOperation, @@ -38,27 +38,28 @@ from pymongo.errors import ( ServerSelectionTimeoutError, WriteError, ) +from pymongo.hello import Hello from pymongo.lock import _create_lock -from pymongo.synchronous import common, periodic_executor -from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool -from pymongo.synchronous.hello import Hello -from pymongo.synchronous.logger import ( +from pymongo.logger import ( _SERVER_SELECTION_LOGGER, _debug_log, _ServerSelectionStatusMessage, ) -from pymongo.synchronous.monitor import SrvMonitor -from pymongo.synchronous.pool import Pool, PoolOptions -from pymongo.synchronous.server import Server -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import ( +from pymongo.pool_options import PoolOptions +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import ( Selection, any_server_selector, arbiter_server_selector, secondary_server_selector, writable_server_selector, ) -from pymongo.synchronous.topology_description import ( +from pymongo.synchronous import periodic_executor +from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool +from pymongo.synchronous.monitor import SrvMonitor +from pymongo.synchronous.pool import Pool +from pymongo.synchronous.server import Server +from pymongo.topology_description import ( SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE, TopologyDescription, @@ -69,7 +70,7 @@ from pymongo.synchronous.topology_description import ( if TYPE_CHECKING: from bson import ObjectId from pymongo.synchronous.settings import TopologySettings - from pymongo.synchronous.typings import ClusterTime, _Address + from pymongo.typings import ClusterTime, _Address _IS_SYNC = True @@ -779,8 +780,8 @@ class Topology: # Default error code if one does not exist. default = 10107 if isinstance(error, NotPrimaryError) else None err_code = error.details.get("code", default) # type: ignore[union-attr] - if err_code in helpers_constants._NOT_PRIMARY_CODES: - is_shutting_down = err_code in helpers_constants._SHUTDOWN_CODES + if err_code in helpers_shared._NOT_PRIMARY_CODES: + is_shutting_down = err_code in helpers_shared._SHUTDOWN_CODES # Mark server Unknown, clear the pool, and request check. if not self._settings.load_balanced: self._process_change(ServerDescription(address, error=error)) diff --git a/pymongo/synchronous/topology_description.py b/pymongo/synchronous/topology_description.py deleted file mode 100644 index 961b9da8d..000000000 --- a/pymongo/synchronous/topology_description.py +++ /dev/null @@ -1,678 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Represent a deployment of MongoDB servers.""" -from __future__ import annotations - -from random import sample -from typing import ( - Any, - Callable, - List, - Mapping, - MutableMapping, - NamedTuple, - Optional, - cast, -) - -from bson.min_key import MinKey -from bson.objectid import ObjectId -from pymongo.errors import ConfigurationError -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import common -from pymongo.synchronous.read_preferences import ReadPreference, _AggWritePref, _ServerMode -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import Selection -from pymongo.synchronous.typings import _Address - -_IS_SYNC = True - - -# Enumeration for various kinds of MongoDB cluster topologies. -class _TopologyType(NamedTuple): - Single: int - ReplicaSetNoPrimary: int - ReplicaSetWithPrimary: int - Sharded: int - Unknown: int - LoadBalanced: int - - -TOPOLOGY_TYPE = _TopologyType(*range(6)) - -# Topologies compatible with SRV record polling. -SRV_POLLING_TOPOLOGIES: tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) - - -_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]] - - -class TopologyDescription: - def __init__( - self, - topology_type: int, - server_descriptions: dict[_Address, ServerDescription], - replica_set_name: Optional[str], - max_set_version: Optional[int], - max_election_id: Optional[ObjectId], - topology_settings: Any, - ) -> None: - """Representation of a deployment of MongoDB servers. - - :param topology_type: initial type - :param server_descriptions: dict of (address, ServerDescription) for - all seeds - :param replica_set_name: replica set name or None - :param max_set_version: greatest setVersion seen from a primary, or None - :param max_election_id: greatest electionId seen from a primary, or None - :param topology_settings: a TopologySettings - """ - self._topology_type = topology_type - self._replica_set_name = replica_set_name - self._server_descriptions = server_descriptions - self._max_set_version = max_set_version - self._max_election_id = max_election_id - - # The heartbeat_frequency is used in staleness estimates. - self._topology_settings = topology_settings - - # Is PyMongo compatible with all servers' wire protocols? - self._incompatible_err = None - if self._topology_type != TOPOLOGY_TYPE.LoadBalanced: - self._init_incompatible_err() - - # Server Discovery And Monitoring Spec: Whenever a client updates the - # TopologyDescription from an hello response, it MUST set - # TopologyDescription.logicalSessionTimeoutMinutes to the smallest - # logicalSessionTimeoutMinutes value among ServerDescriptions of all - # data-bearing server types. If any have a null - # logicalSessionTimeoutMinutes, then - # TopologyDescription.logicalSessionTimeoutMinutes MUST be set to null. - readable_servers = self.readable_servers - if not readable_servers: - self._ls_timeout_minutes = None - elif any(s.logical_session_timeout_minutes is None for s in readable_servers): - self._ls_timeout_minutes = None - else: - self._ls_timeout_minutes = min( # type: ignore[type-var] - s.logical_session_timeout_minutes for s in readable_servers - ) - - def _init_incompatible_err(self) -> None: - """Internal compatibility check for non-load balanced topologies.""" - for s in self._server_descriptions.values(): - if not s.is_server_type_known: - continue - - # s.min/max_wire_version is the server's wire protocol. - # MIN/MAX_SUPPORTED_WIRE_VERSION is what PyMongo supports. - server_too_new = ( - # Server too new. - s.min_wire_version is not None - and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION - ) - - server_too_old = ( - # Server too old. - s.max_wire_version is not None - and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION - ) - - if server_too_new: - self._incompatible_err = ( - "Server at %s:%d requires wire version %d, but this " # type: ignore - "version of PyMongo only supports up to %d." - % ( - s.address[0], - s.address[1] or 0, - s.min_wire_version, - common.MAX_SUPPORTED_WIRE_VERSION, - ) - ) - - elif server_too_old: - self._incompatible_err = ( - "Server at %s:%d reports wire version %d, but this " # type: ignore - "version of PyMongo requires at least %d (MongoDB %s)." - % ( - s.address[0], - s.address[1] or 0, - s.max_wire_version, - common.MIN_SUPPORTED_WIRE_VERSION, - common.MIN_SUPPORTED_SERVER_VERSION, - ) - ) - - break - - def check_compatible(self) -> None: - """Raise ConfigurationError if any server is incompatible. - - A server is incompatible if its wire protocol version range does not - overlap with PyMongo's. - """ - if self._incompatible_err: - raise ConfigurationError(self._incompatible_err) - - def has_server(self, address: _Address) -> bool: - return address in self._server_descriptions - - def reset_server(self, address: _Address) -> TopologyDescription: - """A copy of this description, with one server marked Unknown.""" - unknown_sd = self._server_descriptions[address].to_unknown() - return updated_topology_description(self, unknown_sd) - - def reset(self) -> TopologyDescription: - """A copy of this description, with all servers marked Unknown.""" - if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: - topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary - else: - topology_type = self._topology_type - - # The default ServerDescription's type is Unknown. - sds = {address: ServerDescription(address) for address in self._server_descriptions} - - return TopologyDescription( - topology_type, - sds, - self._replica_set_name, - self._max_set_version, - self._max_election_id, - self._topology_settings, - ) - - def server_descriptions(self) -> dict[_Address, ServerDescription]: - """dict of (address, - :class:`~pymongo.server_description.ServerDescription`). - """ - return self._server_descriptions.copy() - - @property - def topology_type(self) -> int: - """The type of this topology.""" - return self._topology_type - - @property - def topology_type_name(self) -> str: - """The topology type as a human readable string. - - .. versionadded:: 3.4 - """ - return TOPOLOGY_TYPE._fields[self._topology_type] - - @property - def replica_set_name(self) -> Optional[str]: - """The replica set name.""" - return self._replica_set_name - - @property - def max_set_version(self) -> Optional[int]: - """Greatest setVersion seen from a primary, or None.""" - return self._max_set_version - - @property - def max_election_id(self) -> Optional[ObjectId]: - """Greatest electionId seen from a primary, or None.""" - return self._max_election_id - - @property - def logical_session_timeout_minutes(self) -> Optional[int]: - """Minimum logical session timeout, or None.""" - return self._ls_timeout_minutes - - @property - def known_servers(self) -> list[ServerDescription]: - """List of Servers of types besides Unknown.""" - return [s for s in self._server_descriptions.values() if s.is_server_type_known] - - @property - def has_known_servers(self) -> bool: - """Whether there are any Servers of types besides Unknown.""" - return any(s for s in self._server_descriptions.values() if s.is_server_type_known) - - @property - def readable_servers(self) -> list[ServerDescription]: - """List of readable Servers.""" - return [s for s in self._server_descriptions.values() if s.is_readable] - - @property - def common_wire_version(self) -> Optional[int]: - """Minimum of all servers' max wire versions, or None.""" - servers = self.known_servers - if servers: - return min(s.max_wire_version for s in self.known_servers) - - return None - - @property - def heartbeat_frequency(self) -> int: - return self._topology_settings.heartbeat_frequency - - @property - def srv_max_hosts(self) -> int: - return self._topology_settings._srv_max_hosts - - def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]: - if not selection: - return [] - round_trip_times: list[float] = [] - for server in selection.server_descriptions: - if server.round_trip_time is None: - config_err_msg = f"round_trip_time for server {server.address} is unexpectedly None: {self}, servers: {selection.server_descriptions}" - raise ConfigurationError(config_err_msg) - round_trip_times.append(server.round_trip_time) - # Round trip time in seconds. - fastest = min(round_trip_times) - threshold = self._topology_settings.local_threshold_ms / 1000.0 - return [ - s - for s in selection.server_descriptions - if (cast(float, s.round_trip_time) - fastest) <= threshold - ] - - def apply_selector( - self, - selector: Any, - address: Optional[_Address] = None, - custom_selector: Optional[_ServerSelector] = None, - ) -> list[ServerDescription]: - """List of servers matching the provided selector(s). - - :param selector: a callable that takes a Selection as input and returns - a Selection as output. For example, an instance of a read - preference from :mod:`~pymongo.read_preferences`. - :param address: A server address to select. - :param custom_selector: A callable that augments server - selection rules. Accepts a list of - :class:`~pymongo.server_description.ServerDescription` objects and - return a list of server descriptions that should be considered - suitable for the desired operation. - - .. versionadded:: 3.4 - """ - if getattr(selector, "min_wire_version", 0): - common_wv = self.common_wire_version - if common_wv and common_wv < selector.min_wire_version: - raise ConfigurationError( - "%s requires min wire version %d, but topology's min" - " wire version is %d" % (selector, selector.min_wire_version, common_wv) - ) - - if isinstance(selector, _AggWritePref): - selector.selection_hook(self) - - if self.topology_type == TOPOLOGY_TYPE.Unknown: - return [] - elif self.topology_type in (TOPOLOGY_TYPE.Single, TOPOLOGY_TYPE.LoadBalanced): - # Ignore selectors for standalone and load balancer mode. - return self.known_servers - if address: - # Ignore selectors when explicit address is requested. - description = self.server_descriptions().get(address) - return [description] if description else [] - - selection = Selection.from_topology_description(self) - # Ignore read preference for sharded clusters. - if self.topology_type != TOPOLOGY_TYPE.Sharded: - selection = selector(selection) - - # Apply custom selector followed by localThresholdMS. - if custom_selector is not None and selection: - selection = selection.with_server_descriptions( - custom_selector(selection.server_descriptions) - ) - return self._apply_local_threshold(selection) - - def has_readable_server(self, read_preference: _ServerMode = ReadPreference.PRIMARY) -> bool: - """Does this topology have any readable servers available matching the - given read preference? - - :param read_preference: an instance of a read preference from - :mod:`~pymongo.read_preferences`. Defaults to - :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. - - .. note:: When connected directly to a single server this method - always returns ``True``. - - .. versionadded:: 3.4 - """ - common.validate_read_preference("read_preference", read_preference) - return any(self.apply_selector(read_preference)) - - def has_writable_server(self) -> bool: - """Does this topology have a writable server available? - - .. note:: When connected directly to a single server this method - always returns ``True``. - - .. versionadded:: 3.4 - """ - return self.has_readable_server(ReadPreference.PRIMARY) - - def __repr__(self) -> str: - # Sort the servers by address. - servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address) - return "<{} id: {}, topology_type: {}, servers: {!r}>".format( - self.__class__.__name__, - self._topology_settings._topology_id, - self.topology_type_name, - servers, - ) - - -# If topology type is Unknown and we receive a hello response, what should -# the new topology type be? -_SERVER_TYPE_TO_TOPOLOGY_TYPE = { - SERVER_TYPE.Mongos: TOPOLOGY_TYPE.Sharded, - SERVER_TYPE.RSPrimary: TOPOLOGY_TYPE.ReplicaSetWithPrimary, - SERVER_TYPE.RSSecondary: TOPOLOGY_TYPE.ReplicaSetNoPrimary, - SERVER_TYPE.RSArbiter: TOPOLOGY_TYPE.ReplicaSetNoPrimary, - SERVER_TYPE.RSOther: TOPOLOGY_TYPE.ReplicaSetNoPrimary, - # Note: SERVER_TYPE.LoadBalancer and Unknown are intentionally left out. -} - - -def updated_topology_description( - topology_description: TopologyDescription, server_description: ServerDescription -) -> TopologyDescription: - """Return an updated copy of a TopologyDescription. - - :param topology_description: the current TopologyDescription - :param server_description: a new ServerDescription that resulted from - a hello call - - Called after attempting (successfully or not) to call hello on the - server at server_description.address. Does not modify topology_description. - """ - address = server_description.address - - # These values will be updated, if necessary, to form the new - # TopologyDescription. - topology_type = topology_description.topology_type - set_name = topology_description.replica_set_name - max_set_version = topology_description.max_set_version - max_election_id = topology_description.max_election_id - server_type = server_description.server_type - - # Don't mutate the original dict of server descriptions; copy it. - sds = topology_description.server_descriptions() - - # Replace this server's description with the new one. - sds[address] = server_description - - if topology_type == TOPOLOGY_TYPE.Single: - # Set server type to Unknown if replica set name does not match. - if set_name is not None and set_name != server_description.replica_set_name: - error = ConfigurationError( - "client is configured to connect to a replica set named " - "'{}' but this node belongs to a set named '{}'".format( - set_name, server_description.replica_set_name - ) - ) - sds[address] = server_description.to_unknown(error=error) - # Single type never changes. - return TopologyDescription( - TOPOLOGY_TYPE.Single, - sds, - set_name, - max_set_version, - max_election_id, - topology_description._topology_settings, - ) - - if topology_type == TOPOLOGY_TYPE.Unknown: - if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.LoadBalancer): - if len(topology_description._topology_settings.seeds) == 1: - topology_type = TOPOLOGY_TYPE.Single - else: - # Remove standalone from Topology when given multiple seeds. - sds.pop(address) - elif server_type not in (SERVER_TYPE.Unknown, SERVER_TYPE.RSGhost): - topology_type = _SERVER_TYPE_TO_TOPOLOGY_TYPE[server_type] - - if topology_type == TOPOLOGY_TYPE.Sharded: - if server_type not in (SERVER_TYPE.Mongos, SERVER_TYPE.Unknown): - sds.pop(address) - - elif topology_type == TOPOLOGY_TYPE.ReplicaSetNoPrimary: - if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): - sds.pop(address) - - elif server_type == SERVER_TYPE.RSPrimary: - (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( - sds, set_name, server_description, max_set_version, max_election_id - ) - - elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): - topology_type, set_name = _update_rs_no_primary_from_member( - sds, set_name, server_description - ) - - elif topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: - if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): - sds.pop(address) - topology_type = _check_has_primary(sds) - - elif server_type == SERVER_TYPE.RSPrimary: - (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( - sds, set_name, server_description, max_set_version, max_election_id - ) - - elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): - topology_type = _update_rs_with_primary_from_member(sds, set_name, server_description) - - else: - # Server type is Unknown or RSGhost: did we just lose the primary? - topology_type = _check_has_primary(sds) - - # Return updated copy. - return TopologyDescription( - topology_type, - sds, - set_name, - max_set_version, - max_election_id, - topology_description._topology_settings, - ) - - -def _updated_topology_description_srv_polling( - topology_description: TopologyDescription, seedlist: list[tuple[str, Any]] -) -> TopologyDescription: - """Return an updated copy of a TopologyDescription. - - :param topology_description: the current TopologyDescription - :param seedlist: a list of new seeds new ServerDescription that resulted from - a hello call - """ - assert topology_description.topology_type in SRV_POLLING_TOPOLOGIES - # Create a copy of the server descriptions. - sds = topology_description.server_descriptions() - - # If seeds haven't changed, don't do anything. - if set(sds.keys()) == set(seedlist): - return topology_description - - # Remove SDs corresponding to servers no longer part of the SRV record. - for address in list(sds.keys()): - if address not in seedlist: - sds.pop(address) - - if topology_description.srv_max_hosts != 0: - new_hosts = set(seedlist) - set(sds.keys()) - n_to_add = topology_description.srv_max_hosts - len(sds) - if n_to_add > 0: - seedlist = sample(sorted(new_hosts), min(n_to_add, len(new_hosts))) - else: - seedlist = [] - # Add SDs corresponding to servers recently added to the SRV record. - for address in seedlist: - if address not in sds: - sds[address] = ServerDescription(address) - return TopologyDescription( - topology_description.topology_type, - sds, - topology_description.replica_set_name, - topology_description.max_set_version, - topology_description.max_election_id, - topology_description._topology_settings, - ) - - -def _update_rs_from_primary( - sds: MutableMapping[_Address, ServerDescription], - replica_set_name: Optional[str], - server_description: ServerDescription, - max_set_version: Optional[int], - max_election_id: Optional[ObjectId], -) -> tuple[int, Optional[str], Optional[int], Optional[ObjectId]]: - """Update topology description from a primary's hello response. - - Pass in a dict of ServerDescriptions, current replica set name, the - ServerDescription we are processing, and the TopologyDescription's - max_set_version and max_election_id if any. - - Returns (new topology type, new replica_set_name, new max_set_version, - new max_election_id). - """ - if replica_set_name is None: - replica_set_name = server_description.replica_set_name - - elif replica_set_name != server_description.replica_set_name: - # We found a primary but it doesn't have the replica_set_name - # provided by the user. - sds.pop(server_description.address) - return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id - - if server_description.max_wire_version is None or server_description.max_wire_version < 17: - new_election_tuple: tuple = (server_description.set_version, server_description.election_id) - max_election_tuple: tuple = (max_set_version, max_election_id) - if None not in new_election_tuple: - if None not in max_election_tuple and new_election_tuple < max_election_tuple: - # Stale primary, set to type Unknown. - sds[server_description.address] = server_description.to_unknown() - return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id - max_election_id = server_description.election_id - - if server_description.set_version is not None and ( - max_set_version is None or server_description.set_version > max_set_version - ): - max_set_version = server_description.set_version - else: - new_election_tuple = server_description.election_id, server_description.set_version - max_election_tuple = max_election_id, max_set_version - new_election_safe = tuple(MinKey() if i is None else i for i in new_election_tuple) - max_election_safe = tuple(MinKey() if i is None else i for i in max_election_tuple) - if new_election_safe < max_election_safe: - # Stale primary, set to type Unknown. - sds[server_description.address] = server_description.to_unknown() - return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id - else: - max_election_id = server_description.election_id - max_set_version = server_description.set_version - - # We've heard from the primary. Is it the same primary as before? - for server in sds.values(): - if ( - server.server_type is SERVER_TYPE.RSPrimary - and server.address != server_description.address - ): - # Reset old primary's type to Unknown. - sds[server.address] = server.to_unknown() - - # There can be only one prior primary. - break - - # Discover new hosts from this primary's response. - for new_address in server_description.all_hosts: - if new_address not in sds: - sds[new_address] = ServerDescription(new_address) - - # Remove hosts not in the response. - for addr in set(sds) - server_description.all_hosts: - sds.pop(addr) - - # If the host list differs from the seed list, we may not have a primary - # after all. - return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id) - - -def _update_rs_with_primary_from_member( - sds: MutableMapping[_Address, ServerDescription], - replica_set_name: Optional[str], - server_description: ServerDescription, -) -> int: - """RS with known primary. Process a response from a non-primary. - - Pass in a dict of ServerDescriptions, current replica set name, and the - ServerDescription we are processing. - - Returns new topology type. - """ - assert replica_set_name is not None - - if replica_set_name != server_description.replica_set_name: - sds.pop(server_description.address) - elif server_description.me and server_description.address != server_description.me: - sds.pop(server_description.address) - - # Had this member been the primary? - return _check_has_primary(sds) - - -def _update_rs_no_primary_from_member( - sds: MutableMapping[_Address, ServerDescription], - replica_set_name: Optional[str], - server_description: ServerDescription, -) -> tuple[int, Optional[str]]: - """RS without known primary. Update from a non-primary's response. - - Pass in a dict of ServerDescriptions, current replica set name, and the - ServerDescription we are processing. - - Returns (new topology type, new replica_set_name). - """ - topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary - if replica_set_name is None: - replica_set_name = server_description.replica_set_name - - elif replica_set_name != server_description.replica_set_name: - sds.pop(server_description.address) - return topology_type, replica_set_name - - # This isn't the primary's response, so don't remove any servers - # it doesn't report. Only add new servers. - for address in server_description.all_hosts: - if address not in sds: - sds[address] = ServerDescription(address) - - if server_description.me and server_description.address != server_description.me: - sds.pop(server_description.address) - - return topology_type, replica_set_name - - -def _check_has_primary(sds: Mapping[_Address, ServerDescription]) -> int: - """Current topology type is ReplicaSetWithPrimary. Is primary still known? - - Pass in a dict of ServerDescriptions. - - Returns new topology type. - """ - for s in sds.values(): - if s.server_type == SERVER_TYPE.RSPrimary: - return TOPOLOGY_TYPE.ReplicaSetWithPrimary - else: # noqa: PLW0120 - return TOPOLOGY_TYPE.ReplicaSetNoPrimary diff --git a/pymongo/synchronous/uri_parser.py b/pymongo/synchronous/uri_parser.py deleted file mode 100644 index 8e37bdc69..000000000 --- a/pymongo/synchronous/uri_parser.py +++ /dev/null @@ -1,624 +0,0 @@ -# Copyright 2011-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - - -"""Tools to parse and validate a MongoDB URI.""" -from __future__ import annotations - -import re -import sys -import warnings -from typing import ( - TYPE_CHECKING, - Any, - Mapping, - MutableMapping, - Optional, - Sized, - Union, - cast, -) -from urllib.parse import unquote_plus - -from pymongo.errors import ConfigurationError, InvalidURI -from pymongo.synchronous.client_options import _parse_ssl_options -from pymongo.synchronous.common import ( - INTERNAL_URI_OPTION_NAME_MAP, - SRV_SERVICE_NAME, - URI_OPTIONS_DEPRECATION_MAP, - _CaseInsensitiveDictionary, - get_validated_options, -) -from pymongo.synchronous.srv_resolver import _have_dnspython, _SrvResolver -from pymongo.synchronous.typings import _Address - -if TYPE_CHECKING: - from pymongo.pyopenssl_context import SSLContext - -_IS_SYNC = True -SCHEME = "mongodb://" -SCHEME_LEN = len(SCHEME) -SRV_SCHEME = "mongodb+srv://" -SRV_SCHEME_LEN = len(SRV_SCHEME) -DEFAULT_PORT = 27017 - - -def _unquoted_percent(s: str) -> bool: - """Check for unescaped percent signs. - - :param s: A string. `s` can have things like '%25', '%2525', - and '%E2%85%A8' but cannot have unquoted percent like '%foo'. - """ - for i in range(len(s)): - if s[i] == "%": - sub = s[i : i + 3] - # If unquoting yields the same string this means there was an - # unquoted %. - if unquote_plus(sub) == sub: - return True - return False - - -def parse_userinfo(userinfo: str) -> tuple[str, str]: - """Validates the format of user information in a MongoDB URI. - Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", - "]", "@") as per RFC 3986 must be escaped. - - Returns a 2-tuple containing the unescaped username followed - by the unescaped password. - - :param userinfo: A string of the form : - """ - if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo): - raise InvalidURI( - "Username and password must be escaped according to " - "RFC 3986, use urllib.parse.quote_plus" - ) - - user, _, passwd = userinfo.partition(":") - # No password is expected with GSSAPI authentication. - if not user: - raise InvalidURI("The empty string is not valid username.") - - return unquote_plus(user), unquote_plus(passwd) - - -def parse_ipv6_literal_host( - entity: str, default_port: Optional[int] -) -> tuple[str, Optional[Union[str, int]]]: - """Validates an IPv6 literal host:port string. - - Returns a 2-tuple of IPv6 literal followed by port where - port is default_port if it wasn't specified in entity. - - :param entity: A string that represents an IPv6 literal enclosed - in braces (e.g. '[::1]' or '[::1]:27017'). - :param default_port: The port number to use when one wasn't - specified in entity. - """ - if entity.find("]") == -1: - raise ValueError( - "an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732." - ) - i = entity.find("]:") - if i == -1: - return entity[1:-1], default_port - return entity[1:i], entity[i + 2 :] - - -def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address: - """Validates a host string - - Returns a 2-tuple of host followed by port where port is default_port - if it wasn't specified in the string. - - :param entity: A host or host:port string where host could be a - hostname or IP address. - :param default_port: The port number to use when one wasn't - specified in entity. - """ - host = entity - port: Optional[Union[str, int]] = default_port - if entity[0] == "[": - host, port = parse_ipv6_literal_host(entity, default_port) - elif entity.endswith(".sock"): - return entity, default_port - elif entity.find(":") != -1: - if entity.count(":") > 1: - raise ValueError( - "Reserved characters such as ':' must be " - "escaped according RFC 2396. An IPv6 " - "address literal must be enclosed in '[' " - "and ']' according to RFC 2732." - ) - host, port = host.split(":", 1) - if isinstance(port, str): - if not port.isdigit() or int(port) > 65535 or int(port) <= 0: - raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}") - port = int(port) - - # Normalize hostname to lowercase, since DNS is case-insensitive: - # http://tools.ietf.org/html/rfc4343 - # This prevents useless rediscovery if "foo.com" is in the seed list but - # "FOO.com" is in the hello response. - return host.lower(), port - - -# Options whose values are implicitly determined by tlsInsecure. -_IMPLICIT_TLSINSECURE_OPTS = { - "tlsallowinvalidcertificates", - "tlsallowinvalidhostnames", - "tlsdisableocspendpointcheck", -} - - -def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary: - """Helper method for split_options which creates the options dict. - Also handles the creation of a list for the URI tag_sets/ - readpreferencetags portion, and the use of a unicode options string. - """ - options = _CaseInsensitiveDictionary() - for uriopt in opts.split(delim): - key, value = uriopt.split("=") - if key.lower() == "readpreferencetags": - options.setdefault(key, []).append(value) - else: - if key in options: - warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2) - if key.lower() == "authmechanismproperties": - val = value - else: - val = unquote_plus(value) - options[key] = val - - return options - - -def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Raise appropriate errors when conflicting TLS options are present in - the options dictionary. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - # Implicitly defined options must not be explicitly specified. - tlsinsecure = options.get("tlsinsecure") - if tlsinsecure is not None: - for opt in _IMPLICIT_TLSINSECURE_OPTS: - if opt in options: - err_msg = "URI options %s and %s cannot be specified simultaneously." - raise InvalidURI( - err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt)) - ) - - # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. - tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates") - if tlsallowinvalidcerts is not None: - if "tlsdisableocspendpointcheck" in options: - err_msg = "URI options %s and %s cannot be specified simultaneously." - raise InvalidURI( - err_msg - % ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck")) - ) - if tlsallowinvalidcerts is True: - options["tlsdisableocspendpointcheck"] = True - - # Handle co-occurence of CRL and OCSP-related options. - tlscrlfile = options.get("tlscrlfile") - if tlscrlfile is not None: - for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"): - if options.get(opt) is True: - err_msg = "URI option %s=True cannot be specified when CRL checking is enabled." - raise InvalidURI(err_msg % (opt,)) - - if "ssl" in options and "tls" in options: - - def truth_value(val: Any) -> Any: - if val in ("true", "false"): - return val == "true" - if isinstance(val, bool): - return val - return val - - if truth_value(options.get("ssl")) != truth_value(options.get("tls")): - err_msg = "Can not specify conflicting values for URI options %s and %s." - raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls"))) - - return options - - -def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Issue appropriate warnings when deprecated options are present in the - options dictionary. Removes deprecated option key, value pairs if the - options dictionary is found to also have the renamed option. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - for optname in list(options): - if optname in URI_OPTIONS_DEPRECATION_MAP: - mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] - if mode == "renamed": - newoptname = message - if newoptname in options: - warn_msg = "Deprecated option '%s' ignored in favor of '%s'." - warnings.warn( - warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), - DeprecationWarning, - stacklevel=2, - ) - options.pop(optname) - continue - warn_msg = "Option '%s' is deprecated, use '%s' instead." - warnings.warn( - warn_msg % (options.cased_key(optname), newoptname), - DeprecationWarning, - stacklevel=2, - ) - elif mode == "removed": - warn_msg = "Option '%s' is deprecated. %s." - warnings.warn( - warn_msg % (options.cased_key(optname), message), - DeprecationWarning, - stacklevel=2, - ) - - return options - - -def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Normalizes option names in the options dictionary by converting them to - their internally-used names. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - # Expand the tlsInsecure option. - tlsinsecure = options.get("tlsinsecure") - if tlsinsecure is not None: - for opt in _IMPLICIT_TLSINSECURE_OPTS: - # Implicit options are logically the same as tlsInsecure. - options[opt] = tlsinsecure - - for optname in list(options): - intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None) - if intname is not None: - options[intname] = options.pop(optname) - - return options - - -def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]: - """Validates and normalizes options passed in a MongoDB URI. - - Returns a new dictionary of validated and normalized options. If warn is - False then errors will be thrown for invalid options, otherwise they will - be ignored and a warning will be issued. - - :param opts: A dict of MongoDB URI options. - :param warn: If ``True`` then warnings will be logged and - invalid options will be ignored. Otherwise invalid options will - cause errors. - """ - return get_validated_options(opts, warn) - - -def split_options( - opts: str, validate: bool = True, warn: bool = False, normalize: bool = True -) -> MutableMapping[str, Any]: - """Takes the options portion of a MongoDB URI, validates each option - and returns the options in a dictionary. - - :param opt: A string representing MongoDB URI options. - :param validate: If ``True`` (the default), validate and normalize all - options. - :param warn: If ``False`` (the default), suppress all warnings raised - during validation of options. - :param normalize: If ``True`` (the default), renames all options to their - internally-used names. - """ - and_idx = opts.find("&") - semi_idx = opts.find(";") - try: - if and_idx >= 0 and semi_idx >= 0: - raise InvalidURI("Can not mix '&' and ';' for option separators.") - elif and_idx >= 0: - options = _parse_options(opts, "&") - elif semi_idx >= 0: - options = _parse_options(opts, ";") - elif opts.find("=") != -1: - options = _parse_options(opts, None) - else: - raise ValueError - except ValueError: - raise InvalidURI("MongoDB URI options are key=value pairs.") from None - - options = _handle_security_options(options) - - options = _handle_option_deprecations(options) - - if normalize: - options = _normalize_options(options) - - if validate: - options = cast(_CaseInsensitiveDictionary, validate_options(options, warn)) - if options.get("authsource") == "": - raise InvalidURI("the authSource database cannot be an empty string") - - return options - - -def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]: - """Takes a string of the form host1[:port],host2[:port]... and - splits it into (host, port) tuples. If [:port] isn't present the - default_port is used. - - Returns a set of 2-tuples containing the host name (or IP) followed by - port number. - - :param hosts: A string of the form host1[:port],host2[:port],... - :param default_port: The port number to use when one wasn't specified - for a host. - """ - nodes = [] - for entity in hosts.split(","): - if not entity: - raise ConfigurationError("Empty host (or extra comma in host list).") - port = default_port - # Unix socket entities don't have ports - if entity.endswith(".sock"): - port = None - nodes.append(parse_host(entity, port)) - return nodes - - -# Prohibited characters in database name. DB names also can't have ".", but for -# backward-compat we allow "db.collection" in URI. -_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]") - -_ALLOWED_TXT_OPTS = frozenset( - ["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"] -) - - -def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None: - # Ensure directConnection was not True if there are multiple seeds. - if len(nodes) > 1 and options.get("directconnection"): - raise ConfigurationError("Cannot specify multiple hosts with directConnection=true") - - if options.get("loadbalanced"): - if len(nodes) > 1: - raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true") - if options.get("directconnection"): - raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true") - if options.get("replicaset"): - raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") - - -def parse_uri( - uri: str, - default_port: Optional[int] = DEFAULT_PORT, - validate: bool = True, - warn: bool = False, - normalize: bool = True, - connect_timeout: Optional[float] = None, - srv_service_name: Optional[str] = None, - srv_max_hosts: Optional[int] = None, -) -> dict[str, Any]: - """Parse and validate a MongoDB URI. - - Returns a dict of the form:: - - { - 'nodelist': , - 'username': or None, - 'password': or None, - 'database': or None, - 'collection': or None, - 'options': , - 'fqdn': or None - } - - If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done - to build nodelist and options. - - :param uri: The MongoDB URI to parse. - :param default_port: The port number to use when one wasn't specified - for a host in the URI. - :param validate: If ``True`` (the default), validate and - normalize all options. Default: ``True``. - :param warn: When validating, if ``True`` then will warn - the user then ignore any invalid options or values. If ``False``, - validation will error when options are unsupported or values are - invalid. Default: ``False``. - :param normalize: If ``True``, convert names of URI options - to their internally-used names. Default: ``True``. - :param connect_timeout: The maximum time in milliseconds to - wait for a response from the DNS server. - :param srv_service_name: A custom SRV service name - - .. versionchanged:: 4.6 - The delimiting slash (``/``) between hosts and connection options is now optional. - For example, "mongodb://example.com?tls=true" is now a valid URI. - - .. versionchanged:: 4.0 - To better follow RFC 3986, unquoted percent signs ("%") are no longer - supported. - - .. versionchanged:: 3.9 - Added the ``normalize`` parameter. - - .. versionchanged:: 3.6 - Added support for mongodb+srv:// URIs. - - .. versionchanged:: 3.5 - Return the original value of the ``readPreference`` MongoDB URI option - instead of the validated read preference mode. - - .. versionchanged:: 3.1 - ``warn`` added so invalid options can be ignored. - """ - if uri.startswith(SCHEME): - is_srv = False - scheme_free = uri[SCHEME_LEN:] - elif uri.startswith(SRV_SCHEME): - if not _have_dnspython(): - python_path = sys.executable or "python" - raise ConfigurationError( - 'The "dnspython" module must be ' - "installed to use mongodb+srv:// URIs. " - "To fix this error install pymongo again:\n " - "%s -m pip install pymongo>=4.3" % (python_path) - ) - is_srv = True - scheme_free = uri[SRV_SCHEME_LEN:] - else: - raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") - - if not scheme_free: - raise InvalidURI("Must provide at least one hostname or IP.") - - user = None - passwd = None - dbase = None - collection = None - options = _CaseInsensitiveDictionary() - - host_plus_db_part, _, opts = scheme_free.partition("?") - if "/" in host_plus_db_part: - host_part, _, dbase = host_plus_db_part.partition("/") - else: - host_part = host_plus_db_part - - if dbase: - dbase = unquote_plus(dbase) - if "." in dbase: - dbase, collection = dbase.split(".", 1) - if _BAD_DB_CHARS.search(dbase): - raise InvalidURI('Bad database name "%s"' % dbase) - else: - dbase = None - - if opts: - options.update(split_options(opts, validate, warn, normalize)) - if srv_service_name is None: - srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) - if "@" in host_part: - userinfo, _, hosts = host_part.rpartition("@") - user, passwd = parse_userinfo(userinfo) - else: - hosts = host_part - - if "/" in hosts: - raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) - - hosts = unquote_plus(hosts) - fqdn = None - srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") - if is_srv: - if options.get("directConnection"): - raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") - nodes = split_hosts(hosts, default_port=None) - if len(nodes) != 1: - raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") - fqdn, port = nodes[0] - if port is not None: - raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") - - # Use the connection timeout. connectTimeoutMS passed as a keyword - # argument overrides the same option passed in the connection string. - connect_timeout = connect_timeout or options.get("connectTimeoutMS") - dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) - nodes = dns_resolver.get_hosts() - dns_options = dns_resolver.get_options() - if dns_options: - parsed_dns_options = split_options(dns_options, validate, warn, normalize) - if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: - raise ConfigurationError( - "Only authSource, replicaSet, and loadBalanced are supported from DNS" - ) - for opt, val in parsed_dns_options.items(): - if opt not in options: - options[opt] = val - if options.get("loadBalanced") and srv_max_hosts: - raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") - if options.get("replicaSet") and srv_max_hosts: - raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") - if "tls" not in options and "ssl" not in options: - options["tls"] = True if validate else "true" - elif not is_srv and options.get("srvServiceName") is not None: - raise ConfigurationError( - "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" - ) - elif not is_srv and srv_max_hosts: - raise ConfigurationError( - "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" - ) - else: - nodes = split_hosts(hosts, default_port=default_port) - - _check_options(nodes, options) - - return { - "nodelist": nodes, - "username": user, - "password": passwd, - "database": dbase, - "collection": collection, - "options": options, - "fqdn": fqdn, - } - - -def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]: - """Parse KMS TLS connection options.""" - if not kms_tls_options: - return {} - if not isinstance(kms_tls_options, dict): - raise TypeError("kms_tls_options must be a dict") - contexts = {} - for provider, options in kms_tls_options.items(): - if not isinstance(options, dict): - raise TypeError(f'kms_tls_options["{provider}"] must be a dict') - options.setdefault("tls", True) - opts = _CaseInsensitiveDictionary(options) - opts = _handle_security_options(opts) - opts = _normalize_options(opts) - opts = cast(_CaseInsensitiveDictionary, validate_options(opts)) - ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts) - if ssl_context is None: - raise ConfigurationError("TLS is required for KMS providers") - if allow_invalid_hostnames: - raise ConfigurationError("Insecure TLS options prohibited") - - for n in [ - "tlsInsecure", - "tlsAllowInvalidCertificates", - "tlsAllowInvalidHostnames", - "tlsDisableCertificateRevocationCheck", - ]: - if n in opts: - raise ConfigurationError(f"Insecure TLS options prohibited: {n}") - contexts[provider] = ssl_context - return contexts - - -if __name__ == "__main__": - import pprint - - try: - pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 - except InvalidURI as exc: - print(exc) # noqa: T201 - sys.exit(0) diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index 201d9b390..cc2330cba 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -1,21 +1,676 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2014-present MongoDB, Inc. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. -"""Re-import of synchronous TopologyDescription API for compatibility.""" +"""Represent a deployment of MongoDB servers.""" from __future__ import annotations -from pymongo.synchronous.topology_description import * # noqa: F403 -from pymongo.synchronous.topology_description import __doc__ as original_doc +from random import sample +from typing import ( + Any, + Callable, + List, + Mapping, + MutableMapping, + NamedTuple, + Optional, + cast, +) -__doc__ = original_doc +from bson.min_key import MinKey +from bson.objectid import ObjectId +from pymongo import common +from pymongo.errors import ConfigurationError +from pymongo.read_preferences import ReadPreference, _AggWritePref, _ServerMode +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import Selection +from pymongo.server_type import SERVER_TYPE +from pymongo.typings import _Address + + +# Enumeration for various kinds of MongoDB cluster topologies. +class _TopologyType(NamedTuple): + Single: int + ReplicaSetNoPrimary: int + ReplicaSetWithPrimary: int + Sharded: int + Unknown: int + LoadBalanced: int + + +TOPOLOGY_TYPE = _TopologyType(*range(6)) + +# Topologies compatible with SRV record polling. +SRV_POLLING_TOPOLOGIES: tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) + + +_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]] + + +class TopologyDescription: + def __init__( + self, + topology_type: int, + server_descriptions: dict[_Address, ServerDescription], + replica_set_name: Optional[str], + max_set_version: Optional[int], + max_election_id: Optional[ObjectId], + topology_settings: Any, + ) -> None: + """Representation of a deployment of MongoDB servers. + + :param topology_type: initial type + :param server_descriptions: dict of (address, ServerDescription) for + all seeds + :param replica_set_name: replica set name or None + :param max_set_version: greatest setVersion seen from a primary, or None + :param max_election_id: greatest electionId seen from a primary, or None + :param topology_settings: a TopologySettings + """ + self._topology_type = topology_type + self._replica_set_name = replica_set_name + self._server_descriptions = server_descriptions + self._max_set_version = max_set_version + self._max_election_id = max_election_id + + # The heartbeat_frequency is used in staleness estimates. + self._topology_settings = topology_settings + + # Is PyMongo compatible with all servers' wire protocols? + self._incompatible_err = None + if self._topology_type != TOPOLOGY_TYPE.LoadBalanced: + self._init_incompatible_err() + + # Server Discovery And Monitoring Spec: Whenever a client updates the + # TopologyDescription from an hello response, it MUST set + # TopologyDescription.logicalSessionTimeoutMinutes to the smallest + # logicalSessionTimeoutMinutes value among ServerDescriptions of all + # data-bearing server types. If any have a null + # logicalSessionTimeoutMinutes, then + # TopologyDescription.logicalSessionTimeoutMinutes MUST be set to null. + readable_servers = self.readable_servers + if not readable_servers: + self._ls_timeout_minutes = None + elif any(s.logical_session_timeout_minutes is None for s in readable_servers): + self._ls_timeout_minutes = None + else: + self._ls_timeout_minutes = min( # type: ignore[type-var] + s.logical_session_timeout_minutes for s in readable_servers + ) + + def _init_incompatible_err(self) -> None: + """Internal compatibility check for non-load balanced topologies.""" + for s in self._server_descriptions.values(): + if not s.is_server_type_known: + continue + + # s.min/max_wire_version is the server's wire protocol. + # MIN/MAX_SUPPORTED_WIRE_VERSION is what PyMongo supports. + server_too_new = ( + # Server too new. + s.min_wire_version is not None + and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION + ) + + server_too_old = ( + # Server too old. + s.max_wire_version is not None + and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION + ) + + if server_too_new: + self._incompatible_err = ( + "Server at %s:%d requires wire version %d, but this " # type: ignore + "version of PyMongo only supports up to %d." + % ( + s.address[0], + s.address[1] or 0, + s.min_wire_version, + common.MAX_SUPPORTED_WIRE_VERSION, + ) + ) + + elif server_too_old: + self._incompatible_err = ( + "Server at %s:%d reports wire version %d, but this " # type: ignore + "version of PyMongo requires at least %d (MongoDB %s)." + % ( + s.address[0], + s.address[1] or 0, + s.max_wire_version, + common.MIN_SUPPORTED_WIRE_VERSION, + common.MIN_SUPPORTED_SERVER_VERSION, + ) + ) + + break + + def check_compatible(self) -> None: + """Raise ConfigurationError if any server is incompatible. + + A server is incompatible if its wire protocol version range does not + overlap with PyMongo's. + """ + if self._incompatible_err: + raise ConfigurationError(self._incompatible_err) + + def has_server(self, address: _Address) -> bool: + return address in self._server_descriptions + + def reset_server(self, address: _Address) -> TopologyDescription: + """A copy of this description, with one server marked Unknown.""" + unknown_sd = self._server_descriptions[address].to_unknown() + return updated_topology_description(self, unknown_sd) + + def reset(self) -> TopologyDescription: + """A copy of this description, with all servers marked Unknown.""" + if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: + topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary + else: + topology_type = self._topology_type + + # The default ServerDescription's type is Unknown. + sds = {address: ServerDescription(address) for address in self._server_descriptions} + + return TopologyDescription( + topology_type, + sds, + self._replica_set_name, + self._max_set_version, + self._max_election_id, + self._topology_settings, + ) + + def server_descriptions(self) -> dict[_Address, ServerDescription]: + """dict of (address, + :class:`~pymongo.server_description.ServerDescription`). + """ + return self._server_descriptions.copy() + + @property + def topology_type(self) -> int: + """The type of this topology.""" + return self._topology_type + + @property + def topology_type_name(self) -> str: + """The topology type as a human readable string. + + .. versionadded:: 3.4 + """ + return TOPOLOGY_TYPE._fields[self._topology_type] + + @property + def replica_set_name(self) -> Optional[str]: + """The replica set name.""" + return self._replica_set_name + + @property + def max_set_version(self) -> Optional[int]: + """Greatest setVersion seen from a primary, or None.""" + return self._max_set_version + + @property + def max_election_id(self) -> Optional[ObjectId]: + """Greatest electionId seen from a primary, or None.""" + return self._max_election_id + + @property + def logical_session_timeout_minutes(self) -> Optional[int]: + """Minimum logical session timeout, or None.""" + return self._ls_timeout_minutes + + @property + def known_servers(self) -> list[ServerDescription]: + """List of Servers of types besides Unknown.""" + return [s for s in self._server_descriptions.values() if s.is_server_type_known] + + @property + def has_known_servers(self) -> bool: + """Whether there are any Servers of types besides Unknown.""" + return any(s for s in self._server_descriptions.values() if s.is_server_type_known) + + @property + def readable_servers(self) -> list[ServerDescription]: + """List of readable Servers.""" + return [s for s in self._server_descriptions.values() if s.is_readable] + + @property + def common_wire_version(self) -> Optional[int]: + """Minimum of all servers' max wire versions, or None.""" + servers = self.known_servers + if servers: + return min(s.max_wire_version for s in self.known_servers) + + return None + + @property + def heartbeat_frequency(self) -> int: + return self._topology_settings.heartbeat_frequency + + @property + def srv_max_hosts(self) -> int: + return self._topology_settings._srv_max_hosts + + def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]: + if not selection: + return [] + round_trip_times: list[float] = [] + for server in selection.server_descriptions: + if server.round_trip_time is None: + config_err_msg = f"round_trip_time for server {server.address} is unexpectedly None: {self}, servers: {selection.server_descriptions}" + raise ConfigurationError(config_err_msg) + round_trip_times.append(server.round_trip_time) + # Round trip time in seconds. + fastest = min(round_trip_times) + threshold = self._topology_settings.local_threshold_ms / 1000.0 + return [ + s + for s in selection.server_descriptions + if (cast(float, s.round_trip_time) - fastest) <= threshold + ] + + def apply_selector( + self, + selector: Any, + address: Optional[_Address] = None, + custom_selector: Optional[_ServerSelector] = None, + ) -> list[ServerDescription]: + """List of servers matching the provided selector(s). + + :param selector: a callable that takes a Selection as input and returns + a Selection as output. For example, an instance of a read + preference from :mod:`~pymongo.read_preferences`. + :param address: A server address to select. + :param custom_selector: A callable that augments server + selection rules. Accepts a list of + :class:`~pymongo.server_description.ServerDescription` objects and + return a list of server descriptions that should be considered + suitable for the desired operation. + + .. versionadded:: 3.4 + """ + if getattr(selector, "min_wire_version", 0): + common_wv = self.common_wire_version + if common_wv and common_wv < selector.min_wire_version: + raise ConfigurationError( + "%s requires min wire version %d, but topology's min" + " wire version is %d" % (selector, selector.min_wire_version, common_wv) + ) + + if isinstance(selector, _AggWritePref): + selector.selection_hook(self) + + if self.topology_type == TOPOLOGY_TYPE.Unknown: + return [] + elif self.topology_type in (TOPOLOGY_TYPE.Single, TOPOLOGY_TYPE.LoadBalanced): + # Ignore selectors for standalone and load balancer mode. + return self.known_servers + if address: + # Ignore selectors when explicit address is requested. + description = self.server_descriptions().get(address) + return [description] if description else [] + + selection = Selection.from_topology_description(self) + # Ignore read preference for sharded clusters. + if self.topology_type != TOPOLOGY_TYPE.Sharded: + selection = selector(selection) + + # Apply custom selector followed by localThresholdMS. + if custom_selector is not None and selection: + selection = selection.with_server_descriptions( + custom_selector(selection.server_descriptions) + ) + return self._apply_local_threshold(selection) + + def has_readable_server(self, read_preference: _ServerMode = ReadPreference.PRIMARY) -> bool: + """Does this topology have any readable servers available matching the + given read preference? + + :param read_preference: an instance of a read preference from + :mod:`~pymongo.read_preferences`. Defaults to + :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. + + .. note:: When connected directly to a single server this method + always returns ``True``. + + .. versionadded:: 3.4 + """ + common.validate_read_preference("read_preference", read_preference) + return any(self.apply_selector(read_preference)) + + def has_writable_server(self) -> bool: + """Does this topology have a writable server available? + + .. note:: When connected directly to a single server this method + always returns ``True``. + + .. versionadded:: 3.4 + """ + return self.has_readable_server(ReadPreference.PRIMARY) + + def __repr__(self) -> str: + # Sort the servers by address. + servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address) + return "<{} id: {}, topology_type: {}, servers: {!r}>".format( + self.__class__.__name__, + self._topology_settings._topology_id, + self.topology_type_name, + servers, + ) + + +# If topology type is Unknown and we receive a hello response, what should +# the new topology type be? +_SERVER_TYPE_TO_TOPOLOGY_TYPE = { + SERVER_TYPE.Mongos: TOPOLOGY_TYPE.Sharded, + SERVER_TYPE.RSPrimary: TOPOLOGY_TYPE.ReplicaSetWithPrimary, + SERVER_TYPE.RSSecondary: TOPOLOGY_TYPE.ReplicaSetNoPrimary, + SERVER_TYPE.RSArbiter: TOPOLOGY_TYPE.ReplicaSetNoPrimary, + SERVER_TYPE.RSOther: TOPOLOGY_TYPE.ReplicaSetNoPrimary, + # Note: SERVER_TYPE.LoadBalancer and Unknown are intentionally left out. +} + + +def updated_topology_description( + topology_description: TopologyDescription, server_description: ServerDescription +) -> TopologyDescription: + """Return an updated copy of a TopologyDescription. + + :param topology_description: the current TopologyDescription + :param server_description: a new ServerDescription that resulted from + a hello call + + Called after attempting (successfully or not) to call hello on the + server at server_description.address. Does not modify topology_description. + """ + address = server_description.address + + # These values will be updated, if necessary, to form the new + # TopologyDescription. + topology_type = topology_description.topology_type + set_name = topology_description.replica_set_name + max_set_version = topology_description.max_set_version + max_election_id = topology_description.max_election_id + server_type = server_description.server_type + + # Don't mutate the original dict of server descriptions; copy it. + sds = topology_description.server_descriptions() + + # Replace this server's description with the new one. + sds[address] = server_description + + if topology_type == TOPOLOGY_TYPE.Single: + # Set server type to Unknown if replica set name does not match. + if set_name is not None and set_name != server_description.replica_set_name: + error = ConfigurationError( + "client is configured to connect to a replica set named " + "'{}' but this node belongs to a set named '{}'".format( + set_name, server_description.replica_set_name + ) + ) + sds[address] = server_description.to_unknown(error=error) + # Single type never changes. + return TopologyDescription( + TOPOLOGY_TYPE.Single, + sds, + set_name, + max_set_version, + max_election_id, + topology_description._topology_settings, + ) + + if topology_type == TOPOLOGY_TYPE.Unknown: + if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.LoadBalancer): + if len(topology_description._topology_settings.seeds) == 1: + topology_type = TOPOLOGY_TYPE.Single + else: + # Remove standalone from Topology when given multiple seeds. + sds.pop(address) + elif server_type not in (SERVER_TYPE.Unknown, SERVER_TYPE.RSGhost): + topology_type = _SERVER_TYPE_TO_TOPOLOGY_TYPE[server_type] + + if topology_type == TOPOLOGY_TYPE.Sharded: + if server_type not in (SERVER_TYPE.Mongos, SERVER_TYPE.Unknown): + sds.pop(address) + + elif topology_type == TOPOLOGY_TYPE.ReplicaSetNoPrimary: + if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): + sds.pop(address) + + elif server_type == SERVER_TYPE.RSPrimary: + (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( + sds, set_name, server_description, max_set_version, max_election_id + ) + + elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): + topology_type, set_name = _update_rs_no_primary_from_member( + sds, set_name, server_description + ) + + elif topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: + if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): + sds.pop(address) + topology_type = _check_has_primary(sds) + + elif server_type == SERVER_TYPE.RSPrimary: + (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( + sds, set_name, server_description, max_set_version, max_election_id + ) + + elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): + topology_type = _update_rs_with_primary_from_member(sds, set_name, server_description) + + else: + # Server type is Unknown or RSGhost: did we just lose the primary? + topology_type = _check_has_primary(sds) + + # Return updated copy. + return TopologyDescription( + topology_type, + sds, + set_name, + max_set_version, + max_election_id, + topology_description._topology_settings, + ) + + +def _updated_topology_description_srv_polling( + topology_description: TopologyDescription, seedlist: list[tuple[str, Any]] +) -> TopologyDescription: + """Return an updated copy of a TopologyDescription. + + :param topology_description: the current TopologyDescription + :param seedlist: a list of new seeds new ServerDescription that resulted from + a hello call + """ + assert topology_description.topology_type in SRV_POLLING_TOPOLOGIES + # Create a copy of the server descriptions. + sds = topology_description.server_descriptions() + + # If seeds haven't changed, don't do anything. + if set(sds.keys()) == set(seedlist): + return topology_description + + # Remove SDs corresponding to servers no longer part of the SRV record. + for address in list(sds.keys()): + if address not in seedlist: + sds.pop(address) + + if topology_description.srv_max_hosts != 0: + new_hosts = set(seedlist) - set(sds.keys()) + n_to_add = topology_description.srv_max_hosts - len(sds) + if n_to_add > 0: + seedlist = sample(sorted(new_hosts), min(n_to_add, len(new_hosts))) + else: + seedlist = [] + # Add SDs corresponding to servers recently added to the SRV record. + for address in seedlist: + if address not in sds: + sds[address] = ServerDescription(address) + return TopologyDescription( + topology_description.topology_type, + sds, + topology_description.replica_set_name, + topology_description.max_set_version, + topology_description.max_election_id, + topology_description._topology_settings, + ) + + +def _update_rs_from_primary( + sds: MutableMapping[_Address, ServerDescription], + replica_set_name: Optional[str], + server_description: ServerDescription, + max_set_version: Optional[int], + max_election_id: Optional[ObjectId], +) -> tuple[int, Optional[str], Optional[int], Optional[ObjectId]]: + """Update topology description from a primary's hello response. + + Pass in a dict of ServerDescriptions, current replica set name, the + ServerDescription we are processing, and the TopologyDescription's + max_set_version and max_election_id if any. + + Returns (new topology type, new replica_set_name, new max_set_version, + new max_election_id). + """ + if replica_set_name is None: + replica_set_name = server_description.replica_set_name + + elif replica_set_name != server_description.replica_set_name: + # We found a primary but it doesn't have the replica_set_name + # provided by the user. + sds.pop(server_description.address) + return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id + + if server_description.max_wire_version is None or server_description.max_wire_version < 17: + new_election_tuple: tuple = (server_description.set_version, server_description.election_id) + max_election_tuple: tuple = (max_set_version, max_election_id) + if None not in new_election_tuple: + if None not in max_election_tuple and new_election_tuple < max_election_tuple: + # Stale primary, set to type Unknown. + sds[server_description.address] = server_description.to_unknown() + return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id + max_election_id = server_description.election_id + + if server_description.set_version is not None and ( + max_set_version is None or server_description.set_version > max_set_version + ): + max_set_version = server_description.set_version + else: + new_election_tuple = server_description.election_id, server_description.set_version + max_election_tuple = max_election_id, max_set_version + new_election_safe = tuple(MinKey() if i is None else i for i in new_election_tuple) + max_election_safe = tuple(MinKey() if i is None else i for i in max_election_tuple) + if new_election_safe < max_election_safe: + # Stale primary, set to type Unknown. + sds[server_description.address] = server_description.to_unknown() + return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id + else: + max_election_id = server_description.election_id + max_set_version = server_description.set_version + + # We've heard from the primary. Is it the same primary as before? + for server in sds.values(): + if ( + server.server_type is SERVER_TYPE.RSPrimary + and server.address != server_description.address + ): + # Reset old primary's type to Unknown. + sds[server.address] = server.to_unknown() + + # There can be only one prior primary. + break + + # Discover new hosts from this primary's response. + for new_address in server_description.all_hosts: + if new_address not in sds: + sds[new_address] = ServerDescription(new_address) + + # Remove hosts not in the response. + for addr in set(sds) - server_description.all_hosts: + sds.pop(addr) + + # If the host list differs from the seed list, we may not have a primary + # after all. + return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id) + + +def _update_rs_with_primary_from_member( + sds: MutableMapping[_Address, ServerDescription], + replica_set_name: Optional[str], + server_description: ServerDescription, +) -> int: + """RS with known primary. Process a response from a non-primary. + + Pass in a dict of ServerDescriptions, current replica set name, and the + ServerDescription we are processing. + + Returns new topology type. + """ + assert replica_set_name is not None + + if replica_set_name != server_description.replica_set_name: + sds.pop(server_description.address) + elif server_description.me and server_description.address != server_description.me: + sds.pop(server_description.address) + + # Had this member been the primary? + return _check_has_primary(sds) + + +def _update_rs_no_primary_from_member( + sds: MutableMapping[_Address, ServerDescription], + replica_set_name: Optional[str], + server_description: ServerDescription, +) -> tuple[int, Optional[str]]: + """RS without known primary. Update from a non-primary's response. + + Pass in a dict of ServerDescriptions, current replica set name, and the + ServerDescription we are processing. + + Returns (new topology type, new replica_set_name). + """ + topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary + if replica_set_name is None: + replica_set_name = server_description.replica_set_name + + elif replica_set_name != server_description.replica_set_name: + sds.pop(server_description.address) + return topology_type, replica_set_name + + # This isn't the primary's response, so don't remove any servers + # it doesn't report. Only add new servers. + for address in server_description.all_hosts: + if address not in sds: + sds[address] = ServerDescription(address) + + if server_description.me and server_description.address != server_description.me: + sds.pop(server_description.address) + + return topology_type, replica_set_name + + +def _check_has_primary(sds: Mapping[_Address, ServerDescription]) -> int: + """Current topology type is ReplicaSetWithPrimary. Is primary still known? + + Pass in a dict of ServerDescriptions. + + Returns new topology type. + """ + for s in sds.values(): + if s.server_type == SERVER_TYPE.RSPrimary: + return TOPOLOGY_TYPE.ReplicaSetWithPrimary + else: # noqa: PLW0120 + return TOPOLOGY_TYPE.ReplicaSetNoPrimary diff --git a/pymongo/synchronous/typings.py b/pymongo/typings.py similarity index 58% rename from pymongo/synchronous/typings.py rename to pymongo/typings.py index bc3fb0938..9f6d7b166 100644 --- a/pymongo/synchronous/typings.py +++ b/pymongo/typings.py @@ -29,9 +29,16 @@ from typing import ( from bson.typings import _DocumentOut, _DocumentType, _DocumentTypeArg if TYPE_CHECKING: - from pymongo.synchronous.collation import Collation + from pymongo.asynchronous.bulk import _AsyncBulk + from pymongo.asynchronous.client_session import AsyncClientSession + from pymongo.asynchronous.mongo_client import AsyncMongoClient + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.collation import Collation + from pymongo.synchronous.bulk import _Bulk + from pymongo.synchronous.client_session import ClientSession + from pymongo.synchronous.mongo_client import MongoClient + from pymongo.synchronous.pool import Connection -_IS_SYNC = True # Common Shared Types. _Address = Tuple[str, Optional[int]] @@ -41,9 +48,15 @@ ClusterTime = Mapping[str, Any] _T = TypeVar("_T") +# Type hinting types for compatibility between async and sync classes +_AgnosticMongoClient = Union["AsyncMongoClient", "MongoClient"] +_AgnosticConnection = Union["AsyncConnection", "Connection"] +_AgnosticClientSession = Union["AsyncClientSession", "ClientSession"] +_AgnosticBulk = Union["_AsyncBulk", "_Bulk"] + def strip_optional(elem: Optional[_T]) -> _T: - """This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T + """This function is to allow us to cast all the elements of an iterator from Optional[_T] to _T while inside a list comprehension. """ assert elem is not None @@ -58,4 +71,5 @@ __all__ = [ "_CollationIn", "_Pipeline", "strip_optional", + "_AgnosticMongoClient", ] diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index e74ef1883..4ebd3008c 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -1,21 +1,623 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2011-present MongoDB, Inc. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. -"""Re-import of synchronous URIParser API for compatibility.""" + +"""Tools to parse and validate a MongoDB URI.""" from __future__ import annotations -from pymongo.synchronous.uri_parser import * # noqa: F403 -from pymongo.synchronous.uri_parser import __doc__ as original_doc +import re +import sys +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + MutableMapping, + Optional, + Sized, + Union, + cast, +) +from urllib.parse import unquote_plus -__doc__ = original_doc +from pymongo.client_options import _parse_ssl_options +from pymongo.common import ( + INTERNAL_URI_OPTION_NAME_MAP, + SRV_SERVICE_NAME, + URI_OPTIONS_DEPRECATION_MAP, + _CaseInsensitiveDictionary, + get_validated_options, +) +from pymongo.errors import ConfigurationError, InvalidURI +from pymongo.srv_resolver import _have_dnspython, _SrvResolver +from pymongo.typings import _Address + +if TYPE_CHECKING: + from pymongo.pyopenssl_context import SSLContext + +SCHEME = "mongodb://" +SCHEME_LEN = len(SCHEME) +SRV_SCHEME = "mongodb+srv://" +SRV_SCHEME_LEN = len(SRV_SCHEME) +DEFAULT_PORT = 27017 + + +def _unquoted_percent(s: str) -> bool: + """Check for unescaped percent signs. + + :param s: A string. `s` can have things like '%25', '%2525', + and '%E2%85%A8' but cannot have unquoted percent like '%foo'. + """ + for i in range(len(s)): + if s[i] == "%": + sub = s[i : i + 3] + # If unquoting yields the same string this means there was an + # unquoted %. + if unquote_plus(sub) == sub: + return True + return False + + +def parse_userinfo(userinfo: str) -> tuple[str, str]: + """Validates the format of user information in a MongoDB URI. + Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", + "]", "@") as per RFC 3986 must be escaped. + + Returns a 2-tuple containing the unescaped username followed + by the unescaped password. + + :param userinfo: A string of the form : + """ + if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo): + raise InvalidURI( + "Username and password must be escaped according to " + "RFC 3986, use urllib.parse.quote_plus" + ) + + user, _, passwd = userinfo.partition(":") + # No password is expected with GSSAPI authentication. + if not user: + raise InvalidURI("The empty string is not valid username.") + + return unquote_plus(user), unquote_plus(passwd) + + +def parse_ipv6_literal_host( + entity: str, default_port: Optional[int] +) -> tuple[str, Optional[Union[str, int]]]: + """Validates an IPv6 literal host:port string. + + Returns a 2-tuple of IPv6 literal followed by port where + port is default_port if it wasn't specified in entity. + + :param entity: A string that represents an IPv6 literal enclosed + in braces (e.g. '[::1]' or '[::1]:27017'). + :param default_port: The port number to use when one wasn't + specified in entity. + """ + if entity.find("]") == -1: + raise ValueError( + "an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732." + ) + i = entity.find("]:") + if i == -1: + return entity[1:-1], default_port + return entity[1:i], entity[i + 2 :] + + +def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address: + """Validates a host string + + Returns a 2-tuple of host followed by port where port is default_port + if it wasn't specified in the string. + + :param entity: A host or host:port string where host could be a + hostname or IP address. + :param default_port: The port number to use when one wasn't + specified in entity. + """ + host = entity + port: Optional[Union[str, int]] = default_port + if entity[0] == "[": + host, port = parse_ipv6_literal_host(entity, default_port) + elif entity.endswith(".sock"): + return entity, default_port + elif entity.find(":") != -1: + if entity.count(":") > 1: + raise ValueError( + "Reserved characters such as ':' must be " + "escaped according RFC 2396. An IPv6 " + "address literal must be enclosed in '[' " + "and ']' according to RFC 2732." + ) + host, port = host.split(":", 1) + if isinstance(port, str): + if not port.isdigit() or int(port) > 65535 or int(port) <= 0: + raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}") + port = int(port) + + # Normalize hostname to lowercase, since DNS is case-insensitive: + # http://tools.ietf.org/html/rfc4343 + # This prevents useless rediscovery if "foo.com" is in the seed list but + # "FOO.com" is in the hello response. + return host.lower(), port + + +# Options whose values are implicitly determined by tlsInsecure. +_IMPLICIT_TLSINSECURE_OPTS = { + "tlsallowinvalidcertificates", + "tlsallowinvalidhostnames", + "tlsdisableocspendpointcheck", +} + + +def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary: + """Helper method for split_options which creates the options dict. + Also handles the creation of a list for the URI tag_sets/ + readpreferencetags portion, and the use of a unicode options string. + """ + options = _CaseInsensitiveDictionary() + for uriopt in opts.split(delim): + key, value = uriopt.split("=") + if key.lower() == "readpreferencetags": + options.setdefault(key, []).append(value) + else: + if key in options: + warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2) + if key.lower() == "authmechanismproperties": + val = value + else: + val = unquote_plus(value) + options[key] = val + + return options + + +def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Raise appropriate errors when conflicting TLS options are present in + the options dictionary. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + # Implicitly defined options must not be explicitly specified. + tlsinsecure = options.get("tlsinsecure") + if tlsinsecure is not None: + for opt in _IMPLICIT_TLSINSECURE_OPTS: + if opt in options: + err_msg = "URI options %s and %s cannot be specified simultaneously." + raise InvalidURI( + err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt)) + ) + + # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. + tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates") + if tlsallowinvalidcerts is not None: + if "tlsdisableocspendpointcheck" in options: + err_msg = "URI options %s and %s cannot be specified simultaneously." + raise InvalidURI( + err_msg + % ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck")) + ) + if tlsallowinvalidcerts is True: + options["tlsdisableocspendpointcheck"] = True + + # Handle co-occurence of CRL and OCSP-related options. + tlscrlfile = options.get("tlscrlfile") + if tlscrlfile is not None: + for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"): + if options.get(opt) is True: + err_msg = "URI option %s=True cannot be specified when CRL checking is enabled." + raise InvalidURI(err_msg % (opt,)) + + if "ssl" in options and "tls" in options: + + def truth_value(val: Any) -> Any: + if val in ("true", "false"): + return val == "true" + if isinstance(val, bool): + return val + return val + + if truth_value(options.get("ssl")) != truth_value(options.get("tls")): + err_msg = "Can not specify conflicting values for URI options %s and %s." + raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls"))) + + return options + + +def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Issue appropriate warnings when deprecated options are present in the + options dictionary. Removes deprecated option key, value pairs if the + options dictionary is found to also have the renamed option. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + for optname in list(options): + if optname in URI_OPTIONS_DEPRECATION_MAP: + mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] + if mode == "renamed": + newoptname = message + if newoptname in options: + warn_msg = "Deprecated option '%s' ignored in favor of '%s'." + warnings.warn( + warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), + DeprecationWarning, + stacklevel=2, + ) + options.pop(optname) + continue + warn_msg = "Option '%s' is deprecated, use '%s' instead." + warnings.warn( + warn_msg % (options.cased_key(optname), newoptname), + DeprecationWarning, + stacklevel=2, + ) + elif mode == "removed": + warn_msg = "Option '%s' is deprecated. %s." + warnings.warn( + warn_msg % (options.cased_key(optname), message), + DeprecationWarning, + stacklevel=2, + ) + + return options + + +def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Normalizes option names in the options dictionary by converting them to + their internally-used names. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + # Expand the tlsInsecure option. + tlsinsecure = options.get("tlsinsecure") + if tlsinsecure is not None: + for opt in _IMPLICIT_TLSINSECURE_OPTS: + # Implicit options are logically the same as tlsInsecure. + options[opt] = tlsinsecure + + for optname in list(options): + intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None) + if intname is not None: + options[intname] = options.pop(optname) + + return options + + +def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]: + """Validates and normalizes options passed in a MongoDB URI. + + Returns a new dictionary of validated and normalized options. If warn is + False then errors will be thrown for invalid options, otherwise they will + be ignored and a warning will be issued. + + :param opts: A dict of MongoDB URI options. + :param warn: If ``True`` then warnings will be logged and + invalid options will be ignored. Otherwise invalid options will + cause errors. + """ + return get_validated_options(opts, warn) + + +def split_options( + opts: str, validate: bool = True, warn: bool = False, normalize: bool = True +) -> MutableMapping[str, Any]: + """Takes the options portion of a MongoDB URI, validates each option + and returns the options in a dictionary. + + :param opt: A string representing MongoDB URI options. + :param validate: If ``True`` (the default), validate and normalize all + options. + :param warn: If ``False`` (the default), suppress all warnings raised + during validation of options. + :param normalize: If ``True`` (the default), renames all options to their + internally-used names. + """ + and_idx = opts.find("&") + semi_idx = opts.find(";") + try: + if and_idx >= 0 and semi_idx >= 0: + raise InvalidURI("Can not mix '&' and ';' for option separators.") + elif and_idx >= 0: + options = _parse_options(opts, "&") + elif semi_idx >= 0: + options = _parse_options(opts, ";") + elif opts.find("=") != -1: + options = _parse_options(opts, None) + else: + raise ValueError + except ValueError: + raise InvalidURI("MongoDB URI options are key=value pairs.") from None + + options = _handle_security_options(options) + + options = _handle_option_deprecations(options) + + if normalize: + options = _normalize_options(options) + + if validate: + options = cast(_CaseInsensitiveDictionary, validate_options(options, warn)) + if options.get("authsource") == "": + raise InvalidURI("the authSource database cannot be an empty string") + + return options + + +def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]: + """Takes a string of the form host1[:port],host2[:port]... and + splits it into (host, port) tuples. If [:port] isn't present the + default_port is used. + + Returns a set of 2-tuples containing the host name (or IP) followed by + port number. + + :param hosts: A string of the form host1[:port],host2[:port],... + :param default_port: The port number to use when one wasn't specified + for a host. + """ + nodes = [] + for entity in hosts.split(","): + if not entity: + raise ConfigurationError("Empty host (or extra comma in host list).") + port = default_port + # Unix socket entities don't have ports + if entity.endswith(".sock"): + port = None + nodes.append(parse_host(entity, port)) + return nodes + + +# Prohibited characters in database name. DB names also can't have ".", but for +# backward-compat we allow "db.collection" in URI. +_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]") + +_ALLOWED_TXT_OPTS = frozenset( + ["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"] +) + + +def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None: + # Ensure directConnection was not True if there are multiple seeds. + if len(nodes) > 1 and options.get("directconnection"): + raise ConfigurationError("Cannot specify multiple hosts with directConnection=true") + + if options.get("loadbalanced"): + if len(nodes) > 1: + raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true") + if options.get("directconnection"): + raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true") + if options.get("replicaset"): + raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") + + +def parse_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + """Parse and validate a MongoDB URI. + + Returns a dict of the form:: + + { + 'nodelist': , + 'username': or None, + 'password': or None, + 'database': or None, + 'collection': or None, + 'options': , + 'fqdn': or None + } + + If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done + to build nodelist and options. + + :param uri: The MongoDB URI to parse. + :param default_port: The port number to use when one wasn't specified + for a host in the URI. + :param validate: If ``True`` (the default), validate and + normalize all options. Default: ``True``. + :param warn: When validating, if ``True`` then will warn + the user then ignore any invalid options or values. If ``False``, + validation will error when options are unsupported or values are + invalid. Default: ``False``. + :param normalize: If ``True``, convert names of URI options + to their internally-used names. Default: ``True``. + :param connect_timeout: The maximum time in milliseconds to + wait for a response from the DNS server. + :param srv_service_name: A custom SRV service name + + .. versionchanged:: 4.6 + The delimiting slash (``/``) between hosts and connection options is now optional. + For example, "mongodb://example.com?tls=true" is now a valid URI. + + .. versionchanged:: 4.0 + To better follow RFC 3986, unquoted percent signs ("%") are no longer + supported. + + .. versionchanged:: 3.9 + Added the ``normalize`` parameter. + + .. versionchanged:: 3.6 + Added support for mongodb+srv:// URIs. + + .. versionchanged:: 3.5 + Return the original value of the ``readPreference`` MongoDB URI option + instead of the validated read preference mode. + + .. versionchanged:: 3.1 + ``warn`` added so invalid options can be ignored. + """ + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + elif uri.startswith(SRV_SCHEME): + if not _have_dnspython(): + python_path = sys.executable or "python" + raise ConfigurationError( + 'The "dnspython" module must be ' + "installed to use mongodb+srv:// URIs. " + "To fix this error install pymongo again:\n " + "%s -m pip install pymongo>=4.3" % (python_path) + ) + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + else: + raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") + + if not scheme_free: + raise InvalidURI("Must provide at least one hostname or IP.") + + user = None + passwd = None + dbase = None + collection = None + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, dbase = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if dbase: + dbase = unquote_plus(dbase) + if "." in dbase: + dbase, collection = dbase.split(".", 1) + if _BAD_DB_CHARS.search(dbase): + raise InvalidURI('Bad database name "%s"' % dbase) + else: + dbase = None + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if srv_service_name is None: + srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) + if "@" in host_part: + userinfo, _, hosts = host_part.rpartition("@") + user, passwd = parse_userinfo(userinfo) + else: + hosts = host_part + + if "/" in hosts: + raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) + + hosts = unquote_plus(hosts) + fqdn = None + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + if options.get("directConnection"): + raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") + nodes = split_hosts(hosts, default_port=None) + if len(nodes) != 1: + raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") + fqdn, port = nodes[0] + if port is not None: + raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") + + # Use the connection timeout. connectTimeoutMS passed as a keyword + # argument overrides the same option passed in the connection string. + connect_timeout = connect_timeout or options.get("connectTimeoutMS") + dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) + nodes = dns_resolver.get_hosts() + dns_options = dns_resolver.get_options() + if dns_options: + parsed_dns_options = split_options(dns_options, validate, warn, normalize) + if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: + raise ConfigurationError( + "Only authSource, replicaSet, and loadBalanced are supported from DNS" + ) + for opt, val in parsed_dns_options.items(): + if opt not in options: + options[opt] = val + if options.get("loadBalanced") and srv_max_hosts: + raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") + if options.get("replicaSet") and srv_max_hosts: + raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") + if "tls" not in options and "ssl" not in options: + options["tls"] = True if validate else "true" + elif not is_srv and options.get("srvServiceName") is not None: + raise ConfigurationError( + "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" + ) + elif not is_srv and srv_max_hosts: + raise ConfigurationError( + "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" + ) + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "username": user, + "password": passwd, + "database": dbase, + "collection": collection, + "options": options, + "fqdn": fqdn, + } + + +def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]: + """Parse KMS TLS connection options.""" + if not kms_tls_options: + return {} + if not isinstance(kms_tls_options, dict): + raise TypeError("kms_tls_options must be a dict") + contexts = {} + for provider, options in kms_tls_options.items(): + if not isinstance(options, dict): + raise TypeError(f'kms_tls_options["{provider}"] must be a dict') + options.setdefault("tls", True) + opts = _CaseInsensitiveDictionary(options) + opts = _handle_security_options(opts) + opts = _normalize_options(opts) + opts = cast(_CaseInsensitiveDictionary, validate_options(opts)) + ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts) + if ssl_context is None: + raise ConfigurationError("TLS is required for KMS providers") + if allow_invalid_hostnames: + raise ConfigurationError("Insecure TLS options prohibited") + + for n in [ + "tlsInsecure", + "tlsAllowInvalidCertificates", + "tlsAllowInvalidHostnames", + "tlsDisableCertificateRevocationCheck", + ]: + if n in opts: + raise ConfigurationError(f"Insecure TLS options prohibited: {n}") + contexts[provider] = ssl_context + return contexts + + +if __name__ == "__main__": + import pprint + + try: + pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 + except InvalidURI as exc: + print(exc) # noqa: T201 + sys.exit(0) diff --git a/requirements/encryption.txt b/requirements/encryption.txt index bae6115eb..1a8c14844 100644 --- a/requirements/encryption.txt +++ b/requirements/encryption.txt @@ -1,3 +1,3 @@ pymongo-auth-aws>=1.1.0,<2.0.0 -pymongocrypt>=1.6.0,<2.0.0 +pymongocrypt>=1.10.0,<2.0.0 certifi;os.name=='nt' or sys_platform=='darwin' diff --git a/setup.py b/setup.py index 9312c68bd..f371b3d75 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,8 @@ from __future__ import annotations -msg = "PyMongo>=4.8 no longer supports building via setup.py, use python -m pip install instead" +msg = ( + "PyMongo>=4.8 no longer supports building via setup.py, use python -m pip install instead. If " + "this is an editable install (-e) please upgrade to pip>=21.3 first: python -m pip install --upgrade pip" +) raise RuntimeError(msg) diff --git a/test/__init__.py b/test/__init__.py index f89363a33..19218f01a 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -45,14 +45,14 @@ from urllib.parse import quote_plus import pymongo import pymongo.errors from bson.son import SON +from pymongo import common, message +from pymongo.common import partition_node +from pymongo.hello import HelloCompat from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.synchronous import common, message -from pymongo.synchronous.common import partition_node from pymongo.synchronous.database import Database -from pymongo.synchronous.hello_compat import HelloCompat from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.uri_parser import parse_uri +from pymongo.uri_parser import parse_uri if HAVE_SSL: import ssl diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index d38065eb3..0a74366ae 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -63,25 +63,20 @@ except ImportError: from contextlib import asynccontextmanager, contextmanager from functools import wraps from test.version import Version -from typing import Any, Callable, Dict, Generator, no_type_check +from typing import Any, Callable, Dict, Generator from unittest import SkipTest from urllib.parse import quote_plus import pymongo import pymongo.errors from bson.son import SON -from pymongo.asynchronous import common, message -from pymongo.asynchronous.common import partition_node from pymongo.asynchronous.database import AsyncDatabase -from pymongo.asynchronous.hello_compat import HelloCompat from pymongo.asynchronous.mongo_client import AsyncMongoClient -from pymongo.asynchronous.uri_parser import parse_uri +from pymongo.common import partition_node +from pymongo.hello import HelloCompat from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -if HAVE_SSL: - import ssl - _IS_SYNC = False diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 078bad9e2..9e40f53a4 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -28,7 +28,10 @@ from pymongo.asynchronous.database import AsyncDatabase sys.path[0:0] = [""] from test import unittest -from test.asynchronous import AsyncIntegrationTest, async_client_context +from test.asynchronous import ( # TODO: fix sync imports in PYTHON-4528 + AsyncIntegrationTest, + async_client_context, +) from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, EventListener, @@ -47,14 +50,11 @@ from bson.raw_bson import RawBSONDocument from bson.regex import Regex from bson.son import SON from pymongo import ASCENDING, DESCENDING, GEO2D, GEOSPHERE, HASHED, TEXT -from pymongo.asynchronous.bulk import BulkWriteError from pymongo.asynchronous.collection import AsyncCollection, ReturnDocument from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.helpers import anext -from pymongo.asynchronous.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.asynchronous.mongo_client import AsyncMongoClient -from pymongo.asynchronous.operations import * -from pymongo.asynchronous.read_preferences import ReadPreference +from pymongo.bulk_shared import BulkWriteError from pymongo.cursor_shared import CursorType from pymongo.errors import ( ConfigurationError, @@ -67,7 +67,10 @@ from pymongo.errors import ( OperationFailure, WriteConcernError, ) +from pymongo.message import _COMMAND_OVERHEAD, _gen_find_command +from pymongo.operations import * from pymongo.read_concern import DEFAULT_READ_CONCERN +from pymongo.read_preferences import ReadPreference from pymongo.results import ( DeleteResult, InsertManyResult, @@ -472,8 +475,8 @@ class AsyncTestCollection(AsyncIntegrationTest): async def test_index_haystack(self): db = self.db await db.test.drop() - _id = await db.test.insert_one( - {"pos": {"long": 34.2, "lat": 33.3}, "type": "restaurant"} + _id = ( + await db.test.insert_one({"pos": {"long": 34.2, "lat": 33.3}, "type": "restaurant"}) ).inserted_id await db.test.insert_one({"pos": {"long": 34.2, "lat": 37.3}, "type": "restaurant"}) await db.test.insert_one({"pos": {"long": 59.1, "lat": 87.2}, "type": "office"}) @@ -1642,7 +1645,7 @@ class AsyncTestCollection(AsyncIntegrationTest): with await self.db.test.aggregate([], {}): # type:ignore pass - with self.assertRaisesRegex(ValueError, "must be a ClientSession"): + with self.assertRaisesRegex(ValueError, "must be an AsyncClientSession"): await try_invalid_session() async def test_large_limit(self): diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index f6a6c9694..3e5dcec56 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -26,7 +26,7 @@ from pymongo_auth_aws import AwsCredential, auth from pymongo import MongoClient from pymongo.errors import OperationFailure -from pymongo.synchronous.uri_parser import parse_uri +from pymongo.uri_parser import parse_uri class TestAuthAWS(unittest.TestCase): diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index 3fb289478..83e25e685 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -36,14 +36,14 @@ from pymongo._azure_helpers import _get_azure_response from pymongo._gcp_helpers import _get_gcp_response from pymongo.cursor_shared import CursorType from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure +from pymongo.hello import HelloCompat +from pymongo.operations import InsertOne from pymongo.synchronous.auth_oidc import ( OIDCCallback, OIDCCallbackContext, OIDCCallbackResult, ) -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.operations import InsertOne -from pymongo.synchronous.uri_parser import parse_uri +from pymongo.uri_parser import parse_uri ROOT = Path(__file__).parent.parent.resolve() TEST_PATH = ROOT / "auth" / "unified" diff --git a/test/lambda/mongodb/app.py b/test/lambda/mongodb/app.py index deb26bdf1..5840347d9 100644 --- a/test/lambda/mongodb/app.py +++ b/test/lambda/mongodb/app.py @@ -12,7 +12,7 @@ import os from bson import has_c as has_bson_c from pymongo import MongoClient from pymongo import has_c as has_pymongo_c -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( CommandListener, ConnectionPoolListener, ServerHeartbeatListener, diff --git a/test/mockupdb/test_mongos_command_read_mode.py b/test/mockupdb/test_mongos_command_read_mode.py index 1e91384dc..8ee33431a 100644 --- a/test/mockupdb/test_mongos_command_read_mode.py +++ b/test/mockupdb/test_mongos_command_read_mode.py @@ -20,7 +20,7 @@ from mockupdb import MockupDB, OpMsg, going from operations import operations # type: ignore[import] from pymongo import MongoClient, ReadPreference -from pymongo.synchronous.read_preferences import ( +from pymongo.read_preferences import ( _MONGOS_MODES, make_read_preference, read_pref_mode_from_name, diff --git a/test/mockupdb/test_network_disconnect_primary.py b/test/mockupdb/test_network_disconnect_primary.py index 36e004c05..d05cfb531 100644 --- a/test/mockupdb/test_network_disconnect_primary.py +++ b/test/mockupdb/test_network_disconnect_primary.py @@ -19,7 +19,7 @@ from mockupdb import Future, MockupDB, OpReply, going, wait_until from pymongo import MongoClient from pymongo.errors import ConnectionFailure -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE +from pymongo.topology_description import TOPOLOGY_TYPE class TestNetworkDisconnectPrimary(unittest.TestCase): diff --git a/test/mockupdb/test_op_msg.py b/test/mockupdb/test_op_msg.py index aa2437f23..d36e5e02b 100644 --- a/test/mockupdb/test_op_msg.py +++ b/test/mockupdb/test_op_msg.py @@ -20,7 +20,7 @@ from mockupdb import OP_MSG_FLAGS, MockupDB, OpMsg, OpMsgReply, going from pymongo import MongoClient, WriteConcern from pymongo.cursor_shared import CursorType -from pymongo.synchronous.operations import DeleteOne, InsertOne, UpdateOne +from pymongo.operations import DeleteOne, InsertOne, UpdateOne Operation = namedtuple("Operation", ["name", "function", "request", "reply"]) diff --git a/test/mockupdb/test_op_msg_read_preference.py b/test/mockupdb/test_op_msg_read_preference.py index 36b8f4fbe..0fa7b8486 100644 --- a/test/mockupdb/test_op_msg_read_preference.py +++ b/test/mockupdb/test_op_msg_read_preference.py @@ -22,7 +22,7 @@ from mockupdb import CommandBase, MockupDB, going from operations import operations # type: ignore[import] from pymongo import MongoClient, ReadPreference -from pymongo.synchronous.read_preferences import ( +from pymongo.read_preferences import ( _MONGOS_MODES, make_read_preference, read_pref_mode_from_name, diff --git a/test/mockupdb/test_query_read_pref_sharded.py b/test/mockupdb/test_query_read_pref_sharded.py index 9eb4de28c..529770988 100644 --- a/test/mockupdb/test_query_read_pref_sharded.py +++ b/test/mockupdb/test_query_read_pref_sharded.py @@ -21,7 +21,7 @@ from mockupdb import MockupDB, OpMsg, going from bson import SON from pymongo import MongoClient -from pymongo.synchronous.read_preferences import ( +from pymongo.read_preferences import ( Nearest, Primary, PrimaryPreferred, diff --git a/test/mockupdb/test_reset_and_request_check.py b/test/mockupdb/test_reset_and_request_check.py index 080110020..19dfb9e39 100644 --- a/test/mockupdb/test_reset_and_request_check.py +++ b/test/mockupdb/test_reset_and_request_check.py @@ -22,8 +22,8 @@ from operations import operations # type: ignore[import] from pymongo import MongoClient from pymongo.errors import ConnectionFailure +from pymongo.operations import _Op from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.operations import _Op class TestResetAndRequestCheck(unittest.TestCase): diff --git a/test/mockupdb/test_slave_okay_sharded.py b/test/mockupdb/test_slave_okay_sharded.py index 9692465d5..45b7d51ba 100644 --- a/test/mockupdb/test_slave_okay_sharded.py +++ b/test/mockupdb/test_slave_okay_sharded.py @@ -28,7 +28,7 @@ from mockupdb import MockupDB, going from operations import operations # type: ignore[import] from pymongo import MongoClient -from pymongo.synchronous.read_preferences import make_read_preference, read_pref_mode_from_name +from pymongo.read_preferences import make_read_preference, read_pref_mode_from_name class TestSlaveOkaySharded(unittest.TestCase): diff --git a/test/mockupdb/test_slave_okay_single.py b/test/mockupdb/test_slave_okay_single.py index bf1cdee74..b03232807 100644 --- a/test/mockupdb/test_slave_okay_single.py +++ b/test/mockupdb/test_slave_okay_single.py @@ -27,8 +27,8 @@ from mockupdb import MockupDB, going from operations import operations # type: ignore[import] from pymongo import MongoClient -from pymongo.synchronous.read_preferences import make_read_preference, read_pref_mode_from_name -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE +from pymongo.read_preferences import make_read_preference, read_pref_mode_from_name +from pymongo.topology_description import TOPOLOGY_TYPE def topology_type_name(client): diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index d3c1a271c..9cbca169c 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -20,13 +20,13 @@ import weakref from functools import partial from test import client_context +from pymongo import common from pymongo.errors import AutoReconnect, NetworkTimeout -from pymongo.synchronous import common -from pymongo.synchronous.hello import Hello, HelloCompat +from pymongo.hello import Hello, HelloCompat +from pymongo.server_description import ServerDescription from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.monitor import Monitor from pymongo.synchronous.pool import Pool -from pymongo.synchronous.server_description import ServerDescription class MockPool(Pool): diff --git a/test/sigstop_sigcont.py b/test/sigstop_sigcont.py index c5084f594..bc1bacce3 100644 --- a/test/sigstop_sigcont.py +++ b/test/sigstop_sigcont.py @@ -21,8 +21,8 @@ import sys sys.path[0:0] = [""] +from pymongo import monitoring from pymongo.server_api import ServerApi -from pymongo.synchronous import monitoring from pymongo.synchronous.mongo_client import MongoClient SERVER_API = None diff --git a/test/synchronous/__init__.py b/test/synchronous/__init__.py index 6eb11eee8..9176b22d1 100644 --- a/test/synchronous/__init__.py +++ b/test/synchronous/__init__.py @@ -63,24 +63,19 @@ except ImportError: from contextlib import contextmanager from functools import wraps from test.version import Version -from typing import Any, Callable, Dict, Generator, no_type_check +from typing import Any, Callable, Dict, Generator from unittest import SkipTest from urllib.parse import quote_plus import pymongo import pymongo.errors from bson.son import SON +from pymongo.common import partition_node +from pymongo.hello import HelloCompat from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.synchronous import common, message -from pymongo.synchronous.common import partition_node from pymongo.synchronous.database import Database -from pymongo.synchronous.hello_compat import HelloCompat from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.uri_parser import parse_uri - -if HAVE_SSL: - import ssl _IS_SYNC = True diff --git a/test/synchronous/conftest.py b/test/synchronous/conftest.py index 5befb96e1..58f04ea7c 100644 --- a/test/synchronous/conftest.py +++ b/test/synchronous/conftest.py @@ -1,6 +1,6 @@ from __future__ import annotations -from test.synchronous import setup, teardown +from test import setup, teardown import pytest diff --git a/test/synchronous/test_collection.py b/test/synchronous/test_collection.py index 39d7e13a3..7d105acb6 100644 --- a/test/synchronous/test_collection.py +++ b/test/synchronous/test_collection.py @@ -27,8 +27,11 @@ from pymongo.synchronous.database import Database sys.path[0:0] = [""] -from test import unittest -from test.synchronous import IntegrationTest, client_context +from test import ( # TODO: fix sync imports in PYTHON-4528 + IntegrationTest, + client_context, + unittest, +) from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, EventListener, @@ -46,6 +49,7 @@ from bson.raw_bson import RawBSONDocument from bson.regex import Regex from bson.son import SON from pymongo import ASCENDING, DESCENDING, GEO2D, GEOSPHERE, HASHED, TEXT +from pymongo.bulk_shared import BulkWriteError from pymongo.cursor_shared import CursorType from pymongo.errors import ( ConfigurationError, @@ -58,21 +62,20 @@ from pymongo.errors import ( OperationFailure, WriteConcernError, ) +from pymongo.message import _COMMAND_OVERHEAD, _gen_find_command +from pymongo.operations import * from pymongo.read_concern import DEFAULT_READ_CONCERN +from pymongo.read_preferences import ReadPreference from pymongo.results import ( DeleteResult, InsertManyResult, InsertOneResult, UpdateResult, ) -from pymongo.synchronous.bulk import BulkWriteError from pymongo.synchronous.collection import Collection, ReturnDocument from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.helpers import next -from pymongo.synchronous.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.operations import * -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -134,7 +137,7 @@ class TestCollectionNoConnect(unittest.TestCase): if _IS_SYNC: msg = "'Collection' object is not iterable" else: - msg = "'AsyncCollection' object is not iterable" + msg = "'Collection' object is not iterable" # Iteration fails with self.assertRaisesRegex(TypeError, msg): for _ in coll: # type: ignore[misc] # error: "None" not callable [misc] @@ -461,8 +464,8 @@ class TestCollection(IntegrationTest): def test_index_haystack(self): db = self.db db.test.drop() - _id = db.test.insert_one( - {"pos": {"long": 34.2, "lat": 33.3}, "type": "restaurant"} + _id = ( + db.test.insert_one({"pos": {"long": 34.2, "lat": 33.3}, "type": "restaurant"}) ).inserted_id db.test.insert_one({"pos": {"long": 34.2, "lat": 37.3}, "type": "restaurant"}) db.test.insert_one({"pos": {"long": 59.1, "lat": 87.2}, "type": "office"}) diff --git a/test/test_auth.py b/test/test_auth.py index bf2e5f6f8..2ae0eae12 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -33,13 +33,13 @@ from test.utils import ( single_client_noauth, ) -from pymongo import MongoClient -from pymongo.asynchronous.auth import HAVE_KERBEROS, _build_credentials_tuple +from pymongo import MongoClient, monitoring +from pymongo.asynchronous.auth import HAVE_KERBEROS +from pymongo.auth_shared import _build_credentials_tuple from pymongo.errors import OperationFailure +from pymongo.hello import HelloCompat +from pymongo.read_preferences import ReadPreference from pymongo.saslprep import HAVE_STRINGPREP -from pymongo.synchronous import monitoring -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.read_preferences import ReadPreference # YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS ON UNIX. GSSAPI_HOST = os.environ.get("GSSAPI_HOST") diff --git a/test/test_binary.py b/test/test_binary.py index 66a57dcb5..93f6d0831 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -33,7 +33,7 @@ from bson import decode, encode from bson.binary import * from bson.codec_options import CodecOptions from bson.son import SON -from pymongo.synchronous.common import validate_uuid_representation +from pymongo.common import validate_uuid_representation from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern diff --git a/test/test_bulk.py b/test/test_bulk.py index 42dbf5b15..c0f859443 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -34,15 +34,15 @@ from test.utils import ( from bson.binary import Binary, UuidRepresentation from bson.codec_options import CodecOptions from bson.objectid import ObjectId +from pymongo.common import partition_node from pymongo.errors import ( BulkWriteError, ConfigurationError, InvalidOperation, OperationFailure, ) +from pymongo.operations import * from pymongo.synchronous.collection import Collection -from pymongo.synchronous.common import partition_node -from pymongo.synchronous.operations import * from pymongo.write_concern import WriteConcern diff --git a/test/test_change_stream.py b/test/test_change_stream.py index 4d8422667..e00aaa640 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -45,9 +45,9 @@ from pymongo.errors import ( OperationFailure, ServerSelectionTimeoutError, ) +from pymongo.message import _CursorAddress from pymongo.read_concern import ReadConcern from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.message import _CursorAddress from pymongo.write_concern import WriteConcern diff --git a/test/test_client.py b/test/test_client.py index 503c2e6e3..d7f1e0047 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -36,7 +36,7 @@ from unittest.mock import patch import pytest -from pymongo.synchronous.operations import _Op +from pymongo.operations import _Op sys.path[0:0] = [""] @@ -83,6 +83,10 @@ from bson.codec_options import ( ) from bson.son import SON from bson.tz_util import utc +from pymongo import event_loggers, message, monitoring +from pymongo.client_options import ClientOptions +from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT +from pymongo.compression_support import _have_snappy, _have_zstd from pymongo.driver_info import DriverInfo from pymongo.errors import ( AutoReconnect, @@ -96,29 +100,22 @@ from pymongo.errors import ( ServerSelectionTimeoutError, WriteConcernError, ) +from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent +from pymongo.pool_options import _MAX_METADATA_SIZE, _METADATA, ENV_VAR_K8S, PoolOptions +from pymongo.read_preferences import ReadPreference +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import readable_server_selector, writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import event_loggers, message, monitoring -from pymongo.synchronous.client_options import ClientOptions from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT -from pymongo.synchronous.compression_support import _have_snappy, _have_zstd from pymongo.synchronous.cursor import Cursor, CursorType from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent from pymongo.synchronous.pool import ( - _MAX_METADATA_SIZE, - _METADATA, - ENV_VAR_K8S, Connection, - PoolOptions, ) -from pymongo.synchronous.read_preferences import ReadPreference -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import readable_server_selector, writable_server_selector from pymongo.synchronous.settings import TOPOLOGY_TYPE from pymongo.synchronous.topology import _ErrorContext -from pymongo.synchronous.topology_description import TopologyDescription +from pymongo.topology_description import TopologyDescription from pymongo.write_concern import WriteConcern @@ -482,13 +479,13 @@ class ClientUnitTest(unittest.TestCase): def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. - from pymongo.synchronous.srv_resolver import _resolve + from pymongo.srv_resolver import _resolve patched_resolver = FunctionCallRecorder(_resolve) - pymongo.synchronous.srv_resolver._resolve = patched_resolver + pymongo.srv_resolver._resolve = patched_resolver def reset_resolver(): - pymongo.synchronous.srv_resolver._resolve = _resolve + pymongo.srv_resolver._resolve = _resolve self.addCleanup(reset_resolver) @@ -577,7 +574,7 @@ class ClientUnitTest(unittest.TestCase): with self.assertRaisesRegex(ConfigurationError, expected): MongoClient(**{typo: "standard"}) # type: ignore[arg-type] - @patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.srv_resolver._SrvResolver.get_hosts") def test_detected_environment_logging(self, mock_get_hosts): normal_hosts = [ "normal.host.com", @@ -599,7 +596,7 @@ class ClientUnitTest(unittest.TestCase): logs = [record.message for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) - @patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.srv_resolver._SrvResolver.get_hosts") def test_detected_environment_warning(self, mock_get_hosts): with self._caplog.at_level(logging.WARN): normal_hosts = [ diff --git a/test/test_collation.py b/test/test_collation.py index f4830da5d..bedf0a2ea 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -21,15 +21,15 @@ from test import IntegrationTest, client_context, unittest from test.utils import EventListener, rs_or_single_client from typing import Any -from pymongo.errors import ConfigurationError -from pymongo.synchronous.collation import ( +from pymongo.collation import ( Collation, CollationAlternate, CollationCaseFirst, CollationMaxVariable, CollationStrength, ) -from pymongo.synchronous.operations import ( +from pymongo.errors import ConfigurationError +from pymongo.operations import ( DeleteMany, DeleteOne, IndexModel, diff --git a/test/test_collection.py b/test/test_collection.py index 4bbe0fb57..0de506e0f 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -45,6 +45,7 @@ from bson.raw_bson import RawBSONDocument from bson.regex import Regex from bson.son import SON from pymongo import ASCENDING, DESCENDING, GEO2D, GEOSPHERE, HASHED, TEXT +from pymongo.bulk_shared import BulkWriteError from pymongo.cursor_shared import CursorType from pymongo.errors import ( ConfigurationError, @@ -57,20 +58,19 @@ from pymongo.errors import ( OperationFailure, WriteConcernError, ) +from pymongo.message import _COMMAND_OVERHEAD, _gen_find_command +from pymongo.operations import * from pymongo.read_concern import DEFAULT_READ_CONCERN +from pymongo.read_preferences import ReadPreference from pymongo.results import ( DeleteResult, InsertManyResult, InsertOneResult, UpdateResult, ) -from pymongo.synchronous.bulk import BulkWriteError from pymongo.synchronous.collection import Collection, ReturnDocument from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.operations import * -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern diff --git a/test/test_comment.py b/test/test_comment.py index f9630655c..931446ef3 100644 --- a/test/test_comment.py +++ b/test/test_comment.py @@ -25,8 +25,8 @@ from test import IntegrationTest, client_context, unittest from test.utils import EventListener, rs_or_single_client from bson.dbref import DBRef +from pymongo.operations import IndexModel from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.operations import IndexModel class Empty: diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 4ddbd07bb..9ee3202e1 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -45,7 +45,7 @@ from pymongo.errors import ( PyMongoError, WaitQueueTimeoutError, ) -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( ConnectionCheckedInEvent, ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, @@ -60,9 +60,9 @@ from pymongo.synchronous.monitoring import ( PoolCreatedEvent, PoolReadyEvent, ) +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.pool import PoolState, _PoolClosedError -from pymongo.synchronous.read_preferences import ReadPreference -from pymongo.synchronous.topology_description import updated_topology_description +from pymongo.topology_description import updated_topology_description OBJECT_TYPES = { # Event types. diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index bb80bda93..674612693 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -28,8 +28,8 @@ from test.utils import ( ) from bson import SON +from pymongo import monitoring from pymongo.errors import NotPrimaryError -from pymongo.synchronous import monitoring from pymongo.synchronous.collection import Collection from pymongo.write_concern import WriteConcern diff --git a/test/test_crud_v1.py b/test/test_crud_v1.py index b13e4c844..d528a1dfe 100644 --- a/test/test_crud_v1.py +++ b/test/test_crud_v1.py @@ -29,14 +29,9 @@ from test.utils import ( drop_collections, ) -from pymongo import WriteConcern +from pymongo import WriteConcern, operations from pymongo.errors import PyMongoError -from pymongo.read_concern import ReadConcern -from pymongo.results import BulkWriteResult, _WriteResult -from pymongo.synchronous import operations -from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.cursor import Cursor -from pymongo.synchronous.operations import ( +from pymongo.operations import ( DeleteMany, DeleteOne, InsertOne, @@ -44,6 +39,10 @@ from pymongo.synchronous.operations import ( UpdateMany, UpdateOne, ) +from pymongo.read_concern import ReadConcern +from pymongo.results import BulkWriteResult, _WriteResult +from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.cursor import Cursor # Location of JSON test specifications. _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "crud", "v1") diff --git a/test/test_cursor.py b/test/test_cursor.py index c354c42b3..8a6f1f404 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -41,14 +41,13 @@ from test.utils import ( from bson import decode_all from bson.code import Code -from bson.son import SON from pymongo import ASCENDING, DESCENDING +from pymongo.collation import Collation from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure +from pymongo.operations import _IndexList from pymongo.read_concern import ReadConcern -from pymongo.synchronous.collation import Collation +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.cursor import Cursor, CursorType -from pymongo.synchronous.operations import _IndexList -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern diff --git a/test/test_custom_types.py b/test/test_custom_types.py index d946eee17..7daf83244 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -53,8 +53,8 @@ from bson.int64 import Int64 from bson.raw_bson import RawBSONDocument from gridfs import GridIn, GridOut from pymongo.errors import DuplicateKeyError +from pymongo.message import _CursorAddress from pymongo.synchronous.collection import ReturnDocument -from pymongo.synchronous.message import _CursorAddress class DecimalEncoder(TypeEncoder): diff --git a/test/test_database.py b/test/test_database.py index 82c4a086e..b56a053f0 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -38,6 +38,7 @@ from bson.int64 import Int64 from bson.objectid import ObjectId from bson.regex import Regex from bson.son import SON +from pymongo import helpers_shared from pymongo.asynchronous import auth from pymongo.errors import ( CollectionInvalid, @@ -48,11 +49,10 @@ from pymongo.errors import ( WriteConcernError, ) from pymongo.read_concern import ReadConcern -from pymongo.synchronous import helpers +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.collection import Collection from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern @@ -578,10 +578,10 @@ class TestDatabase(IntegrationTest): # Sometimes (SERVER-10891) the server's response to a badly-formatted # command document will have no 'ok' field. We should raise # OperationFailure instead of KeyError. - self.assertRaises(OperationFailure, helpers._check_command_response, {}, None) + self.assertRaises(OperationFailure, helpers_shared._check_command_response, {}, None) try: - helpers._check_command_response({"$err": "foo"}, None) + helpers_shared._check_command_response({"$err": "foo"}, None) except OperationFailure as e: self.assertEqual(e.args[0], "foo, full error: {'$err': 'foo'}") else: @@ -595,7 +595,7 @@ class TestDatabase(IntegrationTest): } with self.assertRaises(OperationFailure) as context: - helpers._check_command_response(error_document, None) + helpers_shared._check_command_response(error_document, None) self.assertIn("inner", str(context.exception)) @@ -605,7 +605,7 @@ class TestDatabase(IntegrationTest): error_document = {"ok": 0, "errmsg": "outer", "raw": {"shard0/host0,host1": {}}} with self.assertRaises(OperationFailure) as context: - helpers._check_command_response(error_document, None) + helpers_shared._check_command_response(error_document, None) self.assertIn("outer", str(context.exception)) @@ -613,7 +613,7 @@ class TestDatabase(IntegrationTest): error_document = {"ok": 0, "errmsg": "outer", "raw": {"shard0/host0,host1": {"ok": 0}}} with self.assertRaises(OperationFailure) as context: - helpers._check_command_response(error_document, None) + helpers_shared._check_command_response(error_document, None) self.assertIn("outer", str(context.exception)) diff --git a/test/test_default_exports.py b/test/test_default_exports.py index 91f94c9db..d9301d222 100644 --- a/test/test_default_exports.py +++ b/test/test_default_exports.py @@ -26,7 +26,7 @@ BSON_IGNORE = [] GRIDFS_IGNORE = [ "ASCENDING", "DESCENDING", - "ClientSession", + "AsyncClientSession", "Collection", "ObjectId", "validate_string", diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 53602eaec..ef32afbcd 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -40,7 +40,7 @@ from test.utils import ( from unittest.mock import patch from bson import Timestamp, json_util -from pymongo import MongoClient +from pymongo import MongoClient, common, monitoring from pymongo.errors import ( AutoReconnect, ConfigurationError, @@ -48,15 +48,14 @@ from pymongo.errors import ( NotPrimaryError, OperationFailure, ) -from pymongo.synchronous import common, monitoring -from pymongo.synchronous.hello import Hello, HelloCompat -from pymongo.synchronous.helpers import _check_command_response, _check_write_command_response -from pymongo.synchronous.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent -from pymongo.synchronous.server_description import SERVER_TYPE, ServerDescription +from pymongo.hello import Hello, HelloCompat +from pymongo.helpers_shared import _check_command_response, _check_write_command_response +from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent +from pymongo.server_description import SERVER_TYPE, ServerDescription from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE -from pymongo.synchronous.uri_parser import parse_uri +from pymongo.topology_description import TOPOLOGY_TYPE +from pymongo.uri_parser import parse_uri # Location of JSON test specifications. SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring") diff --git a/test/test_dns.py b/test/test_dns.py index a2d0fd8b4..b4c5e3684 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -25,10 +25,10 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest from test.utils import wait_until +from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.synchronous.common import validate_read_preference_tags from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.uri_parser import parse_uri, split_hosts +from pymongo.uri_parser import parse_uri, split_hosts class TestDNSRepl(unittest.TestCase): diff --git a/test/test_encryption.py b/test/test_encryption.py index 585847b10..9306876d1 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -70,6 +70,7 @@ from bson.json_util import JSONOptions from bson.son import SON from pymongo import ReadPreference from pymongo.cursor_shared import CursorType +from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts, RangeOpts from pymongo.errors import ( AutoReconnect, BulkWriteError, @@ -82,11 +83,10 @@ from pymongo.errors import ( ServerSelectionTimeoutError, WriteError, ) +from pymongo.operations import InsertOne, ReplaceOne, UpdateOne from pymongo.synchronous import encryption from pymongo.synchronous.encryption import Algorithm, ClientEncryption, QueryType -from pymongo.synchronous.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts, RangeOpts from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.operations import InsertOne, ReplaceOne, UpdateOne from pymongo.write_concern import WriteConcern KMS_PROVIDERS = {"local": {"key": b"\x00" * 96}} diff --git a/test/test_examples.py b/test/test_examples.py index f0d8bd554..e003d8459 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -27,8 +27,8 @@ from test.utils import rs_client, wait_until import pymongo from pymongo.errors import ConnectionFailure, OperationFailure from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference from pymongo.server_api import ServerApi -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern diff --git a/test/test_grid_file.py b/test/test_grid_file.py index c45c5b577..f663f1365 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -42,7 +42,7 @@ from gridfs.synchronous.grid_file import ( ) from pymongo import MongoClient from pymongo.errors import ConfigurationError, ServerSelectionTimeoutError -from pymongo.synchronous.message import _CursorAddress +from pymongo.message import _CursorAddress class TestGridFileNoConnect(unittest.TestCase): diff --git a/test/test_gridfs.py b/test/test_gridfs.py index f15f10804..27b38dc0b 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -37,9 +37,9 @@ from pymongo.errors import ( NotPrimaryError, ServerSelectionTimeoutError, ) +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.read_preferences import ReadPreference class JustWrite(threading.Thread): diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 6ce7b7922..c3945d105 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -41,8 +41,8 @@ from pymongo.errors import ( ServerSelectionTimeoutError, WriteConcernError, ) +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.read_preferences import ReadPreference class JustWrite(threading.Thread): diff --git a/test/test_heartbeat_monitoring.py b/test/test_heartbeat_monitoring.py index 0566fffe5..1302df8fd 100644 --- a/test/test_heartbeat_monitoring.py +++ b/test/test_heartbeat_monitoring.py @@ -23,7 +23,7 @@ from test import IntegrationTest, client_knobs, unittest from test.utils import HeartbeatEventListener, MockPool, single_client, wait_until from pymongo.errors import ConnectionFailure -from pymongo.synchronous.hello import Hello, HelloCompat +from pymongo.hello import Hello, HelloCompat from pymongo.synchronous.monitor import Monitor diff --git a/test/test_index_management.py b/test/test_index_management.py index b8409178d..5b6653dcb 100644 --- a/test/test_index_management.py +++ b/test/test_index_management.py @@ -29,8 +29,8 @@ from test.utils import AllowListEventListener, EventListener from pymongo import MongoClient from pymongo.errors import OperationFailure +from pymongo.operations import SearchIndexModel from pymongo.read_concern import ReadConcern -from pymongo.synchronous.operations import SearchIndexModel from pymongo.write_concern import WriteConcern _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "index_management") diff --git a/test/test_logger.py b/test/test_logger.py index 60abcadc4..1dfa0724e 100644 --- a/test/test_logger.py +++ b/test/test_logger.py @@ -21,7 +21,7 @@ from unittest.mock import patch from bson import json_util from pymongo.errors import OperationFailure -from pymongo.synchronous.logger import _DEFAULT_DOCUMENT_LENGTH +from pymongo.logger import _DEFAULT_DOCUMENT_LENGTH # https://github.com/mongodb/specifications/tree/master/source/command-logging-and-monitoring/tests#prose-tests diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index d41f216eb..1b0130f7d 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -20,7 +20,7 @@ import sys import time import warnings -from pymongo.synchronous.operations import _Op +from pymongo.operations import _Op sys.path[0:0] = [""] @@ -30,7 +30,7 @@ from test.utils_selection_tests import create_selection_tests from pymongo import MongoClient from pymongo.errors import ConfigurationError -from pymongo.synchronous.server_selectors import writable_server_selector +from pymongo.server_selectors import writable_server_selector # Location of JSON test specifications. _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "max_staleness") diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index 812a9072f..f39a1cb03 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -18,7 +18,7 @@ from __future__ import annotations import sys import threading -from pymongo.synchronous.operations import _Op +from pymongo.operations import _Op sys.path[0:0] = [""] @@ -27,8 +27,8 @@ from test.pymongo_mocks import MockClient from test.utils import connected, wait_until from pymongo.errors import AutoReconnect, InvalidOperation -from pymongo.synchronous.server_selectors import writable_server_selector -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE +from pymongo.server_selectors import writable_server_selector +from pymongo.topology_description import TOPOLOGY_TYPE @client_context.require_connection diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 7f8888815..ed6a3d0bc 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -27,11 +27,10 @@ from test.utils import EventListener, rs_or_single_client, single_client, wait_u from bson.int64 import Int64 from bson.objectid import ObjectId from bson.son import SON -from pymongo import CursorType, DeleteOne, InsertOne, UpdateOne +from pymongo import CursorType, DeleteOne, InsertOne, UpdateOne, monitoring from pymongo.errors import AutoReconnect, NotPrimaryError, OperationFailure -from pymongo.synchronous import monitoring +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern diff --git a/test/test_pooling.py b/test/test_pooling.py index 5ed701517..aa32f9f77 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -24,10 +24,9 @@ import time from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.son import SON -from pymongo import MongoClient, timeout +from pymongo import MongoClient, message, timeout from pymongo.errors import AutoReconnect, ConnectionFailure, DuplicateKeyError -from pymongo.synchronous import message -from pymongo.synchronous.hello_compat import HelloCompat +from pymongo.hello import HelloCompat sys.path[0:0] = [""] diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index da550e7db..4df791e94 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -22,7 +22,7 @@ import random import sys from typing import Any -from pymongo.synchronous.operations import _Op +from pymongo.operations import _Op sys.path[0:0] = [""] @@ -39,10 +39,8 @@ from test.version import Version from bson.son import SON from pymongo.errors import ConfigurationError, OperationFailure -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.message import _maybe_add_read_preference -from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.read_preferences import ( +from pymongo.message import _maybe_add_read_preference +from pymongo.read_preferences import ( MovingAverage, Nearest, Primary, @@ -51,8 +49,10 @@ from pymongo.synchronous.read_preferences import ( Secondary, SecondaryPreferred, ) -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import Selection, readable_server_selector +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import Selection, readable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index 93986d824..34aa1f754 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -40,9 +40,9 @@ from pymongo.errors import ( WriteError, WTimeoutError, ) +from pymongo.operations import IndexModel, InsertOne from pymongo.read_concern import ReadConcern from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.operations import IndexModel, InsertOne from pymongo.write_concern import WriteConcern _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "read_write_concern") diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index 569f7c275..9ea546ba9 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -43,13 +43,13 @@ from test.utils import ( ) from test.utils_spec_runner import SpecRunner -from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, ConnectionCheckOutFailedReason, PoolClearedEvent, ) +from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern # Location of JSON test specifications. diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 347e6c138..45a740e84 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -47,15 +47,14 @@ from pymongo.errors import ( ServerSelectionTimeoutError, WriteConcernError, ) -from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( CommandSucceededEvent, ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, ConnectionCheckOutFailedReason, PoolClearedEvent, ) -from pymongo.synchronous.operations import ( +from pymongo.operations import ( DeleteMany, DeleteOne, InsertOne, @@ -63,6 +62,7 @@ from pymongo.synchronous.operations import ( UpdateMany, UpdateOne, ) +from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern # Location of JSON test specifications. diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index c955dc408..5faee9b10 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -31,15 +31,14 @@ from test.utils import ( ) from bson.json_util import object_hook -from pymongo import MongoClient +from pymongo import MongoClient, monitoring +from pymongo.common import clean_node from pymongo.errors import ConnectionFailure, NotPrimaryError -from pymongo.synchronous import monitoring +from pymongo.hello import Hello +from pymongo.server_description import ServerDescription from pymongo.synchronous.collection import Collection -from pymongo.synchronous.common import clean_node -from pymongo.synchronous.hello import Hello from pymongo.synchronous.monitor import Monitor -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE +from pymongo.topology_description import TOPOLOGY_TYPE # Location of JSON test specifications. _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sdam_monitoring") diff --git a/test/test_server.py b/test/test_server.py index b5c6c1365..45d01c10d 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -21,9 +21,9 @@ sys.path[0:0] = [""] from test import unittest -from pymongo.synchronous.hello import Hello +from pymongo.hello import Hello +from pymongo.server_description import ServerDescription from pymongo.synchronous.server import Server -from pymongo.synchronous.server_description import ServerDescription class TestServer(unittest.TestCase): diff --git a/test/test_server_description.py b/test/test_server_description.py index 273c001c9..ee05e95cf 100644 --- a/test/test_server_description.py +++ b/test/test_server_description.py @@ -23,9 +23,9 @@ from test import unittest from bson.int64 import Int64 from bson.objectid import ObjectId +from pymongo.hello import Hello, HelloCompat +from pymongo.server_description import ServerDescription from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.hello import Hello, HelloCompat -from pymongo.synchronous.server_description import ServerDescription address = ("localhost", 27017) diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 94289a00a..d3526617f 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -20,12 +20,12 @@ import sys from pymongo import MongoClient, ReadPreference from pymongo.errors import ServerSelectionTimeoutError -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.operations import _Op -from pymongo.synchronous.server_selectors import writable_server_selector +from pymongo.hello import HelloCompat +from pymongo.operations import _Op +from pymongo.server_selectors import writable_server_selector from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology -from pymongo.synchronous.typings import strip_optional +from pymongo.typings import strip_optional sys.path[0:0] = [""] diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index c7384590d..9dced595c 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -27,9 +27,9 @@ from test.utils import ( ) from test.utils_selection_tests import create_topology -from pymongo.synchronous.common import clean_node -from pymongo.synchronous.operations import _Op -from pymongo.synchronous.read_preferences import ReadPreference +from pymongo.common import clean_node +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference # Location of JSON test specifications. TEST_PATH = os.path.join( diff --git a/test/test_server_selection_rtt.py b/test/test_server_selection_rtt.py index 26e871c40..a129af458 100644 --- a/test/test_server_selection_rtt.py +++ b/test/test_server_selection_rtt.py @@ -23,7 +23,7 @@ sys.path[0:0] = [""] from test import unittest -from pymongo.synchronous.read_preferences import MovingAverage +from pymongo.read_preferences import MovingAverage # Location of JSON test specifications. _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "server_selection/rtt") diff --git a/test/test_session.py b/test/test_session.py index 7098cec6f..fe2544b26 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -35,14 +35,13 @@ from test.utils import ( from bson import DBRef from gridfs import GridFS, GridFSBucket -from pymongo import ASCENDING +from pymongo import ASCENDING, monitoring +from pymongo.common import _MAX_END_SESSIONS from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure +from pymongo.operations import IndexModel, InsertOne, UpdateOne from pymongo.read_concern import ReadConcern -from pymongo.synchronous import monitoring from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.common import _MAX_END_SESSIONS from pymongo.synchronous.cursor import Cursor -from pymongo.synchronous.operations import IndexModel, InsertOne, UpdateOne # Ignore auth commands like saslStart, so we can assert lsid is in all commands. diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 0c293874b..405db14ac 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -25,10 +25,10 @@ from test import client_knobs, unittest from test.utils import FunctionCallRecorder, wait_until import pymongo +from pymongo import common from pymongo.errors import ConfigurationError -from pymongo.synchronous import common +from pymongo.srv_resolver import _have_dnspython from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.srv_resolver import _have_dnspython WAIT_TIME = 0.1 @@ -51,9 +51,7 @@ class SrvPollingKnobs: def enable(self): self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL - self.old_dns_resolver_response = ( - pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl - ) + self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl if self.min_srv_rescan_interval is not None: common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval @@ -73,14 +71,14 @@ class SrvPollingKnobs: else: patch_func = mock_get_hosts_and_min_ttl - pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore def __enter__(self): self.enable() def disable(self): common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore - pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore self.old_dns_resolver_response ) @@ -133,10 +131,7 @@ class TestSrvPolling(unittest.TestCase): def predicate(): if set(expected_nodelist) == set(self.get_nodelist(client)): - return ( - pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count - >= 1 - ) + return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1 return False wait_until(predicate, "Node list equals expected nodelist", timeout=timeout) @@ -146,7 +141,7 @@ class TestSrvPolling(unittest.TestCase): msg = "Client nodelist %s changed unexpectedly (expected %s)" raise self.fail(msg % (nodelist, expected_nodelist)) self.assertGreaterEqual( - pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore 1, "resolver was never called", ) diff --git a/test/test_ssl.py b/test/test_ssl.py index 56dd23a8e..3b307df39 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -33,8 +33,8 @@ from urllib.parse import quote_plus from pymongo import MongoClient, ssl_support from pymongo.errors import ConfigurationError, ConnectionFailure, OperationFailure +from pymongo.hello import HelloCompat from pymongo.ssl_support import HAVE_SSL, _ssl, get_ssl_context -from pymongo.synchronous.hello_compat import HelloCompat from pymongo.write_concern import WriteConcern _HAVE_PYOPENSSL = False diff --git a/test/test_streaming_protocol.py b/test/test_streaming_protocol.py index 054910ca1..44e673822 100644 --- a/test/test_streaming_protocol.py +++ b/test/test_streaming_protocol.py @@ -29,8 +29,8 @@ from test.utils import ( wait_until, ) -from pymongo.synchronous import monitoring -from pymongo.synchronous.hello_compat import HelloCompat +from pymongo import monitoring +from pymongo.hello import HelloCompat class TestStreamingProtocol(IntegrationTest): diff --git a/test/test_topology.py b/test/test_topology.py index e6fd5a3c0..3725bab93 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -17,7 +17,7 @@ from __future__ import annotations import sys -from pymongo.synchronous.operations import _Op +from pymongo.operations import _Op sys.path[0:0] = [""] @@ -26,19 +26,19 @@ from test.pymongo_mocks import DummyMonitor from test.utils import MockPool, wait_until from bson.objectid import ObjectId +from pymongo import common from pymongo.errors import AutoReconnect, ConfigurationError, ConnectionFailure +from pymongo.hello import Hello, HelloCompat +from pymongo.read_preferences import ReadPreference, Secondary +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import any_server_selector, writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import common -from pymongo.synchronous.hello import Hello, HelloCompat from pymongo.synchronous.monitor import Monitor from pymongo.synchronous.pool import PoolOptions -from pymongo.synchronous.read_preferences import ReadPreference, Secondary from pymongo.synchronous.server import Server -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import any_server_selector, writable_server_selector from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext, _filter_servers -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE +from pymongo.topology_description import TOPOLOGY_TYPE class SetNameDiscoverySettings(TopologySettings): @@ -548,8 +548,8 @@ class TestMultiServerTopology(TopologyTest): HelloCompat.LEGACY_CMD: True, "setName": "rs", "hosts": ["a"], - "minWireVersion": 22, - "maxWireVersion": 24, + "minWireVersion": 26, + "maxWireVersion": 27, }, ) @@ -559,7 +559,7 @@ class TestMultiServerTopology(TopologyTest): # Error message should say which server failed and why. self.assertEqual( str(e), - "Server at a:27017 requires wire version 22, but this version " + "Server at a:27017 requires wire version 26, but this version " "of PyMongo only supports up to %d." % (common.MAX_SUPPORTED_WIRE_VERSION,), ) else: diff --git a/test/test_transactions.py b/test/test_transactions.py index 4279c942e..62525742d 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -43,13 +43,13 @@ from pymongo.errors import ( InvalidOperation, OperationFailure, ) +from pymongo.operations import IndexModel, InsertOne from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference from pymongo.synchronous import client_session from pymongo.synchronous.client_session import TransactionOptions from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.cursor import Cursor -from pymongo.synchronous.operations import IndexModel, InsertOne -from pymongo.synchronous.read_preferences import ReadPreference _TXN_TESTS_DEBUG = os.environ.get("TRANSACTION_TESTS_DEBUG") diff --git a/test/test_typing.py b/test/test_typing.py index 552590c64..f423b70a3 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -75,9 +75,9 @@ from bson import CodecOptions, decode, decode_all, decode_file_iter, decode_iter from bson.raw_bson import RawBSONDocument from bson.son import SON from pymongo import ASCENDING, MongoClient +from pymongo.operations import DeleteOne, InsertOne, ReplaceOne +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.collection import Collection -from pymongo.synchronous.operations import DeleteOne, InsertOne, ReplaceOne -from pymongo.synchronous.read_preferences import ReadPreference TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mypy_fails") diff --git a/test/test_uri_parser.py b/test/test_uri_parser.py index 09178e280..27f5fd2fb 100644 --- a/test/test_uri_parser.py +++ b/test/test_uri_parser.py @@ -28,7 +28,7 @@ from test import unittest from bson.binary import JAVA_LEGACY from pymongo import ReadPreference from pymongo.errors import ConfigurationError, InvalidURI -from pymongo.synchronous.uri_parser import ( +from pymongo.uri_parser import ( parse_uri, parse_userinfo, split_hosts, diff --git a/test/test_uri_spec.py b/test/test_uri_spec.py index a5ec43649..3a8bf6275 100644 --- a/test/test_uri_spec.py +++ b/test/test_uri_spec.py @@ -26,9 +26,9 @@ sys.path[0:0] = [""] from test import clear_warning_registry, unittest -from pymongo.synchronous.common import INTERNAL_URI_OPTION_NAME_MAP, validate -from pymongo.synchronous.compression_support import _have_snappy -from pymongo.synchronous.uri_parser import parse_uri +from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate +from pymongo.compression_support import _have_snappy +from pymongo.uri_parser import parse_uri CONN_STRING_TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.path.join("connection_string", "test") diff --git a/test/unified_format.py b/test/unified_format.py index 0c200a15a..50190982c 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -68,6 +68,7 @@ from bson.objectid import ObjectId from bson.regex import RE_TYPE, Regex from gridfs import GridFSBucket, GridOut from pymongo import ASCENDING, CursorType, MongoClient, _csot +from pymongo.encryption_options import _HAVE_PYMONGOCRYPT from pymongo.errors import ( BulkWriteError, ConfigurationError, @@ -78,18 +79,7 @@ from pymongo.errors import ( OperationFailure, PyMongoError, ) -from pymongo.read_concern import ReadConcern -from pymongo.results import BulkWriteResult -from pymongo.server_api import ServerApi -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.change_stream import ChangeStream -from pymongo.synchronous.client_session import ClientSession, TransactionOptions, _TxnState -from pymongo.synchronous.collection import Collection -from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.database import Database -from pymongo.synchronous.encryption import ClientEncryption -from pymongo.synchronous.encryption_options import _HAVE_PYMONGOCRYPT -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( _SENSITIVE_COMMANDS, CommandFailedEvent, CommandListener, @@ -125,12 +115,22 @@ from pymongo.synchronous.monitoring import ( _ServerEvent, _ServerHeartbeatEvent, ) -from pymongo.synchronous.operations import SearchIndexModel -from pymongo.synchronous.read_preferences import ReadPreference -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import Selection, writable_server_selector -from pymongo.synchronous.topology_description import TopologyDescription -from pymongo.synchronous.typings import _Address +from pymongo.operations import SearchIndexModel +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.results import BulkWriteResult +from pymongo.server_api import ServerApi +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import Selection, writable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.synchronous.change_stream import ChangeStream +from pymongo.synchronous.client_session import ClientSession, TransactionOptions, _TxnState +from pymongo.synchronous.collection import Collection +from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.database import Database +from pymongo.synchronous.encryption import ClientEncryption +from pymongo.topology_description import TopologyDescription +from pymongo.typings import _Address from pymongo.write_concern import WriteConcern JSON_OPTS = json_util.JSONOptions(tz_aware=False) diff --git a/test/utils.py b/test/utils.py index bd33270c1..98666e271 100644 --- a/test/utils.py +++ b/test/utils.py @@ -36,18 +36,13 @@ from typing import Any, List from bson import json_util from bson.objectid import ObjectId from bson.son import SON -from pymongo import AsyncMongoClient +from pymongo import AsyncMongoClient, monitoring, operations, read_preferences from pymongo.cursor_shared import CursorType from pymongo.errors import ConfigurationError, OperationFailure -from pymongo.helpers_constants import _SENSITIVE_COMMANDS +from pymongo.hello import HelloCompat +from pymongo.helpers_shared import _SENSITIVE_COMMANDS from pymongo.lock import _create_lock -from pymongo.read_concern import ReadConcern -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import monitoring, operations, read_preferences -from pymongo.synchronous.collection import ReturnDocument -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( ConnectionCheckedInEvent, ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, @@ -60,11 +55,15 @@ from pymongo.synchronous.monitoring import ( PoolCreatedEvent, PoolReadyEvent, ) -from pymongo.synchronous.operations import _Op +from pymongo.operations import _Op +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.server_selectors import any_server_selector, writable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.synchronous.collection import ReturnDocument +from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import _CancellationContext, _PoolGeneration -from pymongo.synchronous.read_preferences import ReadPreference -from pymongo.synchronous.server_selectors import any_server_selector, writable_server_selector -from pymongo.synchronous.uri_parser import parse_uri +from pymongo.uri_parser import parse_uri from pymongo.write_concern import WriteConcern IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50) diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index 7673e9bc2..cef5780d2 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -19,8 +19,6 @@ import datetime import os import sys -from pymongo.synchronous.operations import _Op - sys.path[0:0] = [""] from test import unittest @@ -28,11 +26,12 @@ from test.pymongo_mocks import DummyMonitor from test.utils import MockPool, parse_read_preference from bson import json_util +from pymongo.common import HEARTBEAT_FREQUENCY, clean_node from pymongo.errors import AutoReconnect, ConfigurationError -from pymongo.synchronous.common import HEARTBEAT_FREQUENCY, clean_node -from pymongo.synchronous.hello import Hello, HelloCompat -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import writable_server_selector +from pymongo.hello import Hello, HelloCompat +from pymongo.operations import _Op +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import writable_server_selector from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index e38d53b94..091533582 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -40,11 +40,11 @@ from bson.son import SON from gridfs import GridFSBucket from pymongo.errors import BulkWriteError, OperationFailure, PyMongoError from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult, _WriteResult from pymongo.synchronous import client_session from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.cursor import Cursor -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern diff --git a/tools/synchro.py b/tools/synchro.py index 2a0c4f431..1c555748f 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -33,6 +33,9 @@ replacements = { "AsyncCommandCursor": "CommandCursor", "AsyncRawBatchCursor": "RawBatchCursor", "AsyncRawBatchCommandCursor": "RawBatchCommandCursor", + "AsyncClientSession": "ClientSession", + "_AsyncBulk": "_Bulk", + "AsyncConnection": "Connection", "async_command": "command", "async_receive_message": "receive_message", "async_sendall": "sendall", @@ -89,6 +92,8 @@ docstring_replacements: dict[tuple[str, str], str] = { type_replacements = {"_Condition": "threading.Condition"} +import_replacements = {"test.synchronous": "test"} + _pymongo_base = "./pymongo/asynchronous/" _gridfs_base = "./gridfs/asynchronous/" _test_base = "./test/asynchronous/" @@ -125,26 +130,7 @@ sync_test_files = [ ] -docstring_translate_files = [ - _pymongo_dest_base + f - for f in [ - "aggregation.py", - "change_stream.py", - "collection.py", - "command_cursor.py", - "cursor.py", - "client_options.py", - "client_session.py", - "database.py", - "encryption.py", - "encryption_options.py", - "mongo_client.py", - "network.py", - "operations.py", - "pool.py", - "topology.py", - ] -] +docstring_translate_files = sync_files + sync_gridfs_files + sync_test_files def process_files(files: list[str]) -> None: @@ -152,23 +138,31 @@ def process_files(files: list[str]) -> None: if "__init__" not in file or "__init__" and "test" in file: with open(file, "r+") as f: lines = f.readlines() - lines = apply_is_sync(lines) + lines = apply_is_sync(lines, file) lines = translate_coroutine_types(lines) lines = translate_async_sleeps(lines) if file in docstring_translate_files: lines = translate_docstrings(lines) translate_locks(lines) translate_types(lines) + if file in sync_test_files: + translate_imports(lines) f.seek(0) f.writelines(lines) f.truncate() -def apply_is_sync(lines: list[str]) -> list[str]: - is_sync = next(iter([line for line in lines if line.startswith("_IS_SYNC = ")])) - index = lines.index(is_sync) - is_sync = is_sync.replace("False", "True") - lines[index] = is_sync +def apply_is_sync(lines: list[str], file: str) -> list[str]: + try: + is_sync = next(iter([line for line in lines if line.startswith("_IS_SYNC = ")])) + index = lines.index(is_sync) + is_sync = is_sync.replace("False", "True") + lines[index] = is_sync + except StopIteration as e: + print( + f"Missing _IS_SYNC at top of async file {file.replace('synchronous', 'asynchronous')}" + ) + raise e return lines @@ -212,6 +206,15 @@ def translate_types(lines: list[str]) -> list[str]: return lines +def translate_imports(lines: list[str]) -> list[str]: + for k, v in import_replacements.items(): + matches = [line for line in lines if k in line and "import" in line] + for line in matches: + index = lines.index(line) + lines[index] = line.replace(k, v) + return lines + + def translate_async_sleeps(lines: list[str]) -> list[str]: blocking_sleeps = [line for line in lines if "asyncio.sleep(0)" in line] lines = [line for line in lines if line not in blocking_sleeps] diff --git a/tools/synchro.sh b/tools/synchro.sh index fe48b663b..f5e7ab68c 100644 --- a/tools/synchro.sh +++ b/tools/synchro.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/bash -eu python ./tools/synchro.py python -m ruff check pymongo/synchronous/ gridfs/synchronous/ test/synchronous --fix --silent diff --git a/tox.ini b/tox.ini index 331c73ce1..a154bf424 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,8 @@ requires = envlist = # Test using the system Python. test, + # Test async tests using the system Python. + test-async, # Test using the run-tests Evergreen script. test-eg, # Set up encryption files and services. @@ -34,6 +36,7 @@ envlist = labels = # Use labels and -m instead of -e so that tox -m