Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2024-07-02 09:05:16 -05:00
commit 66935c06cb
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
198 changed files with 9321 additions and 20036 deletions

View File

@ -58,14 +58,12 @@ functions:
export MONGO_ORCHESTRATION_HOME="$DRIVERS_TOOLS/.evergreen/orchestration"
export MONGODB_BINARIES="$DRIVERS_TOOLS/mongodb/bin"
export UPLOAD_BUCKET="${project}"
cat <<EOT > expansion.yml
CURRENT_VERSION: "$CURRENT_VERSION"
DRIVERS_TOOLS: "$DRIVERS_TOOLS"
MONGO_ORCHESTRATION_HOME: "$MONGO_ORCHESTRATION_HOME"
MONGODB_BINARIES: "$MONGODB_BINARIES"
UPLOAD_BUCKET: "$UPLOAD_BUCKET"
PROJECT_DIRECTORY: "$PROJECT_DIRECTORY"
PREPARE_SHELL: |
set -o errexit
@ -73,7 +71,6 @@ functions:
export DRIVERS_TOOLS="$DRIVERS_TOOLS"
export MONGO_ORCHESTRATION_HOME="$MONGO_ORCHESTRATION_HOME"
export MONGODB_BINARIES="$MONGODB_BINARIES"
export UPLOAD_BUCKET="$UPLOAD_BUCKET"
export PROJECT_DIRECTORY="$PROJECT_DIRECTORY"
export TMPDIR="$MONGO_ORCHESTRATION_HOME/db"
@ -103,30 +100,35 @@ functions:
echo "{ \"releases\": { \"default\": \"$MONGODB_BINARIES\" }}" > $MONGO_ORCHESTRATION_HOME/orchestration.config
"upload coverage" :
- command: ec2.assume_role
params:
role_arn: ${assume_role_arn}
- command: s3.put
params:
aws_key: ${aws_key}
aws_secret: ${aws_secret}
aws_key: ${AWS_ACCESS_KEY_ID}
aws_secret: ${AWS_SECRET_ACCESS_KEY}
aws_session_token: ${AWS_SESSION_TOKEN}
local_file: src/.coverage
optional: true
# Upload the coverage report for all tasks in a single build to the same directory.
remote_file: ${UPLOAD_BUCKET}/coverage/${revision}/${version_id}/coverage/coverage.${build_variant}.${task_name}
bucket: mciuploads
remote_file: coverage/${revision}/${version_id}/coverage/coverage.${build_variant}.${task_name}
bucket: ${bucket_name}
permissions: public-read
content_type: text/html
display_name: "Raw Coverage Report"
"download and merge coverage" :
- command: ec2.assume_role
params:
role_arn: ${assume_role_arn}
- command: shell.exec
params:
silent: true
working_dir: "src"
include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]
script: |
export AWS_ACCESS_KEY_ID=${aws_key}
export AWS_SECRET_ACCESS_KEY=${aws_secret}
# Download all the task coverage files.
aws s3 cp --recursive s3://mciuploads/${UPLOAD_BUCKET}/coverage/${revision}/${version_id}/coverage/ coverage/
aws s3 cp --recursive s3://${bucket_name}/coverage/${revision}/${version_id}/coverage/ coverage/
- command: shell.exec
params:
working_dir: "src"
@ -138,24 +140,27 @@ functions:
params:
silent: true
working_dir: "src"
include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]
script: |
export AWS_ACCESS_KEY_ID=${aws_key}
export AWS_SECRET_ACCESS_KEY=${aws_secret}
aws s3 cp htmlcov/ s3://mciuploads/${UPLOAD_BUCKET}/coverage/${revision}/${version_id}/htmlcov/ --recursive --acl public-read --region us-east-1
aws s3 cp htmlcov/ s3://${bucket_name}/coverage/${revision}/${version_id}/htmlcov/ --recursive --acl public-read --region us-east-1
# Attach the index.html with s3.put so it shows up in the Evergreen UI.
- command: s3.put
params:
aws_key: ${aws_key}
aws_secret: ${aws_secret}
aws_key: ${AWS_ACCESS_KEY_ID}
aws_secret: ${AWS_SECRET_ACCESS_KEY}
aws_session_token: ${AWS_SESSION_TOKEN}
local_file: src/htmlcov/index.html
remote_file: ${UPLOAD_BUCKET}/coverage/${revision}/${version_id}/htmlcov/index.html
bucket: mciuploads
remote_file: coverage/${revision}/${version_id}/htmlcov/index.html
bucket: ${bucket_name}
permissions: public-read
content_type: text/html
display_name: "Coverage Report HTML"
"upload mo artifacts":
- command: ec2.assume_role
params:
role_arn: ${assume_role_arn}
- command: shell.exec
params:
script: |
@ -174,37 +179,43 @@ functions:
- "./**.mdmp" # Windows: minidumps
- command: s3.put
params:
aws_key: ${aws_key}
aws_secret: ${aws_secret}
aws_key: ${AWS_ACCESS_KEY_ID}
aws_secret: ${AWS_SECRET_ACCESS_KEY}
aws_session_token: ${AWS_SESSION_TOKEN}
local_file: mongo-coredumps.tgz
remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/coredumps/${task_id}-${execution}-mongodb-coredumps.tar.gz
bucket: mciuploads
remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/coredumps/${task_id}-${execution}-mongodb-coredumps.tar.gz
bucket: ${bucket_name}
permissions: public-read
content_type: ${content_type|application/gzip}
display_name: Core Dumps - Execution
optional: true
- command: s3.put
params:
aws_key: ${aws_key}
aws_secret: ${aws_secret}
aws_key: ${AWS_ACCESS_KEY_ID}
aws_secret: ${AWS_SECRET_ACCESS_KEY}
aws_session_token: ${AWS_SESSION_TOKEN}
local_file: mongodb-logs.tar.gz
remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-mongodb-logs.tar.gz
bucket: mciuploads
remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-mongodb-logs.tar.gz
bucket: ${bucket_name}
permissions: public-read
content_type: ${content_type|application/x-gzip}
display_name: "mongodb-logs.tar.gz"
- command: s3.put
params:
aws_key: ${aws_key}
aws_secret: ${aws_secret}
aws_key: ${AWS_ACCESS_KEY_ID}
aws_secret: ${AWS_SECRET_ACCESS_KEY}
aws_session_token: ${AWS_SESSION_TOKEN}
local_file: drivers-tools/.evergreen/orchestration/server.log
remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-orchestration.log
bucket: mciuploads
remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-orchestration.log
bucket: ${bucket_name}
permissions: public-read
content_type: ${content_type|text/plain}
display_name: "orchestration.log"
"upload working dir":
- command: ec2.assume_role
params:
role_arn: ${assume_role_arn}
- command: archive.targz_pack
params:
target: "working-dir.tar.gz"
@ -213,11 +224,12 @@ functions:
- "./**"
- command: s3.put
params:
aws_key: ${aws_key}
aws_secret: ${aws_secret}
aws_key: ${AWS_ACCESS_KEY_ID}
aws_secret: ${AWS_SECRET_ACCESS_KEY}
aws_session_token: ${AWS_SESSION_TOKEN}
local_file: working-dir.tar.gz
remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-working-dir.tar.gz
bucket: mciuploads
remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-working-dir.tar.gz
bucket: ${bucket_name}
permissions: public-read
content_type: ${content_type|application/x-gzip}
display_name: "working-dir.tar.gz"
@ -232,11 +244,12 @@ functions:
- "*.lock"
- command: s3.put
params:
aws_key: ${aws_key}
aws_secret: ${aws_secret}
aws_key: ${AWS_ACCESS_KEY_ID}
aws_secret: ${AWS_SECRET_ACCESS_KEY}
aws_session_token: ${AWS_SESSION_TOKEN}
local_file: drivers-dir.tar.gz
remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-drivers-dir.tar.gz
bucket: mciuploads
remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-drivers-dir.tar.gz
bucket: ${bucket_name}
permissions: public-read
content_type: ${content_type|application/x-gzip}
display_name: "drivers-dir.tar.gz"
@ -785,6 +798,9 @@ functions:
VERSION=${VERSION} ENSURE_UNIVERSAL2=${ENSURE_UNIVERSAL2} .evergreen/release.sh
"upload release":
- command: ec2.assume_role
params:
role_arn: ${assume_role_arn}
- command: archive.targz_pack
params:
target: "release-files.tgz"
@ -793,25 +809,27 @@ functions:
- "*"
- command: s3.put
params:
aws_key: ${aws_key}
aws_secret: ${aws_secret}
aws_key: ${AWS_ACCESS_KEY_ID}
aws_secret: ${AWS_SECRET_ACCESS_KEY}
aws_session_token: ${AWS_SESSION_TOKEN}
local_file: release-files.tgz
remote_file: ${UPLOAD_BUCKET}/release/${revision}/${task_id}-${execution}-release-files.tar.gz
bucket: mciuploads
remote_file: release/${revision}/${task_id}-${execution}-release-files.tar.gz
bucket: ${bucket_name}
permissions: public-read
content_type: ${content_type|application/gzip}
display_name: Release files
"download and merge releases":
- command: ec2.assume_role
params:
role_arn: ${assume_role_arn}
- command: shell.exec
params:
silent: true
include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]
script: |
export AWS_ACCESS_KEY_ID=${aws_key}
export AWS_SECRET_ACCESS_KEY=${aws_secret}
# Download all the task coverage files.
aws s3 cp --recursive s3://mciuploads/${UPLOAD_BUCKET}/release/${revision}/ release/
aws s3 cp --recursive s3://${bucket_name}/release/${revision}/ release/
- command: shell.exec
params:
shell: "bash"
@ -845,11 +863,12 @@ functions:
- "*"
- command: s3.put
params:
aws_key: ${aws_key}
aws_secret: ${aws_secret}
aws_key: ${AWS_ACCESS_KEY_ID}
aws_secret: ${AWS_SECRET_ACCESS_KEY}
aws_session_token: ${AWS_SESSION_TOKEN}
local_file: release-files-all.tgz
remote_file: ${UPLOAD_BUCKET}/release-all/${revision}/${task_id}-${execution}-release-files-all.tar.gz
bucket: mciuploads
remote_file: release-all/${revision}/${task_id}-${execution}-release-files-all.tar.gz
bucket: ${bucket_name}
permissions: public-read
content_type: ${content_type|application/gzip}
display_name: Release files all
@ -2108,7 +2127,7 @@ tasks:
script: |
${PREPARE_SHELL}
export PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3
export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/mciuploads/libmongocrypt/debian10/master/latest/libmongocrypt.tar.gz
export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/${bucket_name}/libmongocrypt/debian10/master/latest/libmongocrypt.tar.gz
SUCCESS=false TEST_FLE_GCP_AUTO=1 ./.evergreen/tox.sh -m test-eg
- name: testazurekms-task

View File

@ -258,6 +258,10 @@ if [ -z "$GREEN_FRAMEWORK" ]; then
# Use --capture=tee-sys so pytest prints test output inline:
# https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS
if [ -z "$TEST_ARGS" ]; then # TODO: remove this in PYTHON-4528
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 test/synchronous/ $TEST_ARGS
fi
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 test/asynchronous/ $TEST_ARGS
else
python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS
fi

View File

@ -26,9 +26,6 @@ jobs:
# required for all workflows
security-events: write
# required to fetch internal or private CodeQL packs
packages: read
strategy:
fail-fast: false
matrix:

View File

@ -10,6 +10,10 @@ on:
workflow_dispatch:
pull_request:
workflow_call:
inputs:
ref:
required: true
type: string
concurrency:
group: dist-${{ github.ref }}
@ -44,6 +48,7 @@ jobs:
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ inputs.ref }}
- uses: actions/setup-python@v5
with:
@ -99,6 +104,7 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ inputs.ref }}
- uses: actions/setup-python@v5
with:

View File

@ -19,7 +19,7 @@ env:
PRODUCT_NAME: PyMongo
# Changes per branch
SILK_ASSET_GROUP: mongodb-python-driver
EVERGREEN_PROJECT: mongodb-python-driver
EVERGREEN_PROJECT: mongo-python-driver
defaults:
run:
@ -32,6 +32,8 @@ jobs:
permissions:
id-token: write
contents: write
outputs:
version: ${{ steps.pre-publish.outputs.version }}
steps:
- uses: mongodb-labs/drivers-github-tools/secure-checkout@v2
with:
@ -44,6 +46,7 @@ jobs:
aws_secret_id: ${{ secrets.AWS_SECRET_ID }}
artifactory_username: ${{ vars.ARTIFACTORY_USERNAME }}
- uses: mongodb-labs/drivers-github-tools/python/pre-publish@v2
id: pre-publish
with:
version: ${{ inputs.version }}
dry_run: ${{ inputs.dry_run }}
@ -51,12 +54,16 @@ jobs:
build-dist:
needs: [pre-publish]
uses: ./.github/workflows/dist.yml
with:
ref: ${{ needs.pre-publish.outputs.version }}
static-scan:
needs: [pre-publish]
uses: ./.github/workflows/codeql.yml
permissions:
security-events: write
with:
ref: ${{ github.ref }}
ref: ${{ needs.pre-publish.outputs.version }}
publish:
needs: [build-dist, static-scan]

View File

@ -71,6 +71,9 @@ jobs:
- name: Run tests
run: |
tox -m test
- name: Run async tests
run: |
tox -m test-async
doctest:
runs-on: ubuntu-latest
@ -203,3 +206,5 @@ jobs:
which python
pip install -e ".[test]"
pytest -v
pytest -v test/synchronous/
pytest -v test/asynchronous/

View File

@ -14,7 +14,7 @@ a native Python driver for MongoDB. The `gridfs` package is a
[gridfs](https://github.com/mongodb/specifications/blob/master/source/gridfs/gridfs-spec.rst/)
implementation on top of `pymongo`.
PyMongo supports MongoDB 3.6, 4.0, 4.2, 4.4, 5.0, 6.0, and 7.0.
PyMongo supports MongoDB 3.6, 4.0, 4.2, 4.4, 5.0, 6.0, 7.0, and 8.0.
## Support / Feedback

View File

@ -6,7 +6,11 @@ Changes in Version 4.9.0
PyMongo 4.9 brings a number of improvements including:
- Added support for MongoDB 8.0.
- A new asynchronous API with full asyncio support.
- Add support for :attr:`~pymongo.encryption.Algorithm.RANGE` and deprecate
:attr:`~pymongo.encryption.Algorithm.RANGEPREVIEW`.
- pymongocrypt>=1.10 is now required for :ref:`In-Use Encryption` support.
Issues Resolved
...............
@ -26,12 +30,20 @@ PyMongo 4.8 brings a number of improvements including:
- The handshake metadata for "os.name" on Windows has been simplified to "Windows" to improve import time.
- The repr of ``bson.binary.Binary`` is now redacted when the subtype is SENSITIVE_SUBTYPE(8).
- Secure Software Development Life Cycle automation for release process.
GitHub Releases now include a Software Bill of Materials, and signature
files corresponding to the distribution files released on PyPI.
- Fixed a bug in change streams where both ``startAtOperationTime`` and ``resumeToken``
could be added to a retry attempt, which caused the retry to fail.
- Fallback to stdlib ``ssl`` module when ``pyopenssl`` import fails with AttributeError.
- Improved performance of MongoClient operations, especially when many operations are being run concurrently.
Unavoidable breaking changes
............................
- Since we are now using ``hatch`` as our build backend, we no longer have a ``setup.py`` file
and require installation using ``pip``.
- Since we are now using ``hatch`` as our build backend, we no longer have a usable ``setup.py`` file
and require installation using ``pip``. Attempts to invoke the ``setup.py`` file will raise an exception.
Additionally, ``pip`` >= 21.3 is now required for editable installs.
Issues Resolved
...............

View File

@ -42,13 +42,12 @@ from gridfs.grid_file_shared import (
_clear_entity_type_registry,
)
from pymongo import ASCENDING, DESCENDING, WriteConcern, _csot
from pymongo.asynchronous.client_session import ClientSession
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.common import validate_string
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.helpers import _check_write_command_response, anext
from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode
from pymongo.asynchronous.helpers import anext
from pymongo.common import validate_string
from pymongo.errors import (
BulkWriteError,
ConfigurationError,
@ -57,11 +56,13 @@ from pymongo.errors import (
InvalidOperation,
OperationFailure,
)
from pymongo.helpers_shared import _check_write_command_response
from pymongo.read_preferences import ReadPreference, _ServerMode
_IS_SYNC = False
def _disallow_transactions(session: Optional[ClientSession]) -> None:
def _disallow_transactions(session: Optional[AsyncClientSession]) -> None:
if session and session.in_transaction:
raise InvalidOperation("GridFS does not support multi-document transactions")
@ -155,7 +156,7 @@ class AsyncGridFS:
await grid_file.write(data)
return await grid_file._id
async def get(self, file_id: Any, session: Optional[ClientSession] = None) -> AsyncGridOut:
async def get(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> AsyncGridOut:
"""Get a file from GridFS by ``"_id"``.
Returns an instance of :class:`~gridfs.grid_file.GridOut`,
@ -163,7 +164,7 @@ class AsyncGridFS:
:param file_id: ``"_id"`` of the file to get
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
@ -178,7 +179,7 @@ class AsyncGridFS:
self,
filename: Optional[str] = None,
version: Optional[int] = -1,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
**kwargs: Any,
) -> AsyncGridOut:
"""Get a file from GridFS by ``"filename"`` or metadata fields.
@ -205,7 +206,7 @@ class AsyncGridFS:
:param version: version of the file to get (defaults
to -1, the most recent version uploaded)
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
:param kwargs: find files by custom metadata.
.. versionchanged:: 3.6
@ -234,7 +235,10 @@ class AsyncGridFS:
raise NoFile("no version %d for filename %r" % (version, filename)) from None
async def get_last_version(
self, filename: Optional[str] = None, session: Optional[ClientSession] = None, **kwargs: Any
self,
filename: Optional[str] = None,
session: Optional[AsyncClientSession] = None,
**kwargs: Any,
) -> AsyncGridOut:
"""Get the most recent version of a file in GridFS by ``"filename"``
or metadata fields.
@ -244,7 +248,7 @@ class AsyncGridFS:
:param filename: ``"filename"`` of the file to get, or `None`
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
:param kwargs: find files by custom metadata.
.. versionchanged:: 3.6
@ -253,7 +257,7 @@ class AsyncGridFS:
return await self.get_version(filename=filename, session=session, **kwargs)
# TODO add optional safe mode for chunk removal?
async def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None:
async def delete(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> None:
"""Delete a file from GridFS by ``"_id"``.
Deletes all data belonging to the file with ``"_id"``:
@ -269,7 +273,7 @@ class AsyncGridFS:
:param file_id: ``"_id"`` of the file to delete
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
@ -281,12 +285,12 @@ class AsyncGridFS:
await self._files.delete_one({"_id": file_id}, session=session)
await self._chunks.delete_many({"files_id": file_id}, session=session)
async def list(self, session: Optional[ClientSession] = None) -> list[str]:
async def list(self, session: Optional[AsyncClientSession] = None) -> list[str]:
"""List the names of all files stored in this instance of
:class:`GridFS`.
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
@ -306,7 +310,7 @@ class AsyncGridFS:
async def find_one(
self,
filter: Optional[Any] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
*args: Any,
**kwargs: Any,
) -> Optional[AsyncGridOut]:
@ -327,7 +331,7 @@ class AsyncGridFS:
:param args: any additional positional arguments are
the same as the arguments to :meth:`find`.
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
:param kwargs: any additional keyword arguments
are the same as the arguments to :meth:`find`.
@ -370,7 +374,7 @@ class AsyncGridFS:
:meth:`~pymongo.collection.Collection.find`
in :class:`~pymongo.collection.Collection`.
If a :class:`~pymongo.client_session.ClientSession` is passed to
If a :class:`~pymongo.client_session.AsyncClientSession` is passed to
:meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances
are associated with that session.
@ -406,7 +410,7 @@ class AsyncGridFS:
async def exists(
self,
document_or_id: Optional[Any] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
**kwargs: Any,
) -> bool:
"""Check if a file exists in this instance of :class:`GridFS`.
@ -438,7 +442,7 @@ class AsyncGridFS:
:param document_or_id: query document, or _id of the
document to check for
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
:param kwargs: keyword arguments are used as a
query document, if they're present.
@ -525,7 +529,7 @@ class AsyncGridFSBucket:
filename: str,
chunk_size_bytes: Optional[int] = None,
metadata: Optional[Mapping[str, Any]] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
) -> AsyncGridIn:
"""Opens a Stream that the application can write the contents of the
file to.
@ -556,7 +560,7 @@ class AsyncGridFSBucket:
files collection document. If not provided the metadata field will
be omitted from the files collection document.
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
@ -580,7 +584,7 @@ class AsyncGridFSBucket:
filename: str,
chunk_size_bytes: Optional[int] = None,
metadata: Optional[Mapping[str, Any]] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
) -> AsyncGridIn:
"""Opens a Stream that the application can write the contents of the
file to.
@ -615,7 +619,7 @@ class AsyncGridFSBucket:
files collection document. If not provided the metadata field will
be omitted from the files collection document.
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
@ -641,7 +645,7 @@ class AsyncGridFSBucket:
source: Any,
chunk_size_bytes: Optional[int] = None,
metadata: Optional[Mapping[str, Any]] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
) -> ObjectId:
"""Uploads a user file to a GridFS bucket.
@ -672,7 +676,7 @@ class AsyncGridFSBucket:
files collection document. If not provided the metadata field will
be omitted from the files collection document.
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
@ -692,7 +696,7 @@ class AsyncGridFSBucket:
source: Any,
chunk_size_bytes: Optional[int] = None,
metadata: Optional[Mapping[str, Any]] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
) -> None:
"""Uploads a user file to a GridFS bucket with a custom file id.
@ -724,7 +728,7 @@ class AsyncGridFSBucket:
files collection document. If not provided the metadata field will
be omitted from the files collection document.
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
@ -735,7 +739,7 @@ class AsyncGridFSBucket:
await gin.write(source)
async def open_download_stream(
self, file_id: Any, session: Optional[ClientSession] = None
self, file_id: Any, session: Optional[AsyncClientSession] = None
) -> AsyncGridOut:
"""Opens a Stream from which the application can read the contents of
the stored file specified by file_id.
@ -755,7 +759,7 @@ class AsyncGridFSBucket:
:param file_id: The _id of the file to be downloaded.
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
@ -768,7 +772,7 @@ class AsyncGridFSBucket:
@_csot.apply
async def download_to_stream(
self, file_id: Any, destination: Any, session: Optional[ClientSession] = None
self, file_id: Any, destination: Any, session: Optional[AsyncClientSession] = None
) -> None:
"""Downloads the contents of the stored file specified by file_id and
writes the contents to `destination`.
@ -790,7 +794,7 @@ class AsyncGridFSBucket:
:param file_id: The _id of the file to be downloaded.
:param destination: a file-like object implementing :meth:`write`.
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
@ -803,7 +807,7 @@ class AsyncGridFSBucket:
destination.write(chunk)
@_csot.apply
async def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None:
async def delete(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> None:
"""Given an file_id, delete this stored file's files collection document
and associated chunks from a GridFS bucket.
@ -819,7 +823,7 @@ class AsyncGridFSBucket:
:param file_id: The _id of the file to be deleted.
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
@ -859,7 +863,7 @@ class AsyncGridFSBucket:
:meth:`~pymongo.collection.Collection.find`
in :class:`~pymongo.collection.Collection`.
If a :class:`~pymongo.client_session.ClientSession` is passed to
If a :class:`~pymongo.client_session.AsyncClientSession` is passed to
:meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances
are associated with that session.
@ -878,7 +882,7 @@ class AsyncGridFSBucket:
return AsyncGridOutCursor(self._collection, *args, **kwargs)
async def open_download_stream_by_name(
self, filename: str, revision: int = -1, session: Optional[ClientSession] = None
self, filename: str, revision: int = -1, session: Optional[AsyncClientSession] = None
) -> AsyncGridOut:
"""Opens a Stream from which the application can read the contents of
`filename` and optional `revision`.
@ -902,7 +906,7 @@ class AsyncGridFSBucket:
filename and different uploadDate) of the file to retrieve.
Defaults to -1 (the most recent revision).
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
:Note: Revision numbers are defined as follows:
@ -937,7 +941,7 @@ class AsyncGridFSBucket:
filename: str,
destination: Any,
revision: int = -1,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
) -> None:
"""Write the contents of `filename` (with optional `revision`) to
`destination`.
@ -961,7 +965,7 @@ class AsyncGridFSBucket:
filename and different uploadDate) of the file to retrieve.
Defaults to -1 (the most recent revision).
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
:Note: Revision numbers are defined as follows:
@ -985,7 +989,7 @@ class AsyncGridFSBucket:
destination.write(chunk)
async def rename(
self, file_id: Any, new_filename: str, session: Optional[ClientSession] = None
self, file_id: Any, new_filename: str, session: Optional[AsyncClientSession] = None
) -> None:
"""Renames the stored file with the specified file_id.
@ -1002,7 +1006,7 @@ class AsyncGridFSBucket:
:param file_id: The _id of the file to be renamed.
:param new_filename: The new name of the file.
:param session: a
:class:`~pymongo.client_session.ClientSession`
:class:`~pymongo.client_session.AsyncClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
@ -1024,7 +1028,7 @@ class AsyncGridIn:
def __init__(
self,
root_collection: AsyncCollection,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
**kwargs: Any,
) -> None:
"""Write a file to GridFS
@ -1059,7 +1063,7 @@ class AsyncGridIn:
:param root_collection: root collection to write to
:param session: a
:class:`~pymongo.client_session.ClientSession` to use for all
:class:`~pymongo.client_session.AsyncClientSession` to use for all
commands
:param kwargs: Any: file level options (see above)
@ -1402,7 +1406,7 @@ class AsyncGridOut(io.IOBase):
root_collection: AsyncCollection,
file_id: Optional[int] = None,
file_document: Optional[Any] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
) -> None:
"""Read a file from GridFS
@ -1420,7 +1424,7 @@ class AsyncGridOut(io.IOBase):
:param file_document: file document from
`root_collection.files`
:param session: a
:class:`~pymongo.client_session.ClientSession` to use for all
:class:`~pymongo.client_session.AsyncClientSession` to use for all
commands
.. versionchanged:: 3.8
@ -1734,7 +1738,7 @@ class _AsyncGridOutChunkIterator:
self,
grid_out: AsyncGridOut,
chunks: AsyncCollection,
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
next_chunk: Any,
) -> None:
self._id = grid_out._id
@ -1824,7 +1828,9 @@ class _AsyncGridOutChunkIterator:
class AsyncGridOutIterator:
def __init__(self, grid_out: AsyncGridOut, chunks: AsyncCollection, session: ClientSession):
def __init__(
self, grid_out: AsyncGridOut, chunks: AsyncCollection, session: AsyncClientSession
):
self._chunk_iter = _AsyncGridOutChunkIterator(grid_out, chunks, session, 0)
def __aiter__(self) -> AsyncGridOutIterator:
@ -1851,7 +1857,7 @@ class AsyncGridOutCursor(AsyncCursor):
no_cursor_timeout: bool = False,
sort: Optional[Any] = None,
batch_size: int = 0,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
) -> None:
"""Create a new cursor, similar to the normal
:class:`~pymongo.cursor.Cursor`.
@ -1894,6 +1900,6 @@ class AsyncGridOutCursor(AsyncCursor):
def remove_option(self, *args: Any, **kwargs: Any) -> NoReturn:
raise NotImplementedError("Method does not exist for GridOutCursor")
def _clone_base(self, session: Optional[ClientSession]) -> AsyncGridOutCursor:
def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncGridOutCursor:
"""Creates an empty GridOutCursor for information to be copied into."""
return AsyncGridOutCursor(self._root_collection, session=session)

View File

@ -5,7 +5,7 @@ import warnings
from typing import Any, Optional
from pymongo import ASCENDING
from pymongo.asynchronous.common import MAX_MESSAGE_SIZE
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.errors import InvalidOperation
_SEEK_SET = os.SEEK_SET

View File

@ -42,6 +42,7 @@ from gridfs.grid_file_shared import (
_grid_out_property,
)
from pymongo import ASCENDING, DESCENDING, WriteConcern, _csot
from pymongo.common import validate_string
from pymongo.errors import (
BulkWriteError,
ConfigurationError,
@ -50,13 +51,13 @@ from pymongo.errors import (
InvalidOperation,
OperationFailure,
)
from pymongo.helpers_shared import _check_write_command_response
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.synchronous.client_session import ClientSession
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.common import validate_string
from pymongo.synchronous.cursor import Cursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.helpers import _check_write_command_response, next
from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode
from pymongo.synchronous.helpers import next
_IS_SYNC = True
@ -234,7 +235,10 @@ class GridFS:
raise NoFile("no version %d for filename %r" % (version, filename)) from None
def get_last_version(
self, filename: Optional[str] = None, session: Optional[ClientSession] = None, **kwargs: Any
self,
filename: Optional[str] = None,
session: Optional[ClientSession] = None,
**kwargs: Any,
) -> GridOut:
"""Get the most recent version of a file in GridFS by ``"filename"``
or metadata fields.
@ -497,7 +501,7 @@ class GridFSBucket:
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
"""
if not isinstance(db, Database):
raise TypeError("database must be an instance of AsyncDatabase")
raise TypeError("database must be an instance of Database")
db = _clear_entity_type_registry(db)
@ -1028,7 +1032,7 @@ class GridIn:
provided by :class:`~gridfs.GridFS`.
Raises :class:`TypeError` if `root_collection` is not an
instance of :class:`~pymongo.collection.AsyncCollection`.
instance of :class:`~pymongo.collection.Collection`.
Any of the file level options specified in the `GridFS Spec
<http://dochub.mongodb.org/core/gridfsspec>`_ may be passed as
@ -1069,10 +1073,10 @@ class GridIn:
.. versionchanged:: 3.0
`root_collection` must use an acknowledged
:attr:`~pymongo.collection.AsyncCollection.write_concern`
:attr:`~pymongo.collection.Collection.write_concern`
"""
if not isinstance(root_collection, Collection):
raise TypeError("root_collection must be an instance of AsyncCollection")
raise TypeError("root_collection must be an instance of Collection")
if not root_collection.write_concern.acknowledged:
raise ConfigurationError("root_collection must use acknowledged write_concern")
@ -1401,7 +1405,7 @@ class GridOut(io.IOBase):
Either `file_id` or `file_document` must be specified,
`file_document` will be given priority if present. Raises
:class:`TypeError` if `root_collection` is not an instance of
:class:`~pymongo.collection.AsyncCollection`.
:class:`~pymongo.collection.Collection`.
:param root_collection: root collection to read from
:param file_id: value of ``"_id"`` for the file to read
@ -1424,7 +1428,7 @@ class GridOut(io.IOBase):
from the server. Metadata is fetched when first needed.
"""
if not isinstance(root_collection, Collection):
raise TypeError("root_collection must be an instance of AsyncCollection")
raise TypeError("root_collection must be an instance of Collection")
_disallow_transactions(session)
root_collection = _clear_entity_type_registry(root_collection)
@ -1482,7 +1486,7 @@ class GridOut(io.IOBase):
self.open() # type: ignore[unused-coroutine]
elif not self._file:
raise InvalidOperation(
"You must call AsyncGridOut.open() before accessing the %s property" % name
"You must call GridOut.open() before accessing the %s property" % name
)
if name in self._file:
return self._file[name]
@ -1677,13 +1681,13 @@ class GridOut(io.IOBase):
return False
def __enter__(self) -> GridOut:
"""Makes it possible to use :class:`AsyncGridOut` files
"""Makes it possible to use :class:`GridOut` files
with the async context manager protocol.
"""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any:
"""Makes it possible to use :class:`AsyncGridOut` files
"""Makes it possible to use :class:`GridOut` files
with the async context manager protocol.
"""
self.close()

View File

@ -89,11 +89,9 @@ TEXT = "text"
from pymongo import _csot
from pymongo._version import __version__, get_version_string, version_tuple
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION
from pymongo.cursor import CursorType
from pymongo.synchronous.collection import ReturnDocument
from pymongo.synchronous.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.operations import (
from pymongo.operations import (
DeleteMany,
DeleteOne,
IndexModel,
@ -102,7 +100,9 @@ from pymongo.synchronous.operations import (
UpdateMany,
UpdateOne,
)
from pymongo.synchronous.read_preferences import ReadPreference
from pymongo.read_preferences import ReadPreference
from pymongo.synchronous.collection import ReturnDocument
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.write_concern import WriteConcern
version = __version__

View File

@ -18,20 +18,20 @@ from __future__ import annotations
from collections.abc import Callable, Mapping, MutableMapping
from typing import TYPE_CHECKING, Any, Optional, Union
from pymongo.asynchronous import common
from pymongo.asynchronous.collation import validate_collation_or_none
from pymongo.asynchronous.read_preferences import ReadPreference, _AggWritePref
from pymongo import common
from pymongo.collation import validate_collation_or_none
from pymongo.errors import ConfigurationError
from pymongo.read_preferences import ReadPreference, _AggWritePref
if TYPE_CHECKING:
from pymongo.asynchronous.client_session import ClientSession
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.read_preferences import _ServerMode
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.asynchronous.server import Server
from pymongo.asynchronous.typings import _DocumentType, _Pipeline
from pymongo.read_preferences import _ServerMode
from pymongo.typings import _DocumentType, _Pipeline
_IS_SYNC = False
@ -53,7 +53,7 @@ class _AggregationCommand:
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
user_fields: Optional[MutableMapping[str, Any]] = None,
result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None,
result_processor: Optional[Callable[[Mapping[str, Any], AsyncConnection], None]] = None,
comment: Any = None,
) -> None:
if "explain" in options:
@ -121,7 +121,7 @@ class _AggregationCommand:
raise NotImplementedError
def get_read_preference(
self, session: Optional[ClientSession]
self, session: Optional[AsyncClientSession]
) -> Union[_AggWritePref, _ServerMode]:
if self._write_preference:
return self._write_preference
@ -132,9 +132,9 @@ class _AggregationCommand:
async def get_cursor(
self,
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
server: Server,
conn: Connection,
conn: AsyncConnection,
read_preference: _ServerMode,
) -> AsyncCommandCursor[_DocumentType]:
# Serialize command.

View File

@ -18,17 +18,13 @@ from __future__ import annotations
import functools
import hashlib
import hmac
import os
import socket
import typing
from base64 import standard_b64decode, standard_b64encode
from collections import namedtuple
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
Mapping,
MutableMapping,
Optional,
@ -41,17 +37,19 @@ from pymongo.asynchronous.auth_aws import _authenticate_aws
from pymongo.asynchronous.auth_oidc import (
_authenticate_oidc,
_get_authenticator,
_OIDCAzureCallback,
_OIDCGCPCallback,
_OIDCProperties,
_OIDCTestCallback,
)
from pymongo.auth_shared import (
MongoCredential,
_authenticate_scram_start,
_parse_scram_response,
_xor,
)
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.saslprep import saslprep
if TYPE_CHECKING:
from pymongo.asynchronous.hello import Hello
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.hello import Hello
HAVE_KERBEROS = True
_USE_PRINCIPAL = False
@ -69,213 +67,9 @@ except ImportError:
_IS_SYNC = False
MECHANISMS = frozenset(
[
"GSSAPI",
"MONGODB-CR",
"MONGODB-OIDC",
"MONGODB-X509",
"MONGODB-AWS",
"PLAIN",
"SCRAM-SHA-1",
"SCRAM-SHA-256",
"DEFAULT",
]
)
"""The authentication mechanisms supported by PyMongo."""
class _Cache:
__slots__ = ("data",)
_hash_val = hash("_Cache")
def __init__(self) -> None:
self.data = None
def __eq__(self, other: object) -> bool:
# Two instances must always compare equal.
if isinstance(other, _Cache):
return True
return NotImplemented
def __ne__(self, other: object) -> bool:
if isinstance(other, _Cache):
return False
return NotImplemented
def __hash__(self) -> int:
return self._hash_val
MongoCredential = namedtuple(
"MongoCredential",
["mechanism", "source", "username", "password", "mechanism_properties", "cache"],
)
"""A hashable namedtuple of values used for authentication."""
GSSAPIProperties = namedtuple(
"GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"]
)
"""Mechanism properties for GSSAPI authentication."""
_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
"""Mechanism properties for MONGODB-AWS authentication."""
def _build_credentials_tuple(
mech: str,
source: Optional[str],
user: str,
passwd: str,
extra: Mapping[str, Any],
database: Optional[str],
) -> MongoCredential:
"""Build and return a mechanism specific credentials tuple."""
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
raise ConfigurationError(f"{mech} requires a username.")
if mech == "GSSAPI":
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for GSSAPI")
properties = extra.get("authmechanismproperties", {})
service_name = properties.get("SERVICE_NAME", "mongodb")
canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False))
service_realm = properties.get("SERVICE_REALM")
props = GSSAPIProperties(
service_name=service_name,
canonicalize_host_name=canonicalize,
service_realm=service_realm,
)
# Source is always $external.
return MongoCredential(mech, "$external", user, passwd, props, None)
elif mech == "MONGODB-X509":
if passwd is not None:
raise ConfigurationError("Passwords are not supported by MONGODB-X509")
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for MONGODB-X509")
# Source is always $external, user can be None.
return MongoCredential(mech, "$external", user, None, None, None)
elif mech == "MONGODB-AWS":
if user is not None and passwd is None:
raise ConfigurationError("username without a password is not supported by MONGODB-AWS")
if source is not None and source != "$external":
raise ConfigurationError(
"authentication source must be $external or None for MONGODB-AWS"
)
properties = extra.get("authmechanismproperties", {})
aws_session_token = properties.get("AWS_SESSION_TOKEN")
aws_props = _AWSProperties(aws_session_token=aws_session_token)
# user can be None for temporary link-local EC2 credentials.
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
elif mech == "MONGODB-OIDC":
properties = extra.get("authmechanismproperties", {})
callback = properties.get("OIDC_CALLBACK")
human_callback = properties.get("OIDC_HUMAN_CALLBACK")
environ = properties.get("ENVIRONMENT")
token_resource = properties.get("TOKEN_RESOURCE", "")
default_allowed = [
"*.mongodb.net",
"*.mongodb-dev.net",
"*.mongodb-qa.net",
"*.mongodbgov.net",
"localhost",
"127.0.0.1",
"::1",
]
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
msg = (
"authentication with MONGODB-OIDC requires providing either a callback or a environment"
)
if passwd is not None:
msg = "password is not supported by MONGODB-OIDC"
raise ConfigurationError(msg)
if callback or human_callback:
if environ is not None:
raise ConfigurationError(msg)
if callback and human_callback:
msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK"
raise ConfigurationError(msg)
elif environ is not None:
if environ == "test":
if user is not None:
msg = "test environment for MONGODB-OIDC does not support username"
raise ConfigurationError(msg)
callback = _OIDCTestCallback()
elif environ == "azure":
passwd = None
if not token_resource:
raise ConfigurationError(
"Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCAzureCallback(token_resource)
elif environ == "gcp":
passwd = None
if not token_resource:
raise ConfigurationError(
"GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCGCPCallback(token_resource)
else:
raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}")
else:
raise ConfigurationError(msg)
oidc_props = _OIDCProperties(
callback=callback,
human_callback=human_callback,
environment=environ,
allowed_hosts=allowed_hosts,
token_resource=token_resource,
username=user,
)
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
elif mech == "PLAIN":
source_database = source or database or "$external"
return MongoCredential(mech, source_database, user, passwd, None, None)
else:
source_database = source or database or "admin"
if passwd is None:
raise ConfigurationError("A password is required.")
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
def _xor(fir: bytes, sec: bytes) -> bytes:
"""XOR two byte strings together."""
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]:
"""Split a scram response into key, value pairs."""
return dict(
typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1))
for item in response.split(b",")
)
def _authenticate_scram_start(
credentials: MongoCredential, mechanism: str
) -> tuple[bytes, bytes, MutableMapping[str, Any]]:
username = credentials.username
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
nonce = standard_b64encode(os.urandom(32))
first_bare = b"n=" + user + b",r=" + nonce
cmd = {
"saslStart": 1,
"mechanism": mechanism,
"payload": Binary(b"n,," + first_bare),
"autoAuthorize": 1,
"options": {"skipEmptyExchange": True},
}
return nonce, first_bare, cmd
async def _authenticate_scram(
credentials: MongoCredential, conn: Connection, mechanism: str
credentials: MongoCredential, conn: AsyncConnection, mechanism: str
) -> None:
"""Authenticate using SCRAM."""
username = credentials.username
@ -398,7 +192,7 @@ def _canonicalize_hostname(hostname: str) -> str:
return name[0].lower()
async def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None:
async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnection) -> None:
"""Authenticate using GSSAPI."""
if not HAVE_KERBEROS:
raise ConfigurationError(
@ -509,7 +303,7 @@ async def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -
raise OperationFailure(str(exc)) from None
async def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None:
async def _authenticate_plain(credentials: MongoCredential, conn: AsyncConnection) -> None:
"""Authenticate using SASL PLAIN (RFC 4616)"""
source = credentials.source
username = credentials.username
@ -524,7 +318,7 @@ async def _authenticate_plain(credentials: MongoCredential, conn: Connection) ->
await conn.command(source, cmd)
async def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None:
async def _authenticate_x509(credentials: MongoCredential, conn: AsyncConnection) -> None:
"""Authenticate using MONGODB-X509."""
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
@ -535,7 +329,7 @@ async def _authenticate_x509(credentials: MongoCredential, conn: Connection) ->
await conn.command("$external", cmd)
async def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None:
async def _authenticate_mongo_cr(credentials: MongoCredential, conn: AsyncConnection) -> None:
"""Authenticate using MONGODB-CR."""
source = credentials.source
username = credentials.username
@ -550,7 +344,7 @@ async def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection)
await conn.command(source, query)
async def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None:
async def _authenticate_default(credentials: MongoCredential, conn: AsyncConnection) -> None:
if conn.max_wire_version >= 7:
if conn.negotiated_mechs:
mechs = conn.negotiated_mechs
@ -652,7 +446,7 @@ _SPECULATIVE_AUTH_MAP: Mapping[str, Any] = {
async def authenticate(
credentials: MongoCredential, conn: Connection, reauthenticate: bool = False
credentials: MongoCredential, conn: AsyncConnection, reauthenticate: bool = False
) -> None:
"""Authenticate connection."""
mechanism = credentials.mechanism

View File

@ -23,13 +23,13 @@ from pymongo.errors import ConfigurationError, OperationFailure
if TYPE_CHECKING:
from bson.typings import _ReadableBuffer
from pymongo.asynchronous.auth import MongoCredential
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.auth_shared import MongoCredential
_IS_SYNC = False
async def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
async def _authenticate_aws(credentials: MongoCredential, conn: AsyncConnection) -> None:
"""Authenticate using MONGODB-AWS."""
try:
import pymongo_auth_aws # type:ignore[import]

View File

@ -15,79 +15,35 @@
"""MONGODB-OIDC Authentication helpers."""
from __future__ import annotations
import abc
import os
import threading
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
from urllib.parse import quote
import bson
from bson.binary import Binary
from pymongo._azure_helpers import _get_azure_response
from pymongo._csot import remaining
from pymongo._gcp_helpers import _get_gcp_response
from pymongo.auth_oidc_shared import (
CALLBACK_VERSION,
HUMAN_CALLBACK_TIMEOUT_SECONDS,
MACHINE_CALLBACK_TIMEOUT_SECONDS,
TIME_BETWEEN_CALLS_SECONDS,
OIDCCallback,
OIDCCallbackContext,
OIDCCallbackResult,
OIDCIdPInfo,
_OIDCProperties,
)
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.helpers_constants import _AUTHENTICATION_FAILURE_CODE
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
if TYPE_CHECKING:
from pymongo.asynchronous.auth import MongoCredential
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.auth_shared import MongoCredential
_IS_SYNC = False
@dataclass
class OIDCIdPInfo:
issuer: str
clientId: Optional[str] = field(default=None)
requestScopes: Optional[list[str]] = field(default=None)
@dataclass
class OIDCCallbackContext:
timeout_seconds: float
username: str
version: int
refresh_token: Optional[str] = field(default=None)
idp_info: Optional[OIDCIdPInfo] = field(default=None)
@dataclass
class OIDCCallbackResult:
access_token: str
expires_in_seconds: Optional[float] = field(default=None)
refresh_token: Optional[str] = field(default=None)
class OIDCCallback(abc.ABC):
"""A base class for defining OIDC callbacks."""
@abc.abstractmethod
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
"""Convert the given BSON value into our own type."""
@dataclass
class _OIDCProperties:
callback: Optional[OIDCCallback] = field(default=None)
human_callback: Optional[OIDCCallback] = field(default=None)
environment: Optional[str] = field(default=None)
allowed_hosts: list[str] = field(default_factory=list)
token_resource: Optional[str] = field(default=None)
username: str = ""
"""Mechanism properties for MONGODB-OIDC authentication."""
TOKEN_BUFFER_MINUTES = 5
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
CALLBACK_VERSION = 1
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
TIME_BETWEEN_CALLS_SECONDS = 0.1
def _get_authenticator(
credentials: MongoCredential, address: tuple[str, int]
) -> _OIDCAuthenticator:
@ -117,48 +73,6 @@ def _get_authenticator(
return credentials.cache.data
class _OIDCTestCallback(OIDCCallback):
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
token_file = os.environ.get("OIDC_TOKEN_FILE")
if not token_file:
raise RuntimeError(
'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set'
)
with open(token_file) as fid:
return OIDCCallbackResult(access_token=fid.read().strip())
class _OIDCAWSCallback(OIDCCallback):
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
if not token_file:
raise RuntimeError(
'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set'
)
with open(token_file) as fid:
return OIDCCallbackResult(access_token=fid.read().strip())
class _OIDCAzureCallback(OIDCCallback):
def __init__(self, token_resource: str) -> None:
self.token_resource = quote(token_resource)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds)
return OIDCCallbackResult(
access_token=resp["access_token"], expires_in_seconds=resp["expires_in"]
)
class _OIDCGCPCallback(OIDCCallback):
def __init__(self, token_resource: str) -> None:
self.token_resource = quote(token_resource)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
return OIDCCallbackResult(access_token=resp["access_token"])
@dataclass
class _OIDCAuthenticator:
username: str
@ -170,7 +84,7 @@ class _OIDCAuthenticator:
lock: threading.Lock = field(default_factory=threading.Lock)
last_call_time: float = field(default=0)
async def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
async def reauthenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
"""Handle a reauthenticate from the server."""
# Invalidate the token for the connection.
self._invalidate(conn)
@ -179,7 +93,7 @@ class _OIDCAuthenticator:
return await self._authenticate_machine(conn)
return await self._authenticate_human(conn)
async def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
async def authenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
"""Handle an initial authenticate request."""
# First handle speculative auth.
# If it succeeded, we are done.
@ -203,7 +117,7 @@ class _OIDCAuthenticator:
return None
return self._get_start_command({"jwt": self.access_token})
async def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]:
async def _authenticate_machine(self, conn: AsyncConnection) -> Mapping[str, Any]:
# If there is a cached access token, try to authenticate with it. If
# authentication fails with error code 18, invalidate the access token,
# fetch a new access token, and try to authenticate again. If authentication
@ -217,7 +131,7 @@ class _OIDCAuthenticator:
raise
return await self._sasl_start_jwt(conn)
async def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]:
async def _authenticate_human(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
# If we have a cached access token, try a JwtStepRequest.
# authentication fails with error code 18, invalidate the access token,
# and try to authenticate again. If authentication fails for any other
@ -307,7 +221,7 @@ class _OIDCAuthenticator:
return self.access_token
async def _run_command(
self, conn: Connection, cmd: MutableMapping[str, Any]
self, conn: AsyncConnection, cmd: MutableMapping[str, Any]
) -> Mapping[str, Any]:
try:
return await conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
@ -321,7 +235,7 @@ class _OIDCAuthenticator:
return False
return err.code == _AUTHENTICATION_FAILURE_CODE
def _invalidate(self, conn: Connection) -> None:
def _invalidate(self, conn: AsyncConnection) -> None:
# Ignore the invalidation if a token gen id is given and is less than our
# current token gen id.
token_gen_id = conn.oidc_token_gen_id or 0
@ -330,7 +244,7 @@ class _OIDCAuthenticator:
self.access_token = None
async def _sasl_continue_jwt(
self, conn: Connection, start_resp: Mapping[str, Any]
self, conn: AsyncConnection, start_resp: Mapping[str, Any]
) -> Mapping[str, Any]:
self.access_token = None
self.refresh_token = None
@ -342,7 +256,7 @@ class _OIDCAuthenticator:
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
return await self._run_command(conn, cmd)
async def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]:
async def _sasl_start_jwt(self, conn: AsyncConnection) -> Mapping[str, Any]:
access_token = self._get_access_token()
conn.oidc_token_gen_id = self.token_gen_id
cmd = self._get_start_command({"jwt": access_token})
@ -370,7 +284,7 @@ class _OIDCAuthenticator:
async def _authenticate_oidc(
credentials: MongoCredential, conn: Connection, reauthenticate: bool
credentials: MongoCredential, conn: AsyncConnection, reauthenticate: bool
) -> Optional[Mapping[str, Any]]:
"""Authenticate using MONGODB-OIDC."""
authenticator = _get_authenticator(credentials, conn.address)

View File

@ -19,6 +19,8 @@
from __future__ import annotations
import copy
import datetime
import logging
from collections.abc import MutableMapping
from itertools import islice
from typing import (
@ -26,7 +28,6 @@ from typing import (
Any,
Iterator,
Mapping,
NoReturn,
Optional,
Type,
Union,
@ -34,142 +35,52 @@ from typing import (
from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from pymongo import _csot
from pymongo.asynchronous import common
from pymongo.asynchronous.client_session import ClientSession, _validate_session_write_concern
from pymongo.asynchronous.common import (
from pymongo import _csot, common
from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.bulk_shared import (
_COMMANDS,
_DELETE_ALL,
_merge_command,
_raise_bulk_write_error,
_Run,
)
from pymongo.common import (
validate_is_document_type,
validate_ok_for_replace,
validate_ok_for_update,
)
from pymongo.asynchronous.helpers import _get_wce_doc
from pymongo.asynchronous.message import (
from pymongo.errors import (
ConfigurationError,
InvalidOperation,
NotPrimaryError,
OperationFailure,
)
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import (
_DELETE,
_INSERT,
_UPDATE,
_BulkWriteContext,
_convert_exception,
_convert_write_result,
_EncryptedBulkWriteContext,
_randint,
)
from pymongo.asynchronous.read_preferences import ReadPreference
from pymongo.errors import (
BulkWriteError,
ConfigurationError,
InvalidOperation,
OperationFailure,
)
from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.typings import _DocumentOut, _DocumentType, _Pipeline
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
_IS_SYNC = False
_DELETE_ALL: int = 0
_DELETE_ONE: int = 1
# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err
_BAD_VALUE: int = 2
_UNKNOWN_ERROR: int = 8
_WRITE_CONCERN_ERROR: int = 64
_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete")
class _Run:
"""Represents a batch of write operations."""
def __init__(self, op_type: int) -> None:
"""Initialize a new Run object."""
self.op_type: int = op_type
self.index_map: list[int] = []
self.ops: list[Any] = []
self.idx_offset: int = 0
def index(self, idx: int) -> int:
"""Get the original index of an operation in this run.
:param idx: The Run index that maps to the original index.
"""
return self.index_map[idx]
def add(self, original_index: int, operation: Any) -> None:
"""Add an operation to this Run instance.
:param original_index: The original index of this operation
within a larger bulk operation.
:param operation: The operation document.
"""
self.index_map.append(original_index)
self.ops.append(operation)
def _merge_command(
run: _Run,
full_result: MutableMapping[str, Any],
offset: int,
result: Mapping[str, Any],
) -> None:
"""Merge a write command result into the full bulk result."""
affected = result.get("n", 0)
if run.op_type == _INSERT:
full_result["nInserted"] += affected
elif run.op_type == _DELETE:
full_result["nRemoved"] += affected
elif run.op_type == _UPDATE:
upserted = result.get("upserted")
if upserted:
n_upserted = len(upserted)
for doc in upserted:
doc["index"] = run.index(doc["index"] + offset)
full_result["upserted"].extend(upserted)
full_result["nUpserted"] += n_upserted
full_result["nMatched"] += affected - n_upserted
else:
full_result["nMatched"] += affected
full_result["nModified"] += result["nModified"]
write_errors = result.get("writeErrors")
if write_errors:
for doc in write_errors:
# Leave the server response intact for APM.
replacement = doc.copy()
idx = doc["index"] + offset
replacement["index"] = run.index(idx)
# Add the failed operation to the error document.
replacement["op"] = run.ops[idx]
full_result["writeErrors"].append(replacement)
wce = _get_wce_doc(result)
if wce:
full_result["writeConcernErrors"].append(wce)
def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn:
"""Raise a BulkWriteError from the full bulk api result."""
# retryWrites on MMAPv1 should raise an actionable error.
if full_result["writeErrors"]:
full_result["writeErrors"].sort(key=lambda error: error["index"])
err = full_result["writeErrors"][0]
code = err["code"]
msg = err["errmsg"]
if code == 20 and msg.startswith("Transaction numbers"):
errmsg = (
"This MongoDB deployment does not support "
"retryable writes. Please add retryWrites=false "
"to your connection string."
)
raise OperationFailure(errmsg, code, full_result)
raise BulkWriteError(full_result)
class _Bulk:
class _AsyncBulk:
"""The private guts of the bulk write API."""
def __init__(
@ -180,7 +91,7 @@ class _Bulk:
comment: Optional[str] = None,
let: Optional[Any] = None,
) -> None:
"""Initialize a _Bulk instance."""
"""Initialize a _AsyncBulk instance."""
self.collection = collection.with_options(
codec_options=collection.codec_options._replace(
unicode_decode_error_handler="replace", document_class=dict
@ -204,13 +115,16 @@ class _Bulk:
# Extra state so that we know where to pick up on a retry attempt.
self.current_run = None
self.next_run = None
self.is_encrypted = False
@property
def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
encrypter = self.collection.database.client._encrypter
if encrypter and not encrypter._bypass_auto_encryption:
self.is_encrypted = True
return _EncryptedBulkWriteContext
else:
self.is_encrypted = False
return _BulkWriteContext
def add_insert(self, document: _DocumentOut) -> None:
@ -315,12 +229,236 @@ class _Bulk:
if run.ops:
yield run
@_handle_reauth
async def write_command(
self,
bwc: _BulkWriteContext,
cmd: MutableMapping[str, Any],
request_id: int,
msg: bytes,
docs: list[Mapping[str, Any]],
client: AsyncMongoClient,
) -> dict[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""
cmd[bwc.field] = docs
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._start(cmd, request_id, docs)
try:
reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc]
duration = datetime.datetime.now() - bwc.start_time
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if bwc.publish:
bwc._fail(request_id, failure, duration)
raise
finally:
bwc.start_time = datetime.datetime.now()
return reply # type: ignore[return-value]
async def unack_write(
self,
bwc: _BulkWriteContext,
cmd: MutableMapping[str, Any],
request_id: int,
msg: bytes,
max_doc_size: int,
docs: list[Mapping[str, Any]],
client: AsyncMongoClient,
) -> Optional[Mapping[str, Any]]:
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
cmd = bwc._start(cmd, request_id, docs)
try:
result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override]
duration = datetime.datetime.now() - bwc.start_time
if result is not None:
reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type]
else:
# Comply with APM spec.
reply = {"ok": 1}
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration)
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, OperationFailure):
failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type]
elif isinstance(exc, NotPrimaryError):
failure = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if bwc.publish:
assert bwc.start_time is not None
bwc._fail(request_id, failure, duration)
raise
finally:
bwc.start_time = datetime.datetime.now()
return result # type: ignore[return-value]
async def _execute_batch_unack(
self,
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
cmd: dict[str, Any],
ops: list[Mapping[str, Any]],
client: AsyncMongoClient,
) -> list[Mapping[str, Any]]:
if self.is_encrypted:
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
await bwc.conn.command( # type: ignore[misc]
bwc.db_name,
batched_cmd, # type: ignore[arg-type]
write_concern=WriteConcern(w=0),
session=bwc.session, # type: ignore[arg-type]
client=client, # type: ignore[arg-type]
)
else:
request_id, msg, to_send = bwc.batch_command(cmd, ops)
# Though this isn't strictly a "legacy" write, the helper
# handles publishing commands and sending our message
# without receiving a result. Send 0 for max_doc_size
# to disable size checking. Size checking is handled while
# the documents are encoded to BSON.
await self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) # type: ignore[arg-type]
return to_send
async def _execute_batch(
self,
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
cmd: dict[str, Any],
ops: list[Mapping[str, Any]],
client: AsyncMongoClient,
) -> tuple[dict[str, Any], list[Mapping[str, Any]]]:
if self.is_encrypted:
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
result = await bwc.conn.command( # type: ignore[misc]
bwc.db_name,
batched_cmd, # type: ignore[arg-type]
codec_options=bwc.codec,
session=bwc.session, # type: ignore[arg-type]
client=client, # type: ignore[arg-type]
)
else:
request_id, msg, to_send = bwc.batch_command(cmd, ops)
result = await self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type]
await client._process_response(result, bwc.session) # type: ignore[arg-type]
return result, to_send # type: ignore[return-value]
async def _execute_command(
self,
generator: Iterator[Any],
write_concern: WriteConcern,
session: Optional[ClientSession],
conn: Connection,
session: Optional[AsyncClientSession],
conn: AsyncConnection,
op_id: int,
retryable: bool,
full_result: MutableMapping[str, Any],
@ -335,8 +473,8 @@ class _Bulk:
self.next_run = None
run = self.current_run
# Connection.command validates the session, but we use
# Connection.write_command
# AsyncConnection.command validates the session, but we use
# AsyncConnection.write_command
conn.validate_session(client, session)
last_run = False
@ -387,7 +525,7 @@ class _Bulk:
# Run as many ops as possible in one command.
if write_concern.acknowledged:
result, to_send = await bwc.execute(cmd, ops, client)
result, to_send = await self._execute_batch(bwc, cmd, ops, client)
# Retryable writeConcernErrors halt the execution of this run.
wce = result.get("writeConcernError", {})
@ -407,7 +545,7 @@ class _Bulk:
if self.ordered and "writeErrors" in result:
break
else:
to_send = await bwc.execute_unack(cmd, ops, client)
to_send = await self._execute_batch_unack(bwc, cmd, ops, client)
run.idx_offset += len(to_send)
@ -422,7 +560,7 @@ class _Bulk:
self,
generator: Iterator[Any],
write_concern: WriteConcern,
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
operation: str,
) -> dict[str, Any]:
"""Execute using write commands."""
@ -440,7 +578,7 @@ class _Bulk:
op_id = _randint()
async def retryable_bulk(
session: Optional[ClientSession], conn: Connection, retryable: bool
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool
) -> None:
await self._execute_command(
generator,
@ -458,7 +596,7 @@ class _Bulk:
retryable_bulk,
session,
operation,
bulk=self,
bulk=self, # type: ignore[arg-type]
operation_id=op_id,
)
@ -466,7 +604,9 @@ class _Bulk:
_raise_bulk_write_error(full_result)
return full_result
async def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None:
async def execute_op_msg_no_results(
self, conn: AsyncConnection, generator: Iterator[Any]
) -> None:
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
db_name = self.collection.database.name
client = self.collection.database.client
@ -499,13 +639,13 @@ class _Bulk:
conn.add_server_api(cmd)
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible.
to_send = await bwc.execute_unack(cmd, ops, client)
to_send = await self._execute_batch_unack(bwc, cmd, ops, client)
run.idx_offset += len(to_send)
self.current_run = run = next(generator, None)
async def execute_command_no_results(
self,
conn: Connection,
conn: AsyncConnection,
generator: Iterator[Any],
write_concern: WriteConcern,
) -> None:
@ -541,7 +681,7 @@ class _Bulk:
async def execute_no_results(
self,
conn: Connection,
conn: AsyncConnection,
generator: Iterator[Any],
write_concern: WriteConcern,
) -> None:
@ -573,7 +713,7 @@ class _Bulk:
async def execute(
self,
write_concern: WriteConcern,
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
operation: str,
) -> Any:
"""Execute operations."""

View File

@ -21,17 +21,14 @@ from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
from bson import CodecOptions, _bson_to_dict
from bson.raw_bson import RawBSONDocument
from bson.timestamp import Timestamp
from pymongo import _csot
from pymongo.asynchronous import common
from pymongo import _csot, common
from pymongo.asynchronous.aggregation import (
_AggregationCommand,
_CollectionAggregationCommand,
_DatabaseAggregationCommand,
)
from pymongo.asynchronous.collation import validate_collation_or_none
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.operations import _Op
from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _Pipeline
from pymongo.collation import validate_collation_or_none
from pymongo.errors import (
ConnectionFailure,
CursorNotFound,
@ -39,6 +36,8 @@ from pymongo.errors import (
OperationFailure,
PyMongoError,
)
from pymongo.operations import _Op
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
_IS_SYNC = False
@ -68,11 +67,11 @@ _RESUMABLE_GETMORE_ERRORS = frozenset(
if TYPE_CHECKING:
from pymongo.asynchronous.client_session import ClientSession
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.pool import AsyncConnection
def _resumable(exc: PyMongoError) -> bool:
@ -114,7 +113,7 @@ class ChangeStream(Generic[_DocumentType]):
batch_size: Optional[int],
collation: Optional[_CollationIn],
start_at_operation_time: Optional[Timestamp],
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
start_after: Optional[Mapping[str, Any]],
comment: Optional[Any] = None,
full_document_before_change: Optional[str] = None,
@ -211,7 +210,7 @@ class ChangeStream(Generic[_DocumentType]):
full_pipeline.extend(self._pipeline)
return full_pipeline
def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None:
def _process_result(self, result: Mapping[str, Any], conn: AsyncConnection) -> None:
"""Callback that caches the postBatchResumeToken or
startAtOperationTime from a changeStream aggregate command response
containing an empty batch of change documents.
@ -237,7 +236,7 @@ class ChangeStream(Generic[_DocumentType]):
)
async def _run_aggregation_cmd(
self, session: Optional[ClientSession], explicit_session: bool
self, session: Optional[AsyncClientSession], explicit_session: bool
) -> AsyncCommandCursor:
"""Run the full aggregation pipeline for this ChangeStream and return
the corresponding AsyncCommandCursor.

View File

@ -1,334 +0,0 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Tools to parse mongo client options."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast
from bson.codec_options import _parse_codec_options
from pymongo.asynchronous import common
from pymongo.asynchronous.compression_support import CompressionSettings
from pymongo.asynchronous.monitoring import _EventListener, _EventListeners
from pymongo.asynchronous.pool import PoolOptions
from pymongo.asynchronous.read_preferences import (
_ServerMode,
make_read_preference,
read_pref_mode_from_name,
)
from pymongo.asynchronous.server_selectors import any_server_selector
from pymongo.errors import ConfigurationError
from pymongo.read_concern import ReadConcern
from pymongo.ssl_support import get_ssl_context
from pymongo.write_concern import WriteConcern, validate_boolean
if TYPE_CHECKING:
from bson.codec_options import CodecOptions
from pymongo.asynchronous.auth import MongoCredential
from pymongo.asynchronous.encryption_options import AutoEncryptionOpts
from pymongo.asynchronous.topology_description import _ServerSelector
from pymongo.pyopenssl_context import SSLContext
_IS_SYNC = False
def _parse_credentials(
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
) -> Optional[MongoCredential]:
"""Parse authentication credentials."""
mechanism = options.get("authmechanism", "DEFAULT" if username else None)
source = options.get("authsource")
if username or mechanism:
from pymongo.asynchronous.auth import _build_credentials_tuple
return _build_credentials_tuple(mechanism, source, username, password, options, database)
return None
def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode:
"""Parse read preference options."""
if "read_preference" in options:
return options["read_preference"]
name = options.get("readpreference", "primary")
mode = read_pref_mode_from_name(name)
tags = options.get("readpreferencetags")
max_staleness = options.get("maxstalenessseconds", -1)
return make_read_preference(mode, tags, max_staleness)
def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern:
"""Parse write concern options."""
concern = options.get("w")
wtimeout = options.get("wtimeoutms")
j = options.get("journal")
fsync = options.get("fsync")
return WriteConcern(concern, wtimeout, j, fsync)
def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
"""Parse read concern options."""
concern = options.get("readconcernlevel")
return ReadConcern(concern)
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
"""Parse ssl options."""
use_tls = options.get("tls")
if use_tls is not None:
validate_boolean("tls", use_tls)
certfile = options.get("tlscertificatekeyfile")
passphrase = options.get("tlscertificatekeyfilepassword")
ca_certs = options.get("tlscafile")
crlfile = options.get("tlscrlfile")
allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False)
allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False)
disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False)
enabled_tls_opts = []
for opt in (
"tlscertificatekeyfile",
"tlscertificatekeyfilepassword",
"tlscafile",
"tlscrlfile",
):
# Any non-null value of these options implies tls=True.
if opt in options and options[opt]:
enabled_tls_opts.append(opt)
for opt in (
"tlsallowinvalidcertificates",
"tlsallowinvalidhostnames",
"tlsdisableocspendpointcheck",
):
# A value of False for these options implies tls=True.
if opt in options and not options[opt]:
enabled_tls_opts.append(opt)
if enabled_tls_opts:
if use_tls is None:
# Implicitly enable TLS when one of the tls* options is set.
use_tls = True
elif not use_tls:
# Error since tls is explicitly disabled but a tls option is set.
raise ConfigurationError(
"TLS has not been enabled but the "
"following tls parameters have been set: "
"%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts)
)
if use_tls:
ctx = get_ssl_context(
certfile,
passphrase,
ca_certs,
crlfile,
allow_invalid_certificates,
allow_invalid_hostnames,
disable_ocsp_endpoint_check,
)
return ctx, allow_invalid_hostnames
return None, allow_invalid_hostnames
def _parse_pool_options(
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
) -> PoolOptions:
"""Parse connection pool options."""
credentials = _parse_credentials(username, password, database, options)
max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE)
min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE)
max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC)
if max_pool_size is not None and min_pool_size > max_pool_size:
raise ValueError("minPoolSize must be smaller or equal to maxPoolSize")
connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT)
socket_timeout = options.get("sockettimeoutms")
wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT)
event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners"))
appname = options.get("appname")
driver = options.get("driver")
server_api = options.get("server_api")
compression_settings = CompressionSettings(
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
)
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
load_balanced = options.get("loadbalanced")
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
return PoolOptions(
max_pool_size,
min_pool_size,
max_idle_time_seconds,
connect_timeout,
socket_timeout,
wait_queue_timeout,
ssl_context,
tls_allow_invalid_hostnames,
_EventListeners(event_listeners),
appname,
driver,
compression_settings,
max_connecting=max_connecting,
server_api=server_api,
load_balanced=load_balanced,
credentials=credentials,
)
class ClientOptions:
"""Read only configuration options for an AsyncMongoClient.
Should not be instantiated directly by application developers. Access
a client's options via :attr:`pymongo.mongo_client.AsyncMongoClient.options`
instead.
"""
def __init__(
self, username: str, password: str, database: Optional[str], options: Mapping[str, Any]
):
self.__options = options
self.__codec_options = _parse_codec_options(options)
self.__direct_connection = options.get("directconnection")
self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS)
# self.__server_selection_timeout is in seconds. Must use full name for
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
self.__server_selection_timeout = options.get(
"serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT
)
self.__pool_options = _parse_pool_options(username, password, database, options)
self.__read_preference = _parse_read_preference(options)
self.__replica_set_name = options.get("replicaset")
self.__write_concern = _parse_write_concern(options)
self.__read_concern = _parse_read_concern(options)
self.__connect = options.get("connect")
self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY)
self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES)
self.__retry_reads = options.get("retryreads", common.RETRY_READS)
self.__server_selector = options.get("server_selector", any_server_selector)
self.__auto_encryption_opts = options.get("auto_encryption_opts")
self.__load_balanced = options.get("loadbalanced")
self.__timeout = options.get("timeoutms")
self.__server_monitoring_mode = options.get(
"servermonitoringmode", common.SERVER_MONITORING_MODE
)
@property
def _options(self) -> Mapping[str, Any]:
"""The original options used to create this ClientOptions."""
return self.__options
@property
def connect(self) -> Optional[bool]:
"""Whether to begin discovering a MongoDB topology automatically."""
return self.__connect
@property
def codec_options(self) -> CodecOptions:
"""A :class:`~bson.codec_options.CodecOptions` instance."""
return self.__codec_options
@property
def direct_connection(self) -> Optional[bool]:
"""Whether to connect to the deployment in 'Single' topology."""
return self.__direct_connection
@property
def local_threshold_ms(self) -> int:
"""The local threshold for this instance."""
return self.__local_threshold_ms
@property
def server_selection_timeout(self) -> int:
"""The server selection timeout for this instance in seconds."""
return self.__server_selection_timeout
@property
def server_selector(self) -> _ServerSelector:
return self.__server_selector
@property
def heartbeat_frequency(self) -> int:
"""The monitoring frequency in seconds."""
return self.__heartbeat_frequency
@property
def pool_options(self) -> PoolOptions:
"""A :class:`~pymongo.pool.PoolOptions` instance."""
return self.__pool_options
@property
def read_preference(self) -> _ServerMode:
"""A read preference instance."""
return self.__read_preference
@property
def replica_set_name(self) -> Optional[str]:
"""Replica set name or None."""
return self.__replica_set_name
@property
def write_concern(self) -> WriteConcern:
"""A :class:`~pymongo.write_concern.WriteConcern` instance."""
return self.__write_concern
@property
def read_concern(self) -> ReadConcern:
"""A :class:`~pymongo.read_concern.ReadConcern` instance."""
return self.__read_concern
@property
def timeout(self) -> Optional[float]:
"""The configured timeoutMS converted to seconds, or None.
.. versionadded:: 4.2
"""
return self.__timeout
@property
def retry_writes(self) -> bool:
"""If this instance should retry supported write operations."""
return self.__retry_writes
@property
def retry_reads(self) -> bool:
"""If this instance should retry supported read operations."""
return self.__retry_reads
@property
def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]:
"""A :class:`~pymongo.encryption.AutoEncryptionOpts` or None."""
return self.__auto_encryption_opts
@property
def load_balanced(self) -> Optional[bool]:
"""True if the client was configured to connect to a load balancer."""
return self.__load_balanced
@property
def event_listeners(self) -> list[_EventListeners]:
"""The event listeners registered for this client.
See :mod:`~pymongo.monitoring` for details.
.. versionadded:: 4.0
"""
assert self.__pool_options._event_listeners is not None
return self.__pool_options._event_listeners.event_listeners()
@property
def server_monitoring_mode(self) -> str:
"""The configured serverMonitoringMode option.
.. versionadded:: 4.5
"""
return self.__server_monitoring_mode

View File

@ -44,8 +44,8 @@ Transactions
.. versionadded:: 3.7
MongoDB 4.0 adds support for transactions on replica set primaries. A
transaction is associated with a :class:`ClientSession`. To start a transaction
on a session, use :meth:`ClientSession.start_transaction` in a with-statement.
transaction is associated with a :class:`AsyncClientSession`. To start a transaction
on a session, use :meth:`AsyncClientSession.start_transaction` in a with-statement.
Then, execute an operation within the transaction by passing the session to the
operation:
@ -63,9 +63,9 @@ operation:
)
Upon normal completion of ``async with session.start_transaction()`` block, the
transaction automatically calls :meth:`ClientSession.commit_transaction`.
transaction automatically calls :meth:`AsyncClientSession.commit_transaction`.
If the block exits with an exception, the transaction automatically calls
:meth:`ClientSession.abort_transaction`.
:meth:`AsyncClientSession.abort_transaction`.
In general, multi-document transactions only support read/write (CRUD)
operations on existing collections. However, MongoDB 4.4 adds support for
@ -157,8 +157,6 @@ from bson.int64 import Int64
from bson.timestamp import Timestamp
from pymongo import _csot
from pymongo.asynchronous.cursor import _ConnectionManager
from pymongo.asynchronous.operations import _Op
from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode
from pymongo.errors import (
ConfigurationError,
ConnectionFailure,
@ -167,23 +165,25 @@ from pymongo.errors import (
PyMongoError,
WTimeoutError,
)
from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
from pymongo.operations import _Op
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.server_type import SERVER_TYPE
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from types import TracebackType
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.asynchronous.server import Server
from pymongo.asynchronous.typings import ClusterTime, _Address
from pymongo.typings import ClusterTime, _Address
_IS_SYNC = False
class SessionOptions:
"""Options for a new :class:`ClientSession`.
"""Options for a new :class:`AsyncClientSession`.
:param causal_consistency: If True, read operations are causally
ordered within the session. Defaults to True when the ``snapshot``
@ -246,7 +246,7 @@ class SessionOptions:
class TransactionOptions:
"""Options for :meth:`ClientSession.start_transaction`.
"""Options for :meth:`AsyncClientSession.start_transaction`.
:param read_concern: The
:class:`~pymongo.read_concern.ReadConcern` to use for this transaction.
@ -336,8 +336,8 @@ class TransactionOptions:
def _validate_session_write_concern(
session: Optional[ClientSession], write_concern: Optional[WriteConcern]
) -> Optional[ClientSession]:
session: Optional[AsyncClientSession], write_concern: Optional[WriteConcern]
) -> Optional[AsyncClientSession]:
"""Validate that an explicit session is not used with an unack'ed write.
Returns the session to use for the next operation.
@ -362,7 +362,7 @@ def _validate_session_write_concern(
class _TransactionContext:
"""Internal transaction context manager for start_transaction."""
def __init__(self, session: ClientSession):
def __init__(self, session: AsyncClientSession):
self.__session = session
async def __aenter__(self) -> _TransactionContext:
@ -391,7 +391,7 @@ class _TxnState:
class _Transaction:
"""Internal class to hold transaction information in a ClientSession."""
"""Internal class to hold transaction information in a AsyncClientSession."""
def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient):
self.opts = opts
@ -410,12 +410,12 @@ class _Transaction:
return self.state == _TxnState.STARTING
@property
def pinned_conn(self) -> Optional[Connection]:
def pinned_conn(self) -> Optional[AsyncConnection]:
if self.active() and self.conn_mgr:
return self.conn_mgr.conn
return None
def pin(self, server: Server, conn: Connection) -> None:
def pin(self, server: Server, conn: AsyncConnection) -> None:
self.sharded = True
self.pinned_address = server.description.address
if server.description.server_type == SERVER_TYPE.LoadBalancer:
@ -481,16 +481,16 @@ if TYPE_CHECKING:
from pymongo.asynchronous.mongo_client import AsyncMongoClient
class ClientSession:
class AsyncClientSession:
"""A session for ordering sequential operations.
:class:`ClientSession` instances are **not thread-safe or fork-safe**.
:class:`AsyncClientSession` instances are **not thread-safe or fork-safe**.
They can only be used by one thread or process at a time. A single
:class:`ClientSession` cannot be used to run multiple operations
:class:`AsyncClientSession` cannot be used to run multiple operations
concurrently.
Should not be initialized directly by application developers - to create a
:class:`ClientSession`, call
:class:`AsyncClientSession`, call
:meth:`~pymongo.mongo_client.AsyncMongoClient.start_session`.
"""
@ -541,7 +541,7 @@ class ClientSession:
if self._server_session is None:
raise InvalidOperation("Cannot use ended session")
async def __aenter__(self) -> ClientSession:
async def __aenter__(self) -> AsyncClientSession:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
@ -560,14 +560,14 @@ class ClientSession:
return self._options
@property
async def session_id(self) -> Mapping[str, Any]:
def session_id(self) -> Mapping[str, Any]:
"""A BSON document, the opaque server session identifier."""
self._check_ended()
self._materialize(self._client.topology_description.logical_session_timeout_minutes)
return self._server_session.session_id
@property
async def _transaction_id(self) -> Int64:
def _transaction_id(self) -> Int64:
"""The current transaction id for the underlying server session."""
self._materialize(self._client.topology_description.logical_session_timeout_minutes)
return self._server_session.transaction_id
@ -598,7 +598,7 @@ class ClientSession:
async def with_transaction(
self,
callback: Callable[[ClientSession], _T],
callback: Callable[[AsyncClientSession], _T],
read_concern: Optional[ReadConcern] = None,
write_concern: Optional[WriteConcern] = None,
read_preference: Optional[_ServerMode] = None,
@ -646,25 +646,25 @@ class ClientSession:
however, ``with_transaction`` will return without taking further
action.
:class:`ClientSession` instances are **not thread-safe or fork-safe**.
:class:`AsyncClientSession` instances are **not thread-safe or fork-safe**.
Consequently, the ``callback`` must not attempt to execute multiple
operations concurrently.
When ``callback`` raises an exception, ``with_transaction``
automatically aborts the current transaction. When ``callback`` or
:meth:`~ClientSession.commit_transaction` raises an exception that
:meth:`~AsyncClientSession.commit_transaction` raises an exception that
includes the ``"TransientTransactionError"`` error label,
``with_transaction`` starts a new transaction and re-executes
the ``callback``.
When :meth:`~ClientSession.commit_transaction` raises an exception with
When :meth:`~AsyncClientSession.commit_transaction` raises an exception with
the ``"UnknownTransactionCommitResult"`` error label,
``with_transaction`` retries the commit until the result of the
transaction is known.
This method will cease retrying after 120 seconds has elapsed. This
timeout is not configurable and any exception raised by the
``callback`` or by :meth:`ClientSession.commit_transaction` after the
``callback`` or by :meth:`AsyncClientSession.commit_transaction` after the
timeout is reached will be re-raised. Applications that desire a
different timeout duration should not use this method.
@ -850,7 +850,7 @@ class ClientSession:
"""
async def func(
_session: Optional[ClientSession], conn: Connection, _retryable: bool
_session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable: bool
) -> dict[str, Any]:
return await self._finish_transaction(conn, command_name)
@ -858,7 +858,7 @@ class ClientSession:
func, self, None, retryable=True, operation=_Op.ABORT
)
async def _finish_transaction(self, conn: Connection, command_name: str) -> dict[str, Any]:
async def _finish_transaction(self, conn: AsyncConnection, command_name: str) -> dict[str, Any]:
self._transaction.attempt += 1
opts = self._transaction.opts
assert opts
@ -897,8 +897,8 @@ class ClientSession:
"""Update the cluster time for this session.
:param cluster_time: The
:data:`~pymongo.client_session.ClientSession.cluster_time` from
another `ClientSession` instance.
:data:`~pymongo.client_session.AsyncClientSession.cluster_time` from
another `AsyncClientSession` instance.
"""
if not isinstance(cluster_time, _Mapping):
raise TypeError("cluster_time must be a subclass of collections.Mapping")
@ -918,8 +918,8 @@ class ClientSession:
"""Update the operation time for this session.
:param operation_time: The
:data:`~pymongo.client_session.ClientSession.operation_time` from
another `ClientSession` instance.
:data:`~pymongo.client_session.AsyncClientSession.operation_time` from
another `AsyncClientSession` instance.
"""
if not isinstance(operation_time, Timestamp):
raise TypeError("operation_time must be an instance of bson.timestamp.Timestamp")
@ -966,11 +966,11 @@ class ClientSession:
return None
@property
def _pinned_connection(self) -> Optional[Connection]:
def _pinned_connection(self) -> Optional[AsyncConnection]:
"""The connection this transaction was started on."""
return self._transaction.pinned_conn
def _pin(self, server: Server, conn: Connection) -> None:
def _pin(self, server: Server, conn: AsyncConnection) -> None:
"""Pin this session to the given Server or to the given connection."""
self._transaction.pin(server, conn)
@ -999,7 +999,7 @@ class ClientSession:
command: MutableMapping[str, Any],
is_retryable: bool,
read_preference: _ServerMode,
conn: Connection,
conn: AsyncConnection,
) -> None:
if not conn.supports_sessions:
if not self._implicit:
@ -1042,7 +1042,7 @@ class ClientSession:
self._check_ended()
self._server_session.inc_transaction_id()
def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: Connection) -> None:
def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: AsyncConnection) -> None:
if self.options.causal_consistency and self.operation_time is not None:
cmd.setdefault("readConcern", {})["afterClusterTime"] = self.operation_time
if self.options.snapshot:
@ -1054,7 +1054,7 @@ class ClientSession:
rc["atClusterTime"] = self._snapshot_time
def __copy__(self) -> NoReturn:
raise TypeError("A ClientSession cannot be copied, create a new session instead")
raise TypeError("A AsyncClientSession cannot be copied, create a new session instead")
class _EmptyServerSession:

View File

@ -1,226 +0,0 @@
# Copyright 2016 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for working with `collations`_.
.. _collations: https://www.mongodb.com/docs/manual/reference/collation/
"""
from __future__ import annotations
from typing import Any, Mapping, Optional, Union
from pymongo.asynchronous import common
from pymongo.write_concern import validate_boolean
_IS_SYNC = False
class CollationStrength:
"""
An enum that defines values for `strength` on a
:class:`~pymongo.collation.Collation`.
"""
PRIMARY = 1
"""Differentiate base (unadorned) characters."""
SECONDARY = 2
"""Differentiate character accents."""
TERTIARY = 3
"""Differentiate character case."""
QUATERNARY = 4
"""Differentiate words with and without punctuation."""
IDENTICAL = 5
"""Differentiate unicode code point (characters are exactly identical)."""
class CollationAlternate:
"""
An enum that defines values for `alternate` on a
:class:`~pymongo.collation.Collation`.
"""
NON_IGNORABLE = "non-ignorable"
"""Spaces and punctuation are treated as base characters."""
SHIFTED = "shifted"
"""Spaces and punctuation are *not* considered base characters.
Spaces and punctuation are distinguished regardless when the
:class:`~pymongo.collation.Collation` strength is at least
:data:`~pymongo.collation.CollationStrength.QUATERNARY`.
"""
class CollationMaxVariable:
"""
An enum that defines values for `max_variable` on a
:class:`~pymongo.collation.Collation`.
"""
PUNCT = "punct"
"""Both punctuation and spaces are ignored."""
SPACE = "space"
"""Spaces alone are ignored."""
class CollationCaseFirst:
"""
An enum that defines values for `case_first` on a
:class:`~pymongo.collation.Collation`.
"""
UPPER = "upper"
"""Sort uppercase characters first."""
LOWER = "lower"
"""Sort lowercase characters first."""
OFF = "off"
"""Default for locale or collation strength."""
class Collation:
"""Collation
:param locale: (string) The locale of the collation. This should be a string
that identifies an `ICU locale ID` exactly. For example, ``en_US`` is
valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB
documentation for a list of supported locales.
:param caseLevel: (optional) If ``True``, turn on case sensitivity if
`strength` is 1 or 2 (case sensitivity is implied if `strength` is
greater than 2). Defaults to ``False``.
:param caseFirst: (optional) Specify that either uppercase or lowercase
characters take precedence. Must be one of the following values:
* :data:`~CollationCaseFirst.UPPER`
* :data:`~CollationCaseFirst.LOWER`
* :data:`~CollationCaseFirst.OFF` (the default)
:param strength: Specify the comparison strength. This is also
known as the ICU comparison level. This must be one of the following
values:
* :data:`~CollationStrength.PRIMARY`
* :data:`~CollationStrength.SECONDARY`
* :data:`~CollationStrength.TERTIARY` (the default)
* :data:`~CollationStrength.QUATERNARY`
* :data:`~CollationStrength.IDENTICAL`
Each successive level builds upon the previous. For example, a
`strength` of :data:`~CollationStrength.SECONDARY` differentiates
characters based both on the unadorned base character and its accents.
:param numericOrdering: If ``True``, order numbers numerically
instead of in collation order (defaults to ``False``).
:param alternate: Specify whether spaces and punctuation are
considered base characters. This must be one of the following values:
* :data:`~CollationAlternate.NON_IGNORABLE` (the default)
* :data:`~CollationAlternate.SHIFTED`
:param maxVariable: When `alternate` is
:data:`~CollationAlternate.SHIFTED`, this option specifies what
characters may be ignored. This must be one of the following values:
* :data:`~CollationMaxVariable.PUNCT` (the default)
* :data:`~CollationMaxVariable.SPACE`
:param normalization: If ``True``, normalizes text into Unicode
NFD. Defaults to ``False``.
:param backwards: If ``True``, accents on characters are
considered from the back of the word to the front, as it is done in some
French dictionary ordering traditions. Defaults to ``False``.
:param kwargs: Keyword arguments supplying any additional options
to be sent with this Collation object.
.. versionadded: 3.4
"""
__slots__ = ("__document",)
def __init__(
self,
locale: str,
caseLevel: Optional[bool] = None,
caseFirst: Optional[str] = None,
strength: Optional[int] = None,
numericOrdering: Optional[bool] = None,
alternate: Optional[str] = None,
maxVariable: Optional[str] = None,
normalization: Optional[bool] = None,
backwards: Optional[bool] = None,
**kwargs: Any,
) -> None:
locale = common.validate_string("locale", locale)
self.__document: dict[str, Any] = {"locale": locale}
if caseLevel is not None:
self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel)
if caseFirst is not None:
self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst)
if strength is not None:
self.__document["strength"] = common.validate_integer("strength", strength)
if numericOrdering is not None:
self.__document["numericOrdering"] = validate_boolean(
"numericOrdering", numericOrdering
)
if alternate is not None:
self.__document["alternate"] = common.validate_string("alternate", alternate)
if maxVariable is not None:
self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable)
if normalization is not None:
self.__document["normalization"] = validate_boolean("normalization", normalization)
if backwards is not None:
self.__document["backwards"] = validate_boolean("backwards", backwards)
self.__document.update(kwargs)
@property
def document(self) -> dict[str, Any]:
"""The document representation of this collation.
.. note::
:class:`Collation` is immutable. Mutating the value of
:attr:`document` does not mutate this :class:`Collation`.
"""
return self.__document.copy()
def __repr__(self) -> str:
document = self.document
return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document))
def __eq__(self, other: Any) -> bool:
if isinstance(other, Collation):
return self.document == other.document
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def validate_collation_or_none(
value: Optional[Union[Mapping[str, Any], Collation]]
) -> Optional[dict[str, Any]]:
if value is None:
return None
if isinstance(value, Collation):
return value.document
if isinstance(value, dict):
return value
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")

View File

@ -42,27 +42,32 @@ from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from bson.timestamp import Timestamp
from pymongo import ASCENDING, _csot
from pymongo.asynchronous import common, helpers, message
from pymongo import ASCENDING, _csot, common, helpers_shared, message
from pymongo.asynchronous.aggregation import (
_CollectionAggregationCommand,
_CollectionRawAggregationCommand,
)
from pymongo.asynchronous.bulk import _Bulk
from pymongo.asynchronous.bulk import _AsyncBulk
from pymongo.asynchronous.change_stream import CollectionChangeStream
from pymongo.asynchronous.collation import validate_collation_or_none
from pymongo.asynchronous.command_cursor import (
AsyncCommandCursor,
AsyncRawBatchCommandCursor,
)
from pymongo.asynchronous.common import _ecoc_coll_name, _esc_coll_name
from pymongo.asynchronous.cursor import (
AsyncCursor,
AsyncRawBatchCursor,
)
from pymongo.asynchronous.helpers import _check_write_command_response
from pymongo.asynchronous.message import _UNICODE_REPLACE_CODEC_OPTIONS
from pymongo.asynchronous.operations import (
from pymongo.collation import validate_collation_or_none
from pymongo.common import _ecoc_coll_name, _esc_coll_name
from pymongo.errors import (
ConfigurationError,
InvalidName,
InvalidOperation,
OperationFailure,
)
from pymongo.helpers_shared import _check_write_command_response
from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS
from pymongo.operations import (
DeleteMany,
DeleteOne,
IndexModel,
@ -75,15 +80,8 @@ from pymongo.asynchronous.operations import (
_IndexList,
_Op,
)
from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode
from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
from pymongo.errors import (
ConfigurationError,
InvalidName,
InvalidOperation,
OperationFailure,
)
from pymongo.read_concern import DEFAULT_READ_CONCERN
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import (
BulkWriteResult,
DeleteResult,
@ -91,6 +89,7 @@ from pymongo.results import (
InsertOneResult,
UpdateResult,
)
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean
_IS_SYNC = False
@ -127,11 +126,11 @@ class ReturnDocument:
if TYPE_CHECKING:
import bson
from pymongo.asynchronous.aggregation import _AggregationCommand
from pymongo.asynchronous.client_session import ClientSession
from pymongo.asynchronous.collation import Collation
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.asynchronous.server import Server
from pymongo.collation import Collation
from pymongo.read_concern import ReadConcern
@ -147,7 +146,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional[ReadConcern] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
**kwargs: Any,
) -> None:
"""Get / create an asynchronous Mongo collection.
@ -373,7 +372,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
)
def _write_concern_for_cmd(
self, cmd: Mapping[str, Any], session: Optional[ClientSession]
self, cmd: Mapping[str, Any], session: Optional[AsyncClientSession]
) -> WriteConcern:
raw_wc = cmd.get("writeConcern")
if raw_wc is not None:
@ -413,7 +412,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
batch_size: Optional[int] = None,
collation: Optional[_CollationIn] = None,
start_at_operation_time: Optional[Timestamp] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
start_after: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
full_document_before_change: Optional[str] = None,
@ -494,7 +493,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
the specified :class:`~bson.timestamp.Timestamp`. Requires
MongoDB >= 4.0.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param start_after: The same as `resume_after` except that
`start_after` can resume notifications after an invalidate event.
This option and `resume_after` are mutually exclusive.
@ -546,13 +545,13 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
return change_stream
async def _conn_for_writes(
self, session: Optional[ClientSession], operation: str
) -> AsyncContextManager[Connection]:
self, session: Optional[AsyncClientSession], operation: str
) -> AsyncContextManager[AsyncConnection]:
return await self._database.client._conn_for_writes(session, operation)
async def _command(
self,
conn: Connection,
conn: AsyncConnection,
command: MutableMapping[str, Any],
read_preference: Optional[_ServerMode] = None,
codec_options: Optional[CodecOptions] = None,
@ -561,13 +560,13 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
read_concern: Optional[ReadConcern] = None,
write_concern: Optional[WriteConcern] = None,
collation: Optional[_CollationIn] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
retryable_write: bool = False,
user_fields: Optional[Any] = None,
) -> Mapping[str, Any]:
"""Internal command helper.
:param conn` - A Connection instance.
:param conn` - A AsyncConnection instance.
:param command` - The command itself, as a :class:`~bson.son.SON` instance.
:param read_preference` (optional) - The read preference to use.
:param codec_options` (optional) - An instance of
@ -581,7 +580,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
:param collation` (optional) - An instance of
:class:`~pymongo.collation.Collation`.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param retryable_write: True if this command is a retryable
write.
:param user_fields: Response fields that should be decoded
@ -613,7 +612,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
name: str,
options: MutableMapping[str, Any],
collation: Optional[_CollationIn],
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
encrypted_fields: Optional[Mapping[str, Any]] = None,
qev2_required: bool = False,
) -> None:
@ -646,7 +645,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def _create(
self,
options: MutableMapping[str, Any],
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
) -> None:
collation = validate_collation_or_none(options.pop("collation", None))
encrypted_fields = options.pop("encryptedFields", None)
@ -676,7 +675,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
requests: Sequence[_WriteOp[_DocumentType]],
ordered: bool = True,
bypass_document_validation: bool = False,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
let: Optional[Mapping] = None,
) -> BulkWriteResult:
@ -726,7 +725,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
write to opt-out of document level validation. Default is
``False``.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param let: Map of parameter names and values. Values must be
@ -755,7 +754,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
"""
common.validate_list("requests", requests)
blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let)
blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let)
for request in requests:
try:
request._add_to_bulk(blk)
@ -775,7 +774,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
write_concern: WriteConcern,
op_id: Optional[int],
bypass_doc_val: bool,
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
comment: Optional[Any] = None,
) -> Any:
"""Internal helper for inserting a single document."""
@ -786,7 +785,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
command["comment"] = comment
async def _insert_command(
session: Optional[ClientSession], conn: Connection, retryable_write: bool
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
) -> None:
if bypass_doc_val:
command["bypassDocumentValidation"] = True
@ -815,7 +814,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
self,
document: Union[_DocumentType, RawBSONDocument],
bypass_document_validation: bool = False,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
) -> InsertOneResult:
"""Insert a single document.
@ -835,7 +834,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
write to opt-out of document level validation. Default is
``False``.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
@ -881,7 +880,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
documents: Iterable[Union[_DocumentType, RawBSONDocument]],
ordered: bool = True,
bypass_document_validation: bool = False,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
) -> InsertManyResult:
"""Insert an iterable of documents.
@ -904,7 +903,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
write to opt-out of document level validation. Default is
``False``.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
@ -945,14 +944,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
yield (message._INSERT, document)
write_concern = self._write_concern_for(session)
blk = _Bulk(self, ordered, bypass_document_validation, comment=comment)
blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment)
blk.ops = list(gen())
await blk.execute(write_concern, session, _Op.INSERT)
return InsertManyResult(inserted_ids, write_concern.acknowledged)
async def _update(
self,
conn: Connection,
conn: AsyncConnection,
criteria: Mapping[str, Any],
document: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
@ -964,7 +963,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
collation: Optional[_CollationIn] = None,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
retryable_write: bool = False,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
@ -996,7 +995,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
)
if not isinstance(hint, str):
hint = helpers._index_document(hint)
hint = helpers_shared._index_document(hint)
update_doc["hint"] = hint
command = {"update": self.name, "ordered": ordered, "updates": [update_doc]}
if let is not None:
@ -1051,14 +1050,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
collation: Optional[_CollationIn] = None,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
) -> Optional[Mapping[str, Any]]:
"""Internal update / replace helper."""
async def _update(
session: Optional[ClientSession], conn: Connection, retryable_write: bool
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
) -> Optional[Mapping[str, Any]]:
return await self._update(
conn,
@ -1094,7 +1093,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
bypass_document_validation: bool = False,
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
) -> UpdateResult:
@ -1143,7 +1142,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param let: Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
@ -1197,7 +1196,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
collation: Optional[_CollationIn] = None,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
) -> UpdateResult:
@ -1252,7 +1251,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param let: Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
@ -1310,7 +1309,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
bypass_document_validation: Optional[bool] = None,
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
) -> UpdateResult:
@ -1352,7 +1351,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param let: Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
@ -1404,14 +1403,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def drop(
self,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
encrypted_fields: Optional[Mapping[str, Any]] = None,
) -> None:
"""Alias for :meth:`~pymongo.database.AsyncDatabase.drop_collection`.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param encrypted_fields: **(BETA)** Document that describes the encrypted fields for
@ -1447,7 +1446,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def _delete(
self,
conn: Connection,
conn: AsyncConnection,
criteria: Mapping[str, Any],
multi: bool,
write_concern: Optional[WriteConcern] = None,
@ -1455,7 +1454,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
ordered: bool = True,
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
retryable_write: bool = False,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
@ -1477,7 +1476,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
)
if not isinstance(hint, str):
hint = helpers._index_document(hint)
hint = helpers_shared._index_document(hint)
delete_doc["hint"] = hint
command = {"delete": self.name, "ordered": ordered, "deletes": [delete_doc]}
@ -1510,14 +1509,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
ordered: bool = True,
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
) -> Mapping[str, Any]:
"""Internal delete helper."""
async def _delete(
session: Optional[ClientSession], conn: Connection, retryable_write: bool
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
) -> Mapping[str, Any]:
return await self._delete(
conn,
@ -1546,7 +1545,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
filter: Mapping[str, Any],
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
) -> DeleteResult:
@ -1570,7 +1569,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.4 and above.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param let: Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
@ -1611,7 +1610,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
filter: Mapping[str, Any],
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
) -> DeleteResult:
@ -1635,7 +1634,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.4 and above.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param let: Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
@ -1737,7 +1736,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
always be returned. Use a dict to exclude fields from
the result (e.g. projection={'_id': False}).
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param skip: the number of documents to omit (from
the start of the result set) when returning the results
:param limit: the maximum number of results to
@ -1931,8 +1930,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def _count_cmd(
self,
session: Optional[ClientSession],
conn: Connection,
session: Optional[AsyncClientSession],
conn: AsyncConnection,
read_preference: Optional[_ServerMode],
cmd: dict[str, Any],
collation: Optional[Collation],
@ -1956,11 +1955,11 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def _aggregate_one_result(
self,
conn: Connection,
conn: AsyncConnection,
read_preference: Optional[_ServerMode],
cmd: dict[str, Any],
collation: Optional[_CollationIn],
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
) -> Optional[Mapping[str, Any]]:
"""Internal helper to run an aggregate that returns a single result."""
result = await self._command(
@ -2012,9 +2011,9 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
kwargs["comment"] = comment
async def _cmd(
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
_server: Server,
conn: Connection,
conn: AsyncConnection,
read_preference: Optional[_ServerMode],
) -> int:
cmd: dict[str, Any] = {"count": self._name}
@ -2026,7 +2025,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def count_documents(
self,
filter: Mapping[str, Any],
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> int:
@ -2072,7 +2071,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
to count in the collection. Can be an empty document to count all
documents.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: See list of options above.
@ -2095,14 +2094,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
pipeline.append({"$group": {"_id": 1, "n": {"$sum": 1}}})
cmd = {"aggregate": self._name, "pipeline": pipeline, "cursor": {}}
if "hint" in kwargs and not isinstance(kwargs["hint"], str):
kwargs["hint"] = helpers._index_document(kwargs["hint"])
kwargs["hint"] = helpers_shared._index_document(kwargs["hint"])
collation = validate_collation_or_none(kwargs.pop("collation", None))
cmd.update(kwargs)
async def _cmd(
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
_server: Server,
conn: Connection,
conn: AsyncConnection,
read_preference: Optional[_ServerMode],
) -> int:
result = await self._aggregate_one_result(
@ -2117,10 +2116,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def _retryable_non_cursor_read(
self,
func: Callable[
[Optional[ClientSession], Server, Connection, Optional[_ServerMode]],
[Optional[AsyncClientSession], Server, AsyncConnection, Optional[_ServerMode]],
Coroutine[Any, Any, T],
],
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
operation: str,
) -> T:
"""Non-cursor read helper to handle implicit session creation."""
@ -2131,7 +2130,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def create_indexes(
self,
indexes: Sequence[IndexModel],
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> list[str]:
@ -2147,7 +2146,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
:param indexes: A list of :class:`~pymongo.operations.IndexModel`
instances.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: optional arguments to the createIndexes
@ -2177,14 +2176,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
@_csot.apply
async def _create_indexes(
self, indexes: Sequence[IndexModel], session: Optional[ClientSession], **kwargs: Any
self, indexes: Sequence[IndexModel], session: Optional[AsyncClientSession], **kwargs: Any
) -> list[str]:
"""Internal createIndexes helper.
:param indexes: A list of :class:`~pymongo.operations.IndexModel`
instances.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param kwargs: optional arguments to the createIndexes
command (like maxTimeMS) can be passed as keyword arguments.
"""
@ -2223,7 +2222,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def create_index(
self,
keys: _IndexKeyHint,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> str:
@ -2299,7 +2298,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
:param keys: a single key or a list of (key, direction)
pairs specifying the index to create
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: any additional index creation
@ -2340,7 +2339,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def drop_indexes(
self,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> None:
@ -2350,7 +2349,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
Raises OperationFailure on an error.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: optional arguments to the createIndexes
@ -2375,7 +2374,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def drop_index(
self,
index_or_name: _IndexKeyHint,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> None:
@ -2397,7 +2396,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
:param index_or_name: index (or name of index) to drop
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: optional arguments to the createIndexes
@ -2424,13 +2423,13 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def _drop_index(
self,
index_or_name: _IndexKeyHint,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> None:
name = index_or_name
if isinstance(index_or_name, list):
name = helpers._gen_index_name(index_or_name)
name = helpers_shared._gen_index_name(index_or_name)
if not isinstance(name, str):
raise TypeError("index_or_name must be an instance of str or list")
@ -2451,7 +2450,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def list_indexes(
self,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
) -> AsyncCommandCursor[MutableMapping[str, Any]]:
"""Get a cursor over the index documents for this collection.
@ -2462,7 +2461,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
SON([('v', 2), ('key', SON([('_id', 1)])), ('name', '_id_')])
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
@ -2480,7 +2479,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def _list_indexes(
self,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
) -> AsyncCommandCursor[MutableMapping[str, Any]]:
codec_options: CodecOptions = CodecOptions(SON)
@ -2492,9 +2491,9 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
explicit_session = session is not None
async def _cmd(
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
_server: Server,
conn: Connection,
conn: AsyncConnection,
read_preference: _ServerMode,
) -> AsyncCommandCursor[MutableMapping[str, Any]]:
cmd = {"listIndexes": self._name, "cursor": {}}
@ -2529,7 +2528,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def index_information(
self,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
) -> MutableMapping[str, Any]:
"""Get information on this collection's indexes.
@ -2551,7 +2550,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
'x_1': {'unique': True, 'key': [('x', 1)]}}
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
@ -2572,7 +2571,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def list_search_indexes(
self,
name: Optional[str] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> AsyncCommandCursor[Mapping[str, Any]]:
@ -2582,7 +2581,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
for. Only indexes with matching index names will be returned.
If not given, all search indexes for the current collection
will be returned.
:param session: a :class:`~pymongo.client_session.ClientSession`.
:param session: a :class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
@ -2625,7 +2624,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def create_search_index(
self,
model: Union[Mapping[str, Any], SearchIndexModel],
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Any = None,
**kwargs: Any,
) -> str:
@ -2636,7 +2635,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
instance or a dictionary with a model "definition" and optional
"name".
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: optional arguments to the createSearchIndexes
@ -2655,14 +2654,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def create_search_indexes(
self,
models: list[SearchIndexModel],
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> list[str]:
"""Create multiple search indexes for the current collection.
:param models: A list of :class:`~pymongo.operations.SearchIndexModel` instances.
:param session: a :class:`~pymongo.client_session.ClientSession`.
:param session: a :class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: optional arguments to the createSearchIndexes
@ -2679,7 +2678,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def _create_search_indexes(
self,
models: list[SearchIndexModel],
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> list[str]:
@ -2711,7 +2710,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def drop_search_index(
self,
name: str,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> None:
@ -2719,7 +2718,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
:param name: The name of the search index to be deleted.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: optional arguments to the dropSearchIndexes
@ -2746,7 +2745,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
self,
name: str,
definition: Mapping[str, Any],
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> None:
@ -2755,7 +2754,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
:param name: The name of the search index to be updated.
:param definition: The new search index definition.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: optional arguments to the updateSearchIndexes
@ -2780,7 +2779,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def options(
self,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
) -> MutableMapping[str, Any]:
"""Get the options set on this collection.
@ -2791,7 +2790,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
dictionary if the collection has not been created yet.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
@ -2830,7 +2829,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
aggregation_command: Type[_AggregationCommand],
pipeline: _Pipeline,
cursor_class: Type[AsyncCommandCursor],
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
@ -2859,7 +2858,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def aggregate(
self,
pipeline: _Pipeline,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
@ -2882,7 +2881,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
:param pipeline: a list of aggregation pipeline stages
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param let: A dict of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
@ -2954,7 +2953,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def aggregate_raw_batches(
self,
pipeline: _Pipeline,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> AsyncRawBatchCursor[_DocumentType]:
@ -3003,7 +3002,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def rename(
self,
new_name: str,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> MutableMapping[str, Any]:
@ -3017,7 +3016,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
:param new_name: new name for this collection
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: additional arguments to the rename command
@ -3067,7 +3066,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
self,
key: str,
filter: Optional[Mapping[str, Any]] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> list:
@ -3093,7 +3092,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
:param filter: A query document that specifies the documents
from which to retrieve the distinct values.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: See list of options above.
@ -3118,9 +3117,9 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd["comment"] = comment
async def _cmd(
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
_server: Server,
conn: Connection,
conn: AsyncConnection,
read_preference: Optional[_ServerMode],
) -> list:
return (
@ -3146,7 +3145,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
return_document: bool = ReturnDocument.BEFORE,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
let: Optional[Mapping] = None,
**kwargs: Any,
) -> Any:
@ -3163,20 +3162,20 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd["let"] = let
cmd.update(kwargs)
if projection is not None:
cmd["fields"] = helpers._fields_list_to_dict(projection, "projection")
cmd["fields"] = helpers_shared._fields_list_to_dict(projection, "projection")
if sort is not None:
cmd["sort"] = helpers._index_document(sort)
cmd["sort"] = helpers_shared._index_document(sort)
if upsert is not None:
validate_boolean("upsert", upsert)
cmd["upsert"] = upsert
if hint is not None:
if not isinstance(hint, str):
hint = helpers._index_document(hint)
hint = helpers_shared._index_document(hint)
write_concern = self._write_concern_for_cmd(cmd, session)
async def _find_and_modify_helper(
session: Optional[ClientSession], conn: Connection, retryable_write: bool
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
) -> Any:
acknowledged = write_concern.acknowledged
if array_filters is not None:
@ -3222,7 +3221,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
sort: Optional[_IndexList] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
@ -3268,7 +3267,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
(e.g. ``[('field', ASCENDING)]``). This option is only supported
on MongoDB 4.4 and above.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param let: Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
@ -3314,7 +3313,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
upsert: bool = False,
return_document: bool = ReturnDocument.BEFORE,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
@ -3366,7 +3365,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.4 and above.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param let: Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
@ -3422,7 +3421,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
return_document: bool = ReturnDocument.BEFORE,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
let: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
@ -3513,7 +3512,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.4 and above.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param let: Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an

View File

@ -30,22 +30,22 @@ from typing import (
from bson import CodecOptions, _convert_raw_document_lists_to_streams
from pymongo.asynchronous.cursor import _ConnectionManager
from pymongo.asynchronous.message import (
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.message import (
_CursorAddress,
_GetMore,
_OpMsg,
_OpReply,
_RawBatchGetMore,
)
from pymongo.asynchronous.response import PinnedResponse
from pymongo.asynchronous.typings import _Address, _DocumentOut, _DocumentType
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.response import PinnedResponse
from pymongo.typings import _Address, _DocumentOut, _DocumentType
if TYPE_CHECKING:
from pymongo.asynchronous.client_session import ClientSession
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.pool import AsyncConnection
_IS_SYNC = False
@ -62,7 +62,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
address: Optional[_Address],
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
@ -133,7 +133,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
"""
return self._postbatchresumetoken
async def _maybe_pin_connection(self, conn: Connection) -> None:
async def _maybe_pin_connection(self, conn: AsyncConnection) -> None:
client = self._collection.database.client
if not client._should_pin_cursor(self._session):
return
@ -188,8 +188,8 @@ class AsyncCommandCursor(Generic[_DocumentType]):
return self._address
@property
def session(self) -> Optional[ClientSession]:
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
def session(self) -> Optional[AsyncClientSession]:
"""The cursor's :class:`~pymongo.client_session.AsyncClientSession`, or None.
.. versionadded:: 3.6
"""
@ -272,7 +272,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
if isinstance(response, PinnedResponse):
if not self._sock_mgr:
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type]
if response.from_command:
cursor = response.docs[0]["cursor"]
documents = cursor["nextBatch"]
@ -384,7 +384,7 @@ class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):
address: Optional[_Address],
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:

File diff suppressed because it is too large Load Diff

View File

@ -1,178 +0,0 @@
# Copyright 2018 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from typing import Any, Iterable, Optional, Union
from pymongo.asynchronous.hello_compat import HelloCompat
from pymongo.helpers_constants import _SENSITIVE_COMMANDS
_IS_SYNC = False
_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"}
_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD}
_NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
def _have_snappy() -> bool:
try:
import snappy # type:ignore[import] # noqa: F401
return True
except ImportError:
return False
def _have_zlib() -> bool:
try:
import zlib # noqa: F401
return True
except ImportError:
return False
def _have_zstd() -> bool:
try:
import zstandard # noqa: F401
return True
except ImportError:
return False
def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[str]:
try:
# `value` is string.
compressors = value.split(",") # type: ignore[union-attr]
except AttributeError:
# `value` is an iterable.
compressors = list(value)
for compressor in compressors[:]:
if compressor not in _SUPPORTED_COMPRESSORS:
compressors.remove(compressor)
warnings.warn(f"Unsupported compressor: {compressor}", stacklevel=2)
elif compressor == "snappy" and not _have_snappy():
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with snappy is not available. "
"You must install the python-snappy module for snappy support.",
stacklevel=2,
)
elif compressor == "zlib" and not _have_zlib():
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with zlib is not available. "
"The zlib module is not available.",
stacklevel=2,
)
elif compressor == "zstd" and not _have_zstd():
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with zstandard is not available. "
"You must install the zstandard module for zstandard support.",
stacklevel=2,
)
return compressors
def validate_zlib_compression_level(option: str, value: Any) -> int:
try:
level = int(value)
except Exception:
raise TypeError(f"{option} must be an integer, not {value!r}.") from None
if level < -1 or level > 9:
raise ValueError("%s must be between -1 and 9, not %d." % (option, level))
return level
class CompressionSettings:
def __init__(self, compressors: list[str], zlib_compression_level: int):
self.compressors = compressors
self.zlib_compression_level = zlib_compression_level
def get_compression_context(
self, compressors: Optional[list[str]]
) -> Union[SnappyContext, ZlibContext, ZstdContext, None]:
if compressors:
chosen = compressors[0]
if chosen == "snappy":
return SnappyContext()
elif chosen == "zlib":
return ZlibContext(self.zlib_compression_level)
elif chosen == "zstd":
return ZstdContext()
return None
return None
class SnappyContext:
compressor_id = 1
@staticmethod
def compress(data: bytes) -> bytes:
import snappy
return snappy.compress(data)
class ZlibContext:
compressor_id = 2
def __init__(self, level: int):
self.level = level
def compress(self, data: bytes) -> bytes:
import zlib
return zlib.compress(data, self.level)
class ZstdContext:
compressor_id = 3
@staticmethod
def compress(data: bytes) -> bytes:
# ZstdCompressor is not thread safe.
# TODO: Use a pool?
import zstandard
return zstandard.ZstdCompressor().compress(data)
def decompress(data: bytes, compressor_id: int) -> bytes:
if compressor_id == SnappyContext.compressor_id:
# python-snappy doesn't support the buffer interface.
# https://github.com/andrix/python-snappy/issues/65
# This only matters when data is a memoryview since
# id(bytes(data)) == id(data) when data is a bytes.
import snappy
return snappy.uncompress(bytes(data))
elif compressor_id == ZlibContext.compressor_id:
import zlib
return zlib.decompress(data)
elif compressor_id == ZstdContext.compressor_id:
# ZstdDecompressor is not thread safe.
# TODO: Use a pool?
import zstandard
return zstandard.ZstdDecompressor().decompress(data)
else:
raise ValueError("Unknown compressorId %d" % (compressor_id,))

View File

@ -36,14 +36,17 @@ from typing import (
from bson import RE_TYPE, _convert_raw_document_lists_to_streams
from bson.code import Code
from bson.son import SON
from pymongo.asynchronous import helpers
from pymongo.asynchronous.collation import validate_collation_or_none
from pymongo.asynchronous.common import (
from pymongo import helpers_shared
from pymongo.asynchronous.helpers import anext
from pymongo.collation import validate_collation_or_none
from pymongo.common import (
validate_is_document_type,
validate_is_mapping,
)
from pymongo.asynchronous.helpers import anext
from pymongo.asynchronous.message import (
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.lock import _ALock, _create_lock
from pymongo.message import (
_CursorAddress,
_GetMore,
_OpMsg,
@ -52,21 +55,18 @@ from pymongo.asynchronous.message import (
_RawBatchGetMore,
_RawBatchQuery,
)
from pymongo.asynchronous.response import PinnedResponse
from pymongo.asynchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.lock import _ALock, _create_lock
from pymongo.response import PinnedResponse
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
from pymongo.write_concern import validate_boolean
if TYPE_CHECKING:
from _typeshed import SupportsItems
from bson.codec_options import CodecOptions
from pymongo.asynchronous.client_session import ClientSession
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.read_preferences import _ServerMode
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.read_preferences import _ServerMode
_IS_SYNC = False
@ -74,8 +74,8 @@ _IS_SYNC = False
class _ConnectionManager:
"""Used with exhaust cursors to ensure the connection is returned."""
def __init__(self, conn: Connection, more_to_come: bool):
self.conn: Optional[Connection] = conn
def __init__(self, conn: AsyncConnection, more_to_come: bool):
self.conn: Optional[AsyncConnection] = conn
self.more_to_come = more_to_come
self._alock = _ALock(_create_lock())
@ -116,7 +116,7 @@ class AsyncCursor(Generic[_DocumentType]):
show_record_id: Optional[bool] = None,
snapshot: Optional[bool] = None,
comment: Optional[Any] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
allow_disk_use: Optional[bool] = None,
let: Optional[bool] = None,
) -> None:
@ -134,7 +134,7 @@ class AsyncCursor(Generic[_DocumentType]):
self._exhaust = False
self._sock_mgr: Any = None
self._killed = False
self._session: Optional[ClientSession]
self._session: Optional[AsyncClientSession]
if session:
self._session = session
@ -179,7 +179,7 @@ class AsyncCursor(Generic[_DocumentType]):
allow_disk_use = validate_boolean("allow_disk_use", allow_disk_use)
if projection is not None:
projection = helpers._fields_list_to_dict(projection, "projection")
projection = helpers_shared._fields_list_to_dict(projection, "projection")
if let is not None:
validate_is_document_type("let", let)
@ -191,7 +191,7 @@ class AsyncCursor(Generic[_DocumentType]):
self._skip = skip
self._limit = limit
self._batch_size = batch_size
self._ordering = sort and helpers._index_document(sort) or None
self._ordering = sort and helpers_shared._index_document(sort) or None
self._max_scan = max_scan
self._explain = False
self._comment = comment
@ -313,7 +313,7 @@ class AsyncCursor(Generic[_DocumentType]):
base.__dict__.update(data)
return base
def _clone_base(self, session: Optional[ClientSession]) -> AsyncCursor:
def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncCursor:
"""Creates an empty Cursor object for information to be copied into."""
return self.__class__(self._collection, session=session)
@ -742,8 +742,8 @@ class AsyncCursor(Generic[_DocumentType]):
key, if not given :data:`~pymongo.ASCENDING` is assumed
"""
self._check_okay_to_chain()
keys = helpers._index_list(key_or_list, direction)
self._ordering = helpers._index_document(keys)
keys = helpers_shared._index_list(key_or_list, direction)
self._ordering = helpers_shared._index_document(keys)
return self
async def explain(self) -> _DocumentType:
@ -774,7 +774,7 @@ class AsyncCursor(Generic[_DocumentType]):
if isinstance(index, str):
self._hint = index
else:
self._hint = helpers._index_document(index)
self._hint = helpers_shared._index_document(index)
def hint(self, index: Optional[_Hint]) -> AsyncCursor[_DocumentType]:
"""Adds a 'hint', telling Mongo the proper index to use for the query.
@ -927,8 +927,8 @@ class AsyncCursor(Generic[_DocumentType]):
return self._address
@property
def session(self) -> Optional[ClientSession]:
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
def session(self) -> Optional[AsyncClientSession]:
"""The cursor's :class:`~pymongo.client_session.AsyncClientSession`, or None.
.. versionadded:: 3.6
"""
@ -1122,7 +1122,7 @@ class AsyncCursor(Generic[_DocumentType]):
self._address = response.address
if isinstance(response, PinnedResponse):
if not self._sock_mgr:
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type]
cmd_name = operation.name
docs = response.docs

View File

@ -33,25 +33,24 @@ from typing import (
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
from bson.dbref import DBRef
from bson.timestamp import Timestamp
from pymongo import _csot
from pymongo.asynchronous import common
from pymongo import _csot, common
from pymongo.asynchronous.aggregation import _DatabaseAggregationCommand
from pymongo.asynchronous.change_stream import DatabaseChangeStream
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.common import _ecoc_coll_name, _esc_coll_name
from pymongo.asynchronous.operations import _Op
from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode
from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
from pymongo.common import _ecoc_coll_name, _esc_coll_name
from pymongo.database_shared import _check_name, _CodecDocumentType
from pymongo.errors import CollectionInvalid, InvalidOperation
from pymongo.operations import _Op
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
if TYPE_CHECKING:
import bson
import bson.codec_options
from pymongo.asynchronous.client_session import ClientSession
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.asynchronous.server import Server
from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern
@ -151,7 +150,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
>>> db1.read_preference
Primary()
>>> from pymongo.asynchronous.read_preferences import Secondary
>>> from pymongo.read_preferences import Secondary
>>> db2 = db1.with_options(read_preference=Secondary([{'node': 'analytics'}]))
>>> db1.read_preference
Primary()
@ -328,7 +327,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
batch_size: Optional[int] = None,
collation: Optional[_CollationIn] = None,
start_at_operation_time: Optional[Timestamp] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
start_after: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
full_document_before_change: Optional[str] = None,
@ -402,7 +401,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
the specified :class:`~bson.timestamp.Timestamp`. Requires
MongoDB >= 4.0.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param start_after: The same as `resume_after` except that
`start_after` can resume notifications after an invalidate event.
This option and `resume_after` are mutually exclusive.
@ -458,7 +457,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional[ReadConcern] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
check_exists: Optional[bool] = True,
**kwargs: Any,
) -> AsyncCollection[_DocumentType]:
@ -489,7 +488,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param `check_exists`: if True (the default), send a listCollections command to
check if the collection already exists before creation.
:param kwargs: additional keyword arguments will
@ -607,7 +606,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
return coll
async def aggregate(
self, pipeline: _Pipeline, session: Optional[ClientSession] = None, **kwargs: Any
self, pipeline: _Pipeline, session: Optional[AsyncClientSession] = None, **kwargs: Any
) -> AsyncCommandCursor[_DocumentType]:
"""Perform a database-level aggregation.
@ -634,7 +633,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
:param pipeline: a list of aggregation pipeline stages
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param kwargs: extra `aggregate command`_ parameters.
All optional `aggregate command`_ parameters should be passed as
@ -688,7 +687,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
@overload
async def _command(
self,
conn: Connection,
conn: AsyncConnection,
command: Union[str, MutableMapping[str, Any]],
value: int = 1,
check: bool = True,
@ -697,7 +696,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
codec_options: CodecOptions[dict[str, Any]] = DEFAULT_CODEC_OPTIONS,
write_concern: Optional[WriteConcern] = None,
parse_write_concern_error: bool = False,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
**kwargs: Any,
) -> dict[str, Any]:
...
@ -705,7 +704,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
@overload
async def _command(
self,
conn: Connection,
conn: AsyncConnection,
command: Union[str, MutableMapping[str, Any]],
value: int = 1,
check: bool = True,
@ -714,14 +713,14 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
codec_options: CodecOptions[_CodecDocumentType] = ...,
write_concern: Optional[WriteConcern] = None,
parse_write_concern_error: bool = False,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
**kwargs: Any,
) -> _CodecDocumentType:
...
async def _command(
self,
conn: Connection,
conn: AsyncConnection,
command: Union[str, MutableMapping[str, Any]],
value: int = 1,
check: bool = True,
@ -732,7 +731,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
] = DEFAULT_CODEC_OPTIONS,
write_concern: Optional[WriteConcern] = None,
parse_write_concern_error: bool = False,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
**kwargs: Any,
) -> Union[dict[str, Any], _CodecDocumentType]:
"""Internal command helper."""
@ -763,7 +762,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: Optional[_ServerMode] = None,
codec_options: None = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> dict[str, Any]:
@ -778,7 +777,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: Optional[_ServerMode] = None,
codec_options: CodecOptions[_CodecDocumentType] = ...,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> _CodecDocumentType:
@ -793,7 +792,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: Optional[_ServerMode] = None,
codec_options: Optional[bson.codec_options.CodecOptions[_CodecDocumentType]] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> Union[dict[str, Any], _CodecDocumentType]:
@ -852,7 +851,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
:param codec_options: A :class:`~bson.codec_options.CodecOptions`
instance.
:param session: A
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: additional keyword arguments will
@ -922,7 +921,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
value: Any = 1,
read_preference: Optional[_ServerMode] = None,
codec_options: Optional[CodecOptions[_CodecDocumentType]] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
max_await_time_ms: Optional[int] = None,
**kwargs: Any,
@ -953,7 +952,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
:param codec_options`: A :class:`~bson.codec_options.CodecOptions`
instance.
:param session: A
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to future getMores for this
command.
:param max_await_time_ms: The number of ms to wait for more data on future getMores for this command.
@ -1024,15 +1023,15 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
self,
command: Union[str, MutableMapping[str, Any]],
operation: str,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
) -> dict[str, Any]:
"""Same as command but used for retryable read commands."""
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
async def _cmd(
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
_server: Server,
conn: Connection,
conn: AsyncConnection,
read_preference: _ServerMode,
) -> dict[str, Any]:
return await self._command(
@ -1046,8 +1045,8 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
async def _list_collections(
self,
conn: Connection,
session: Optional[ClientSession],
conn: AsyncConnection,
session: Optional[AsyncClientSession],
read_preference: _ServerMode,
**kwargs: Any,
) -> AsyncCommandCursor[MutableMapping[str, Any]]:
@ -1075,7 +1074,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
async def _list_collections_helper(
self,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
filter: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
@ -1083,7 +1082,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
"""Get a cursor over the collections of this database.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param filter: A query document to filter the list of
collections returned from the listCollections command.
:param comment: A user-provided comment to attach to this
@ -1106,9 +1105,9 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
kwargs["comment"] = comment
async def _cmd(
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
_server: Server,
conn: Connection,
conn: AsyncConnection,
read_preference: _ServerMode,
) -> AsyncCommandCursor[MutableMapping[str, Any]]:
return await self._list_collections(
@ -1121,7 +1120,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
async def list_collections(
self,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
filter: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
@ -1129,7 +1128,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
"""Get a cursor over the collections of this database.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param filter: A query document to filter the list of
collections returned from the listCollections command.
:param comment: A user-provided comment to attach to this
@ -1149,7 +1148,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
async def _list_collection_names(
self,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
filter: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
@ -1174,7 +1173,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
async def list_collection_names(
self,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
filter: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
@ -1187,7 +1186,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
db.list_collection_names(filter=filter)
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param filter: A query document to filter the list of
collections returned from the listCollections command.
:param comment: A user-provided comment to attach to this
@ -1207,7 +1206,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
return await self._list_collection_names(session, filter, comment, **kwargs)
async def _drop_helper(
self, name: str, session: Optional[ClientSession] = None, comment: Optional[Any] = None
self, name: str, session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None
) -> dict[str, Any]:
command = {"drop": name}
if comment is not None:
@ -1227,7 +1226,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
async def drop_collection(
self,
name_or_collection: Union[str, AsyncCollection[_DocumentTypeArg]],
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
encrypted_fields: Optional[Mapping[str, Any]] = None,
) -> dict[str, Any]:
@ -1236,7 +1235,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
:param name_or_collection: the name of a collection to drop or the
collection object itself
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param encrypted_fields: **(BETA)** Document that describes the encrypted fields for
@ -1306,7 +1305,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
name_or_collection: Union[str, AsyncCollection[_DocumentTypeArg]],
scandata: bool = False,
full: bool = False,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
background: Optional[bool] = None,
comment: Optional[Any] = None,
) -> dict[str, Any]:
@ -1326,7 +1325,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
of the structure of the collection and the individual
documents.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param background: A boolean flag that determines whether
the command runs in the background. Requires MongoDB 4.4+.
:param comment: A user-provided comment to attach to this
@ -1386,7 +1385,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
async def dereference(
self,
dbref: DBRef,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> Optional[_DocumentType]:
@ -1401,7 +1400,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
:param dbref: the reference
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: any additional keyword arguments

View File

@ -59,16 +59,13 @@ from bson.errors import BSONError
from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson
from pymongo import _csot
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.common import CONNECT_TIMEOUT
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.encryption_options import AutoEncryptionOpts, RangeOpts
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.operations import UpdateOne
from pymongo.asynchronous.pool import PoolOptions, _configured_socket, _raise_connection_failure
from pymongo.asynchronous.typings import _DocumentType, _DocumentTypeArg
from pymongo.asynchronous.uri_parser import parse_host
from pymongo.asynchronous.pool import _configured_socket, _raise_connection_failure
from pymongo.common import CONNECT_TIMEOUT
from pymongo.daemon import _spawn_daemon
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
from pymongo.errors import (
ConfigurationError,
EncryptedCollectionError,
@ -78,9 +75,13 @@ from pymongo.errors import (
ServerSelectionTimeoutError,
)
from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall
from pymongo.operations import UpdateOne
from pymongo.pool_options import PoolOptions
from pymongo.read_concern import ReadConcern
from pymongo.results import BulkWriteResult, DeleteResult
from pymongo.ssl_support import get_ssl_context
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser import parse_host
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
@ -381,7 +382,10 @@ class _Encrypter:
)
io_callbacks = _EncryptionIO( # type:ignore[misc]
metadata_client, key_vault_coll, mongocryptd_client, opts
metadata_client,
key_vault_coll, # type:ignore[arg-type]
mongocryptd_client,
opts,
)
self._auto_encrypter = AsyncAutoEncrypter(
io_callbacks,
@ -459,7 +463,14 @@ class Algorithm(str, enum.Enum):
RANGE = "Range"
"""Range.
.. versionadded:: 4.8
.. versionadded:: 4.9
"""
RANGEPREVIEW = "RangePreview"
"""**DEPRECATED** - RangePreview.
.. note:: Support for RangePreview is deprecated. Use :attr:`Algorithm.RANGE` instead.
.. versionadded:: 4.4
"""
@ -473,7 +484,18 @@ class QueryType(str, enum.Enum):
"""Used to encrypt a value for an equality query."""
RANGE = "range"
"""Used to encrypt a value for a range query."""
"""Used to encrypt a value for a range query.
.. versionadded:: 4.9
"""
RANGEPREVIEW = "RangePreview"
"""**DEPRECATED** - Used to encrypt a value for a rangePreview query.
.. note:: Support for RangePreview is deprecated. Use :attr:`QueryType.RANGE` instead.
.. versionadded:: 4.4
"""
def _create_mongocrypt_options(**kwargs: Any) -> MongoCryptOptions:
@ -570,7 +592,7 @@ class ClientEncryption(Generic[_DocumentType]):
raise ConfigurationError(
"client-side field level encryption requires the pymongocrypt "
"library: install a compatible version with: "
"python -m pip install 'pymongo[encryption]'"
"python -m pip install --upgrade 'pymongo[encryption]'"
)
if not isinstance(codec_options, CodecOptions):
@ -843,7 +865,7 @@ class ClientEncryption(Generic[_DocumentType]):
:return: The encrypted value, a :class:`~bson.binary.Binary` with subtype 6.
.. versionchanged:: 4.8
.. versionchanged:: 4.9
Added the `range_opts` parameter.
.. versionchanged:: 4.7
@ -899,7 +921,7 @@ class ClientEncryption(Generic[_DocumentType]):
:return: The encrypted expression, a :class:`~bson.RawBSONDocument`.
.. versionchanged:: 4.8
.. versionchanged:: 4.9
Added the `range_opts` parameter.
.. versionchanged:: 4.7

View File

@ -1,272 +0,0 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support for automatic client-side field level encryption."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional
try:
import pymongocrypt # type:ignore[import] # noqa: F401
_HAVE_PYMONGOCRYPT = True
except ImportError:
_HAVE_PYMONGOCRYPT = False
from bson import int64
from pymongo.asynchronous.common import validate_is_mapping
from pymongo.asynchronous.uri_parser import _parse_kms_tls_options
from pymongo.errors import ConfigurationError
if TYPE_CHECKING:
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.typings import _DocumentTypeArg
_IS_SYNC = False
class AutoEncryptionOpts:
"""Options to configure automatic client-side field level encryption."""
def __init__(
self,
kms_providers: Mapping[str, Any],
key_vault_namespace: str,
key_vault_client: Optional[AsyncMongoClient[_DocumentTypeArg]] = None,
schema_map: Optional[Mapping[str, Any]] = None,
bypass_auto_encryption: bool = False,
mongocryptd_uri: str = "mongodb://localhost:27020",
mongocryptd_bypass_spawn: bool = False,
mongocryptd_spawn_path: str = "mongocryptd",
mongocryptd_spawn_args: Optional[list[str]] = None,
kms_tls_options: Optional[Mapping[str, Any]] = None,
crypt_shared_lib_path: Optional[str] = None,
crypt_shared_lib_required: bool = False,
bypass_query_analysis: bool = False,
encrypted_fields_map: Optional[Mapping[str, Any]] = None,
) -> None:
"""Options to configure automatic client-side field level encryption.
Automatic client-side field level encryption requires MongoDB >=4.2
enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not
supported for operations on a database or view and will result in
error.
Although automatic encryption requires MongoDB >=4.2 enterprise or a
MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all
users. To configure automatic *decryption* without automatic
*encryption* set ``bypass_auto_encryption=True``. Explicit
encryption and explicit decryption is also supported for all users
with the :class:`~pymongo.encryption.ClientEncryption` class.
See :ref:`automatic-client-side-encryption` for an example.
:param kms_providers: Map of KMS provider options. The `kms_providers`
map values differ by provider:
- `aws`: Map with "accessKeyId" and "secretAccessKey" as strings.
These are the AWS access key ID and AWS secret access key used
to generate KMS messages. An optional "sessionToken" may be
included to support temporary AWS credentials.
- `azure`: Map with "tenantId", "clientId", and "clientSecret" as
strings. Additionally, "identityPlatformEndpoint" may also be
specified as a string (defaults to 'login.microsoftonline.com').
These are the Azure Active Directory credentials used to
generate Azure Key Vault messages.
- `gcp`: Map with "email" as a string and "privateKey"
as `bytes` or a base64 encoded string.
Additionally, "endpoint" may also be specified as a string
(defaults to 'oauth2.googleapis.com'). These are the
credentials used to generate Google Cloud KMS messages.
- `kmip`: Map with "endpoint" as a host with required port.
For example: ``{"endpoint": "example.com:443"}``.
- `local`: Map with "key" as `bytes` (96 bytes in length) or
a base64 encoded string which decodes
to 96 bytes. "key" is the master key used to encrypt/decrypt
data keys. This key should be generated and stored as securely
as possible.
KMS providers may be specified with an optional name suffix
separated by a colon, for example "kmip:name" or "aws:name".
Named KMS providers do not support :ref:`CSFLE on-demand credentials`.
Named KMS providers enables more than one of each KMS provider type to be configured.
For example, to configure multiple local KMS providers::
kms_providers = {
"local": {"key": local_kek1}, # Unnamed KMS provider.
"local:myname": {"key": local_kek2}, # Named KMS provider with name "myname".
}
:param key_vault_namespace: The namespace for the key vault collection.
The key vault collection contains all data keys used for encryption
and decryption. Data keys are stored as documents in this MongoDB
collection. Data keys are protected with encryption by a KMS
provider.
:param key_vault_client: By default, the key vault collection
is assumed to reside in the same MongoDB cluster as the encrypted
AsyncMongoClient. Use this option to route data key queries to a
separate MongoDB cluster.
:param schema_map: Map of collection namespace ("db.coll") to
JSON Schema. By default, a collection's JSONSchema is periodically
polled with the listCollections command. But a JSONSchema may be
specified locally with the schemaMap option.
**Supplying a `schema_map` provides more security than relying on
JSON Schemas obtained from the server. It protects against a
malicious server advertising a false JSON Schema, which could trick
the client into sending unencrypted data that should be
encrypted.**
Schemas supplied in the schemaMap only apply to configuring
automatic encryption for client side encryption. Other validation
rules in the JSON schema will not be enforced by the driver and
will result in an error.
:param bypass_auto_encryption: If ``True``, automatic
encryption will be disabled but automatic decryption will still be
enabled. Defaults to ``False``.
:param mongocryptd_uri: The MongoDB URI used to connect
to the *local* mongocryptd process. Defaults to
``'mongodb://localhost:27020'``.
:param mongocryptd_bypass_spawn: If ``True``, the encrypted
AsyncMongoClient will not attempt to spawn the mongocryptd process.
Defaults to ``False``.
:param mongocryptd_spawn_path: Used for spawning the
mongocryptd process. Defaults to ``'mongocryptd'`` and spawns
mongocryptd from the system path.
:param mongocryptd_spawn_args: A list of string arguments to
use when spawning the mongocryptd process. Defaults to
``['--idleShutdownTimeoutSecs=60']``. If the list does not include
the ``idleShutdownTimeoutSecs`` option then
``'--idleShutdownTimeoutSecs=60'`` will be added.
:param kms_tls_options: A map of KMS provider names to TLS
options to use when creating secure connections to KMS providers.
Accepts the same TLS options as
:class:`pymongo.mongo_client.AsyncMongoClient`. For example, to
override the system default CA file::
kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}}
Or to supply a client certificate::
kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}}
:param crypt_shared_lib_path: Override the path to load the crypt_shared library.
:param crypt_shared_lib_required: If True, raise an error if libmongocrypt is
unable to load the crypt_shared library.
:param bypass_query_analysis: If ``True``, disable automatic analysis
of outgoing commands. Set `bypass_query_analysis` to use explicit
encryption on indexed fields without the MongoDB Enterprise Advanced
licensed crypt_shared library.
:param encrypted_fields_map: Map of collection namespace ("db.coll") to documents
that described the encrypted fields for Queryable Encryption. For example::
{
"db.encryptedCollection": {
"escCollection": "enxcol_.encryptedCollection.esc",
"ecocCollection": "enxcol_.encryptedCollection.ecoc",
"fields": [
{
"path": "firstName",
"keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')),
"bsonType": "string",
"queries": {"queryType": "equality"}
},
{
"path": "ssn",
"keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')),
"bsonType": "string"
}
]
}
}
.. versionchanged:: 4.2
Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`,
and `bypass_query_analysis` parameters.
.. versionchanged:: 4.0
Added the `kms_tls_options` parameter and the "kmip" KMS provider.
.. versionadded:: 3.9
"""
if not _HAVE_PYMONGOCRYPT:
raise ConfigurationError(
"client side encryption requires the pymongocrypt library: "
"install a compatible version with: "
"python -m pip install 'pymongo[encryption]'"
)
if encrypted_fields_map:
validate_is_mapping("encrypted_fields_map", encrypted_fields_map)
self._encrypted_fields_map = encrypted_fields_map
self._bypass_query_analysis = bypass_query_analysis
self._crypt_shared_lib_path = crypt_shared_lib_path
self._crypt_shared_lib_required = crypt_shared_lib_required
self._kms_providers = kms_providers
self._key_vault_namespace = key_vault_namespace
self._key_vault_client = key_vault_client
self._schema_map = schema_map
self._bypass_auto_encryption = bypass_auto_encryption
self._mongocryptd_uri = mongocryptd_uri
self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
self._mongocryptd_spawn_path = mongocryptd_spawn_path
if mongocryptd_spawn_args is None:
mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"]
self._mongocryptd_spawn_args = mongocryptd_spawn_args
if not isinstance(self._mongocryptd_spawn_args, list):
raise TypeError("mongocryptd_spawn_args must be a list")
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
# Maps KMS provider name to a SSLContext.
self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options)
self._bypass_query_analysis = bypass_query_analysis
class RangeOpts:
"""Options to configure encrypted queries using the range algorithm."""
def __init__(
self,
sparsity: int,
trim_factor: int,
min: Optional[Any] = None,
max: Optional[Any] = None,
precision: Optional[int] = None,
) -> None:
"""Options to configure encrypted queries using the range algorithm.
:param sparsity: An integer.
:param trim_factor: An integer.
:param min: A BSON scalar value corresponding to the type being queried.
:param max: A BSON scalar value corresponding to the type being queried.
:param precision: An integer, may only be set for double or decimal128 types.
.. versionadded:: 4.4
"""
self.min = min
self.max = max
self.sparsity = sparsity
self.trim_factor = trim_factor
self.precision = precision
@property
def document(self) -> dict[str, Any]:
doc = {}
for k, v in [
("sparsity", int64.Int64(self.sparsity)),
("trimFactor", self.trim_factor),
("precision", self.precision),
("min", self.min),
("max", self.max),
]:
if v is not None:
doc[k] = v
return doc

View File

@ -1,225 +0,0 @@
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example event logger classes.
.. versionadded:: 3.11
These loggers can be registered using :func:`register` or
:class:`~pymongo.mongo_client.MongoClient`.
``monitoring.register(CommandLogger())``
or
``MongoClient(event_listeners=[CommandLogger()])``
"""
from __future__ import annotations
import logging
from pymongo.asynchronous import monitoring
_IS_SYNC = False
class CommandLogger(monitoring.CommandListener):
"""A simple listener that logs command events.
Listens for :class:`~pymongo.monitoring.CommandStartedEvent`,
:class:`~pymongo.monitoring.CommandSucceededEvent` and
:class:`~pymongo.monitoring.CommandFailedEvent` events and
logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def started(self, event: monitoring.CommandStartedEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} started on server "
f"{event.connection_id}"
)
def succeeded(self, event: monitoring.CommandSucceededEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} on server {event.connection_id} "
f"succeeded in {event.duration_micros} "
"microseconds"
)
def failed(self, event: monitoring.CommandFailedEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} on server {event.connection_id} "
f"failed in {event.duration_micros} "
"microseconds"
)
class ServerLogger(monitoring.ServerListener):
"""A simple listener that logs server discovery events.
Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`,
:class:`~pymongo.monitoring.ServerDescriptionChangedEvent`,
and :class:`~pymongo.monitoring.ServerClosedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def opened(self, event: monitoring.ServerOpeningEvent) -> None:
logging.info(f"Server {event.server_address} added to topology {event.topology_id}")
def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None:
previous_server_type = event.previous_description.server_type
new_server_type = event.new_description.server_type
if new_server_type != previous_server_type:
# server_type_name was added in PyMongo 3.4
logging.info(
f"Server {event.server_address} changed type from "
f"{event.previous_description.server_type_name} to "
f"{event.new_description.server_type_name}"
)
def closed(self, event: monitoring.ServerClosedEvent) -> None:
logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}")
class HeartbeatLogger(monitoring.ServerHeartbeatListener):
"""A simple listener that logs server heartbeat events.
Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`,
:class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`,
and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
logging.info(f"Heartbeat sent to server {event.connection_id}")
def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
# The reply.document attribute was added in PyMongo 3.4.
logging.info(
f"Heartbeat to server {event.connection_id} "
"succeeded with reply "
f"{event.reply.document}"
)
def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
logging.warning(
f"Heartbeat to server {event.connection_id} failed with error {event.reply}"
)
class TopologyLogger(monitoring.TopologyListener):
"""A simple listener that logs server topology events.
Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`,
:class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`,
and :class:`~pymongo.monitoring.TopologyClosedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def opened(self, event: monitoring.TopologyOpenedEvent) -> None:
logging.info(f"Topology with id {event.topology_id} opened")
def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None:
logging.info(f"Topology description updated for topology id {event.topology_id}")
previous_topology_type = event.previous_description.topology_type
new_topology_type = event.new_description.topology_type
if new_topology_type != previous_topology_type:
# topology_type_name was added in PyMongo 3.4
logging.info(
f"Topology {event.topology_id} changed type from "
f"{event.previous_description.topology_type_name} to "
f"{event.new_description.topology_type_name}"
)
# The has_writable_server and has_readable_server methods
# were added in PyMongo 3.4.
if not event.new_description.has_writable_server():
logging.warning("No writable servers available.")
if not event.new_description.has_readable_server():
logging.warning("No readable servers available.")
def closed(self, event: monitoring.TopologyClosedEvent) -> None:
logging.info(f"Topology with id {event.topology_id} closed")
class ConnectionPoolLogger(monitoring.ConnectionPoolListener):
"""A simple listener that logs server connection pool events.
Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`,
:class:`~pymongo.monitoring.PoolClearedEvent`,
:class:`~pymongo.monitoring.PoolClosedEvent`,
:~pymongo.monitoring.class:`ConnectionCreatedEvent`,
:class:`~pymongo.monitoring.ConnectionReadyEvent`,
:class:`~pymongo.monitoring.ConnectionClosedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckedOutEvent`,
and :class:`~pymongo.monitoring.ConnectionCheckedInEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def pool_created(self, event: monitoring.PoolCreatedEvent) -> None:
logging.info(f"[pool {event.address}] pool created")
def pool_ready(self, event: monitoring.PoolReadyEvent) -> None:
logging.info(f"[pool {event.address}] pool ready")
def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None:
logging.info(f"[pool {event.address}] pool cleared")
def pool_closed(self, event: monitoring.PoolClosedEvent) -> None:
logging.info(f"[pool {event.address}] pool closed")
def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None:
logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created")
def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded"
)
def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] "
f'connection closed, reason: "{event.reason}"'
)
def connection_check_out_started(
self, event: monitoring.ConnectionCheckOutStartedEvent
) -> None:
logging.info(f"[pool {event.address}] connection check out started")
def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None:
logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}")
def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool"
)
def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool"
)

View File

@ -1,26 +0,0 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The HelloCompat class, placed here to break circular import issues."""
from __future__ import annotations
_IS_SYNC = False
class HelloCompat:
CMD = "hello"
LEGACY_CMD = "ismaster"
PRIMARY = "isWritablePrimary"
LEGACY_PRIMARY = "ismaster"
LEGACY_ERROR = "not master"

View File

@ -1,4 +1,4 @@
# Copyright 2009-present MongoDB, Inc.
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,270 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bits and pieces used by the driver that don't really fit elsewhere."""
"""Miscellaneous pieces that need to be synchronized."""
from __future__ import annotations
import builtins
import sys
import traceback
from collections import abc
from typing import (
TYPE_CHECKING,
Any,
Callable,
Container,
Iterable,
Mapping,
NoReturn,
Optional,
Sequence,
TypeVar,
Union,
cast,
)
from pymongo import ASCENDING
from pymongo.asynchronous.hello_compat import HelloCompat
from pymongo.errors import (
CursorNotFound,
DuplicateKeyError,
ExecutionTimeout,
NotPrimaryError,
OperationFailure,
WriteConcernError,
WriteError,
WTimeoutError,
_wtimeout_error,
)
from pymongo.helpers_constants import _NOT_PRIMARY_CODES, _REAUTHENTICATION_REQUIRED_CODE
if TYPE_CHECKING:
from pymongo.asynchronous.operations import _IndexList
from pymongo.asynchronous.typings import _DocumentOut
from pymongo.cursor_shared import _Hint
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
_IS_SYNC = False
def _gen_index_name(keys: _IndexList) -> str:
"""Generate an index name from the set of fields it is over."""
return "_".join(["{}_{}".format(*item) for item in keys])
def _index_list(
key_or_list: _Hint, direction: Optional[Union[int, str]] = None
) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]:
"""Helper to generate a list of (key, direction) pairs.
Takes such a list, or a single key, or a single key and direction.
"""
if direction is not None:
if not isinstance(key_or_list, str):
raise TypeError("Expected a string and a direction")
return [(key_or_list, direction)]
else:
if isinstance(key_or_list, str):
return [(key_or_list, ASCENDING)]
elif isinstance(key_or_list, abc.ItemsView):
return list(key_or_list) # type: ignore[arg-type]
elif isinstance(key_or_list, abc.Mapping):
return list(key_or_list.items())
elif not isinstance(key_or_list, (list, tuple)):
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
values: list[tuple[str, int]] = []
for item in key_or_list:
if isinstance(item, str):
item = (item, ASCENDING) # noqa: PLW2901
values.append(item)
return values
def _index_document(index_list: _IndexList) -> dict[str, Any]:
"""Helper to generate an index specifying document.
Takes a list of (key, direction) pairs.
"""
if not isinstance(index_list, (list, tuple, abc.Mapping)):
raise TypeError(
"must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list)
)
if not len(index_list):
raise ValueError("key_or_list must not be empty")
index: dict[str, Any] = {}
if isinstance(index_list, abc.Mapping):
for key in index_list:
value = index_list[key]
_validate_index_key_pair(key, value)
index[key] = value
else:
for item in index_list:
if isinstance(item, str):
item = (item, ASCENDING) # noqa: PLW2901
key, value = item
_validate_index_key_pair(key, value)
index[key] = value
return index
def _validate_index_key_pair(key: Any, value: Any) -> None:
if not isinstance(key, str):
raise TypeError("first item in each key pair must be an instance of str")
if not isinstance(value, (str, int, abc.Mapping)):
raise TypeError(
"second item in each key pair must be 1, -1, "
"'2d', or another valid MongoDB index specifier."
)
def _check_command_response(
response: _DocumentOut,
max_wire_version: Optional[int],
allowable_errors: Optional[Container[Union[int, str]]] = None,
parse_write_concern_error: bool = False,
) -> None:
"""Check the response to a command for errors."""
if "ok" not in response:
# Server didn't recognize our message as a command.
raise OperationFailure(
response.get("$err"), # type: ignore[arg-type]
response.get("code"),
response,
max_wire_version,
)
if parse_write_concern_error and "writeConcernError" in response:
_error = response["writeConcernError"]
_labels = response.get("errorLabels")
if _labels:
_error.update({"errorLabels": _labels})
_raise_write_concern_error(_error)
if response["ok"]:
return
details = response
# Mongos returns the error details in a 'raw' object
# for some errors.
if "raw" in response:
for shard in response["raw"].values():
# Grab the first non-empty raw error from a shard.
if shard.get("errmsg") and not shard.get("ok"):
details = shard
break
errmsg = details["errmsg"]
code = details.get("code")
# For allowable errors, only check for error messages when the code is not
# included.
if allowable_errors:
if code is not None:
if code in allowable_errors:
return
elif errmsg in allowable_errors:
return
# Server is "not primary" or "recovering"
if code is not None:
if code in _NOT_PRIMARY_CODES:
raise NotPrimaryError(errmsg, response)
elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg:
raise NotPrimaryError(errmsg, response)
# Other errors
# findAndModify with upsert can raise duplicate key error
if code in (11000, 11001, 12582):
raise DuplicateKeyError(errmsg, code, response, max_wire_version)
elif code == 50:
raise ExecutionTimeout(errmsg, code, response, max_wire_version)
elif code == 43:
raise CursorNotFound(errmsg, code, response, max_wire_version)
raise OperationFailure(errmsg, code, response, max_wire_version)
def _raise_last_write_error(write_errors: list[Any]) -> NoReturn:
# If the last batch had multiple errors only report
# the last error to emulate continue_on_error.
error = write_errors[-1]
if error.get("code") == 11000:
raise DuplicateKeyError(error.get("errmsg"), 11000, error)
raise WriteError(error.get("errmsg"), error.get("code"), error)
def _raise_write_concern_error(error: Any) -> NoReturn:
if _wtimeout_error(error):
# Make sure we raise WTimeoutError
raise WTimeoutError(error.get("errmsg"), error.get("code"), error)
raise WriteConcernError(error.get("errmsg"), error.get("code"), error)
def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
"""Return the writeConcernError or None."""
wce = result.get("writeConcernError")
if wce:
# The server reports errorLabels at the top level but it's more
# convenient to attach it to the writeConcernError doc itself.
error_labels = result.get("errorLabels")
if error_labels:
# Copy to avoid changing the original document.
wce = wce.copy()
wce["errorLabels"] = error_labels
return wce
def _check_write_command_response(result: Mapping[str, Any]) -> None:
"""Backward compatibility helper for write command error handling."""
# Prefer write errors over write concern errors
write_errors = result.get("writeErrors")
if write_errors:
_raise_last_write_error(write_errors)
wce = _get_wce_doc(result)
if wce:
_raise_write_concern_error(wce)
def _fields_list_to_dict(
fields: Union[Mapping[str, Any], Iterable[str]], option_name: str
) -> Mapping[str, Any]:
"""Takes a sequence of field names and returns a matching dictionary.
["a", "b"] becomes {"a": 1, "b": 1}
and
["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1}
"""
if isinstance(fields, abc.Mapping):
return fields
if isinstance(fields, (abc.Sequence, abc.Set)):
if not all(isinstance(field, str) for field in fields):
raise TypeError(f"{option_name} must be a list of key names, each an instance of str")
return dict.fromkeys(fields, 1)
raise TypeError(f"{option_name} must be a mapping or list of key names")
def _handle_exception() -> None:
"""Print exceptions raised by subscribers to stderr."""
# Heavily influenced by logging.Handler.handleError.
# See note here:
# https://docs.python.org/3.4/library/sys.html#sys.__stderr__
if sys.stderr:
einfo = sys.exc_info()
try:
traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr)
except OSError:
pass
finally:
del einfo
# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories
F = TypeVar("F", bound=Callable[..., Any])
@ -283,8 +38,8 @@ F = TypeVar("F", bound=Callable[..., Any])
def _handle_reauth(func: F) -> F:
async def inner(*args: Any, **kwargs: Any) -> Any:
no_reauth = kwargs.pop("no_reauth", False)
from pymongo.asynchronous.message import _BulkWriteContext
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.message import _BulkWriteContext
try:
return await func(*args, **kwargs)
@ -292,16 +47,16 @@ def _handle_reauth(func: F) -> F:
if no_reauth:
raise
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
# Look for an argument that either is a Connection
# Look for an argument that either is a AsyncConnection
# or has a connection attribute, so we can trigger
# a reauth.
conn = None
for arg in args:
if isinstance(arg, Connection):
if isinstance(arg, AsyncConnection):
conn = arg
break
if isinstance(arg, _BulkWriteContext):
conn = arg.conn
conn = arg.conn # type: ignore[assignment]
break
if conn:
await conn.authenticate(reauthenticate=True)

View File

@ -1,171 +0,0 @@
# Copyright 2023-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import enum
import logging
import os
import warnings
from typing import Any
from bson import UuidRepresentation, json_util
from bson.json_util import JSONOptions, _truncate_documents
from pymongo.asynchronous.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
_IS_SYNC = False
class _CommandStatusMessage(str, enum.Enum):
STARTED = "Command started"
SUCCEEDED = "Command succeeded"
FAILED = "Command failed"
class _ServerSelectionStatusMessage(str, enum.Enum):
STARTED = "Server selection started"
SUCCEEDED = "Server selection succeeded"
FAILED = "Server selection failed"
WAITING = "Waiting for suitable server to become available"
class _ConnectionStatusMessage(str, enum.Enum):
POOL_CREATED = "Connection pool created"
POOL_READY = "Connection pool ready"
POOL_CLOSED = "Connection pool closed"
POOL_CLEARED = "Connection pool cleared"
CONN_CREATED = "Connection created"
CONN_READY = "Connection ready"
CONN_CLOSED = "Connection closed"
CHECKOUT_STARTED = "Connection checkout started"
CHECKOUT_SUCCEEDED = "Connection checked out"
CHECKOUT_FAILED = "Connection checkout failed"
CHECKEDIN = "Connection checked in"
_DEFAULT_DOCUMENT_LENGTH = 1000
_SENSITIVE_COMMANDS = [
"authenticate",
"saslStart",
"saslContinue",
"getnonce",
"createUser",
"updateUser",
"copydbgetnonce",
"copydbsaslstart",
"copydb",
]
_HELLO_COMMANDS = ["hello", "ismaster", "isMaster"]
_REDACTED_FAILURE_FIELDS = ["code", "codeName", "errorLabels"]
_DOCUMENT_NAMES = ["command", "reply", "failure"]
_JSON_OPTIONS = JSONOptions(uuid_representation=UuidRepresentation.STANDARD)
_COMMAND_LOGGER = logging.getLogger("pymongo.command")
_CONNECTION_LOGGER = logging.getLogger("pymongo.connection")
_SERVER_SELECTION_LOGGER = logging.getLogger("pymongo.serverSelection")
_CLIENT_LOGGER = logging.getLogger("pymongo.client")
_VERBOSE_CONNECTION_ERROR_REASONS = {
ConnectionClosedReason.POOL_CLOSED: "Connection pool was closed",
ConnectionCheckOutFailedReason.POOL_CLOSED: "Connection pool was closed",
ConnectionClosedReason.STALE: "Connection pool was stale",
ConnectionClosedReason.ERROR: "An error occurred while using the connection",
ConnectionCheckOutFailedReason.CONN_ERROR: "An error occurred while trying to establish a new connection",
ConnectionClosedReason.IDLE: "Connection was idle too long",
ConnectionCheckOutFailedReason.TIMEOUT: "Connection exceeded the specified timeout",
}
def _debug_log(logger: logging.Logger, **fields: Any) -> None:
logger.debug(LogMessage(**fields))
def _verbose_connection_error_reason(reason: str) -> str:
return _VERBOSE_CONNECTION_ERROR_REASONS.get(reason, reason)
def _info_log(logger: logging.Logger, **fields: Any) -> None:
logger.info(LogMessage(**fields))
def _log_or_warn(logger: logging.Logger, message: str) -> None:
if logger.isEnabledFor(logging.INFO):
logger.info(message)
else:
# stacklevel=4 ensures that the warning is for the user's code.
warnings.warn(message, UserWarning, stacklevel=4)
class LogMessage:
__slots__ = ("_kwargs", "_redacted")
def __init__(self, **kwargs: Any):
self._kwargs = kwargs
self._redacted = False
def __str__(self) -> str:
self._redact()
return "%s" % (
json_util.dumps(
self._kwargs, json_options=_JSON_OPTIONS, default=lambda o: o.__repr__()
)
)
def _is_sensitive(self, doc_name: str) -> bool:
is_speculative_authenticate = (
self._kwargs.pop("speculative_authenticate", False)
or "speculativeAuthenticate" in self._kwargs[doc_name]
)
is_sensitive_command = (
"commandName" in self._kwargs and self._kwargs["commandName"] in _SENSITIVE_COMMANDS
)
is_sensitive_hello = (
self._kwargs["commandName"] in _HELLO_COMMANDS and is_speculative_authenticate
)
return is_sensitive_command or is_sensitive_hello
def _redact(self) -> None:
if self._redacted:
return
self._kwargs = {k: v for k, v in self._kwargs.items() if v is not None}
if "durationMS" in self._kwargs and hasattr(self._kwargs["durationMS"], "total_seconds"):
self._kwargs["durationMS"] = self._kwargs["durationMS"].total_seconds() * 1000
if "serviceId" in self._kwargs:
self._kwargs["serviceId"] = str(self._kwargs["serviceId"])
document_length = int(os.getenv("MONGOB_LOG_MAX_DOCUMENT_LENGTH", _DEFAULT_DOCUMENT_LENGTH))
if document_length < 0:
document_length = _DEFAULT_DOCUMENT_LENGTH
is_server_side_error = self._kwargs.pop("isServerSideError", False)
for doc_name in _DOCUMENT_NAMES:
doc = self._kwargs.get(doc_name)
if doc:
if doc_name == "failure" and is_server_side_error:
doc = {k: v for k, v in doc.items() if k in _REDACTED_FAILURE_FIELDS}
if doc_name != "failure" and self._is_sensitive(doc_name):
doc = json_util.dumps({})
else:
truncated_doc = _truncate_documents(doc, document_length)[0]
doc = json_util.dumps(
truncated_doc,
json_options=_JSON_OPTIONS,
default=lambda o: o.__repr__(),
)
if len(doc) > document_length:
doc = (
doc.encode()[:document_length].decode("unicode-escape", "ignore")
) + "..."
self._kwargs[doc_name] = doc
self._redacted = True

View File

@ -1,125 +0,0 @@
# Copyright 2016 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Criteria to select ServerDescriptions based on maxStalenessSeconds.
The Max Staleness Spec says: When there is a known primary P,
a secondary S's staleness is estimated with this formula:
(S.lastUpdateTime - S.lastWriteDate) - (P.lastUpdateTime - P.lastWriteDate)
+ heartbeatFrequencyMS
When there is no known primary, a secondary S's staleness is estimated with:
SMax.lastWriteDate - S.lastWriteDate + heartbeatFrequencyMS
where "SMax" is the secondary with the greatest lastWriteDate.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from pymongo.errors import ConfigurationError
from pymongo.server_type import SERVER_TYPE
if TYPE_CHECKING:
from pymongo.asynchronous.server_selectors import Selection
_IS_SYNC = False
# Constant defined in Max Staleness Spec: An idle primary writes a no-op every
# 10 seconds to refresh secondaries' lastWriteDate values.
IDLE_WRITE_PERIOD = 10
SMALLEST_MAX_STALENESS = 90
def _validate_max_staleness(max_staleness: int, heartbeat_frequency: int) -> None:
# We checked for max staleness -1 before this, it must be positive here.
if max_staleness < heartbeat_frequency + IDLE_WRITE_PERIOD:
raise ConfigurationError(
"maxStalenessSeconds must be at least heartbeatFrequencyMS +"
" %d seconds. maxStalenessSeconds is set to %d,"
" heartbeatFrequencyMS is set to %d."
% (IDLE_WRITE_PERIOD, max_staleness, heartbeat_frequency * 1000)
)
if max_staleness < SMALLEST_MAX_STALENESS:
raise ConfigurationError(
"maxStalenessSeconds must be at least %d. "
"maxStalenessSeconds is set to %d." % (SMALLEST_MAX_STALENESS, max_staleness)
)
def _with_primary(max_staleness: int, selection: Selection) -> Selection:
"""Apply max_staleness, in seconds, to a Selection with a known primary."""
primary = selection.primary
assert primary
sds = []
for s in selection.server_descriptions:
if s.server_type == SERVER_TYPE.RSSecondary:
# See max-staleness.rst for explanation of this formula.
assert s.last_write_date and primary.last_write_date # noqa: PT018
staleness = (
(s.last_update_time - s.last_write_date)
- (primary.last_update_time - primary.last_write_date)
+ selection.heartbeat_frequency
)
if staleness <= max_staleness:
sds.append(s)
else:
sds.append(s)
return selection.with_server_descriptions(sds)
def _no_primary(max_staleness: int, selection: Selection) -> Selection:
"""Apply max_staleness, in seconds, to a Selection with no known primary."""
# Secondary that's replicated the most recent writes.
smax = selection.secondary_with_max_last_write_date()
if not smax:
# No secondaries and no primary, short-circuit out of here.
return selection.with_server_descriptions([])
sds = []
for s in selection.server_descriptions:
if s.server_type == SERVER_TYPE.RSSecondary:
# See max-staleness.rst for explanation of this formula.
assert smax.last_write_date and s.last_write_date # noqa: PT018
staleness = smax.last_write_date - s.last_write_date + selection.heartbeat_frequency
if staleness <= max_staleness:
sds.append(s)
else:
sds.append(s)
return selection.with_server_descriptions(sds)
def select(max_staleness: int, selection: Selection) -> Selection:
"""Apply max_staleness, in seconds, to a Selection."""
if max_staleness == -1:
return selection
# Server Selection Spec: If the TopologyType is ReplicaSetWithPrimary or
# ReplicaSetNoPrimary, a client MUST raise an error if maxStaleness <
# heartbeatFrequency + IDLE_WRITE_PERIOD, or if maxStaleness < 90.
_validate_max_staleness(max_staleness, selection.heartbeat_frequency)
if selection.primary:
return _with_primary(max_staleness, selection)
else:
return _no_primary(max_staleness, selection)

File diff suppressed because it is too large Load Diff

View File

@ -58,42 +58,14 @@ from typing import (
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
from bson.timestamp import Timestamp
from pymongo import _csot, helpers_constants
from pymongo.asynchronous import (
client_session,
common,
database,
helpers,
message,
periodic_executor,
uri_parser,
)
from pymongo import _csot, common, helpers_shared, uri_parser
from pymongo.asynchronous import client_session, database, periodic_executor
from pymongo.asynchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.asynchronous.client_options import ClientOptions
from pymongo.asynchronous.client_session import _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.logger import _CLIENT_LOGGER, _log_or_warn
from pymongo.asynchronous.monitoring import ConnectionClosedReason
from pymongo.asynchronous.operations import _Op
from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode
from pymongo.asynchronous.server_selectors import writable_server_selector
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
from pymongo.asynchronous.topology_description import TOPOLOGY_TYPE, TopologyDescription
from pymongo.asynchronous.typings import (
ClusterTime,
_Address,
_CollationIn,
_DocumentType,
_DocumentTypeArg,
_Pipeline,
)
from pymongo.asynchronous.uri_parser import (
_check_options,
_handle_option_deprecations,
_handle_security_options,
_normalize_options,
)
from pymongo.client_options import ClientOptions
from pymongo.errors import (
AutoReconnect,
BulkWriteError,
@ -108,29 +80,52 @@ from pymongo.errors import (
WriteConcernError,
)
from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
from pymongo.message import _CursorAddress, _GetMore, _Query
from pymongo.monitoring import ConnectionClosedReason
from pymongo.operations import _Op
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.server_selectors import writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
from pymongo.typings import (
ClusterTime,
_Address,
_CollationIn,
_DocumentType,
_DocumentTypeArg,
_Pipeline,
)
from pymongo.uri_parser import (
_check_options,
_handle_option_deprecations,
_handle_security_options,
_normalize_options,
)
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
if TYPE_CHECKING:
from types import TracebackType
from bson.objectid import ObjectId
from pymongo.asynchronous.bulk import _Bulk
from pymongo.asynchronous.client_session import ClientSession, _ServerSession
from pymongo.asynchronous.bulk import _AsyncBulk
from pymongo.asynchronous.client_session import AsyncClientSession, _ServerSession
from pymongo.asynchronous.cursor import _ConnectionManager
from pymongo.asynchronous.message import _CursorAddress, _GetMore, _Query
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.response import Response
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.asynchronous.server import Server
from pymongo.asynchronous.server_selectors import Selection
from pymongo.read_concern import ReadConcern
from pymongo.response import Response
from pymongo.server_selectors import Selection
T = TypeVar("T")
_WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], Coroutine[Any, Any, T]]
_WriteCall = Callable[
[Optional["AsyncClientSession"], "AsyncConnection", bool], Coroutine[Any, Any, T]
]
_ReadCall = Callable[
[Optional["ClientSession"], "Server", "Connection", _ServerMode], Coroutine[Any, Any, T]
[Optional["AsyncClientSession"], "Server", "AsyncConnection", _ServerMode],
Coroutine[Any, Any, T],
]
_IS_SYNC = False
@ -892,7 +887,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self._kill_cursors_executor = executor
self._opened = False
def _should_pin_cursor(self, session: Optional[ClientSession]) -> Optional[bool]:
def _should_pin_cursor(self, session: Optional[AsyncClientSession]) -> Optional[bool]:
return self._options.load_balanced and not (session and session.in_transaction)
def _after_fork(self) -> None:
@ -915,7 +910,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
batch_size: Optional[int] = None,
collation: Optional[_CollationIn] = None,
start_at_operation_time: Optional[Timestamp] = None,
session: Optional[client_session.ClientSession] = None,
session: Optional[client_session.AsyncClientSession] = None,
start_after: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
full_document_before_change: Optional[str] = None,
@ -989,7 +984,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
the specified :class:`~bson.timestamp.Timestamp`. Requires
MongoDB >= 4.0.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param start_after: The same as `resume_after` except that
`start_after` can resume notifications after an invalidate event.
This option and `resume_after` are mutually exclusive.
@ -1163,30 +1158,30 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
"""Request that a cursor and/or connection be cleaned up soon."""
self._kill_cursors_queue.append((address, cursor_id, conn_mgr))
def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession:
def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession:
server_session = _EmptyServerSession()
opts = client_session.SessionOptions(**kwargs)
return client_session.ClientSession(self, server_session, opts, implicit)
return client_session.AsyncClientSession(self, server_session, opts, implicit)
def start_session(
self,
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[client_session.TransactionOptions] = None,
snapshot: Optional[bool] = False,
) -> client_session.ClientSession:
) -> client_session.AsyncClientSession:
"""Start a logical session.
This method takes the same parameters as
:class:`~pymongo.client_session.SessionOptions`. See the
:mod:`~pymongo.client_session` module for details and examples.
A :class:`~pymongo.client_session.ClientSession` may only be used with
the MongoClient that started it. :class:`ClientSession` instances are
A :class:`~pymongo.client_session.AsyncClientSession` may only be used with
the MongoClient that started it. :class:`AsyncClientSession` instances are
**not thread-safe or fork-safe**. They can only be used by one thread
or process at a time. A single :class:`ClientSession` cannot be used
or process at a time. A single :class:`AsyncClientSession` cannot be used
to run multiple operations concurrently.
:return: An instance of :class:`~pymongo.client_session.ClientSession`.
:return: An instance of :class:`~pymongo.client_session.AsyncClientSession`.
.. versionadded:: 3.6
"""
@ -1197,7 +1192,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
snapshot=snapshot,
)
def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]:
def _ensure_session(
self, session: Optional[AsyncClientSession] = None
) -> Optional[AsyncClientSession]:
"""If provided session is None, lend a temporary session."""
if session:
return session
@ -1211,7 +1208,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
return None
def _send_cluster_time(
self, command: MutableMapping[str, Any], session: Optional[ClientSession]
self, command: MutableMapping[str, Any], session: Optional[AsyncClientSession]
) -> None:
topology_time = self._topology.max_cluster_time()
session_time = session.cluster_time if session else None
@ -1474,7 +1471,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
async def _end_sessions(self, session_ids: list[_ServerSession]) -> None:
"""Send endSessions command(s) with the given session ids."""
try:
# Use Connection.command directly to avoid implicitly creating
# Use AsyncConnection.command directly to avoid implicitly creating
# another session.
async with await self._conn_for_reads(
ReadPreference.PRIMARY_PREFERRED, None, operation=_Op.END_SESSIONS
@ -1535,8 +1532,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
@contextlib.asynccontextmanager
async def _checkout(
self, server: Server, session: Optional[ClientSession]
) -> AsyncGenerator[Connection, None]:
self, server: Server, session: Optional[AsyncClientSession]
) -> AsyncGenerator[AsyncConnection, None]:
in_txn = session and session.in_transaction
async with _MongoClientErrorHandler(self, server, session) as err_handler:
# Reuse the pinned connection, if it exists.
@ -1570,7 +1567,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
async def _select_server(
self,
server_selector: Callable[[Selection], Selection],
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
operation: str,
address: Optional[_Address] = None,
deprioritized_servers: Optional[list[Server]] = None,
@ -1581,7 +1578,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
:Parameters:
- `server_selector`: The server selector to use if the session is
not pinned and no address is given.
- `session`: The ClientSession for the next operation, or None. May
- `session`: The AsyncClientSession for the next operation, or None. May
be pinned to a mongos server address.
- `address` (optional): Address when sending a message
to a specific server, used for getMore.
@ -1615,15 +1612,15 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
raise
async def _conn_for_writes(
self, session: Optional[ClientSession], operation: str
) -> AsyncContextManager[Connection]:
self, session: Optional[AsyncClientSession], operation: str
) -> AsyncContextManager[AsyncConnection]:
server = await self._select_server(writable_server_selector, session, operation)
return self._checkout(server, session)
@contextlib.asynccontextmanager
async def _conn_from_server(
self, read_preference: _ServerMode, server: Server, session: Optional[ClientSession]
) -> AsyncGenerator[tuple[Connection, _ServerMode], None]:
self, read_preference: _ServerMode, server: Server, session: Optional[AsyncClientSession]
) -> AsyncGenerator[tuple[AsyncConnection, _ServerMode], None]:
assert read_preference is not None, "read_preference must not be None"
# Get a connection for a server matching the read preference, and yield
# conn with the effective read preference. The Server Selection
@ -1648,9 +1645,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
async def _conn_for_reads(
self,
read_preference: _ServerMode,
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
operation: str,
) -> AsyncContextManager[tuple[Connection, _ServerMode]]:
) -> AsyncContextManager[tuple[AsyncConnection, _ServerMode]]:
assert read_preference is not None, "read_preference must not be None"
server = await self._select_server(read_preference, session, operation)
return self._conn_from_server(read_preference, server, session)
@ -1672,13 +1669,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
if operation.conn_mgr:
server = await self._select_server(
operation.read_preference,
operation.session,
operation.session, # type: ignore[arg-type]
operation.name,
address=address,
)
async with operation.conn_mgr._alock:
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler:
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
err_handler.contribute_socket(operation.conn_mgr.conn)
return await server.run_operation(
operation.conn_mgr.conn,
@ -1690,9 +1687,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
)
async def _cmd(
_session: Optional[ClientSession],
_session: Optional[AsyncClientSession],
server: Server,
conn: Connection,
conn: AsyncConnection,
read_preference: _ServerMode,
) -> Response:
operation.reset() # Reset op in case of retry.
@ -1708,9 +1705,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
return await self._retryable_read(
_cmd,
operation.read_preference,
operation.session,
operation.session, # type: ignore[arg-type]
address=address,
retryable=isinstance(operation, message._Query),
retryable=isinstance(operation, _Query),
operation=operation.name,
)
@ -1718,8 +1715,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self,
retryable: bool,
func: _WriteCall[T],
session: Optional[ClientSession],
bulk: Optional[_Bulk],
session: Optional[AsyncClientSession],
bulk: Optional[_AsyncBulk],
operation: str,
operation_id: Optional[int] = None,
) -> T:
@ -1748,8 +1745,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
async def _retry_internal(
self,
func: _WriteCall[T] | _ReadCall[T],
session: Optional[ClientSession],
bulk: Optional[_Bulk],
session: Optional[AsyncClientSession],
bulk: Optional[_AsyncBulk],
operation: str,
is_read: bool = False,
address: Optional[_Address] = None,
@ -1787,7 +1784,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self,
func: _ReadCall[T],
read_pref: _ServerMode,
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
operation: str,
address: Optional[_Address] = None,
retryable: bool = True,
@ -1830,9 +1827,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self,
retryable: bool,
func: _WriteCall[T],
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
operation: str,
bulk: Optional[_Bulk] = None,
bulk: Optional[_AsyncBulk] = None,
operation_id: Optional[int] = None,
) -> T:
"""Execute an operation with consecutive retries if possible
@ -1856,7 +1853,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
cursor_id: int,
address: Optional[_CursorAddress],
conn_mgr: _ConnectionManager,
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
explicit_session: bool,
) -> None:
"""Cleanup a cursor from __del__ without locking.
@ -1880,7 +1877,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
cursor_id: int,
address: Optional[_CursorAddress],
conn_mgr: _ConnectionManager,
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
explicit_session: bool,
) -> None:
"""Cleanup a cursor from cursor.close() using a lock.
@ -1913,7 +1910,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self,
cursor_id: int,
address: Optional[_CursorAddress],
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
conn_mgr: Optional[_ConnectionManager] = None,
) -> None:
"""Send a kill cursors message with the given id.
@ -1941,7 +1938,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
cursor_ids: Sequence[int],
address: Optional[_CursorAddress],
topology: Topology,
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
) -> None:
"""Send a kill cursors message with the given ids."""
if address:
@ -1960,8 +1957,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self,
cursor_ids: Sequence[int],
address: _CursorAddress,
session: Optional[ClientSession],
conn: Connection,
session: Optional[AsyncClientSession],
conn: AsyncConnection,
) -> None:
namespace = address.namespace
db, coll = namespace.split(".", 1)
@ -1994,7 +1991,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# can be caught in _process_periodic_tasks
raise
else:
helpers._handle_exception()
helpers_shared._handle_exception()
# Don't re-open topology if it's closed and there's no pending cursors.
if address_to_cursor_ids:
@ -2006,7 +2003,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
if isinstance(exc, InvalidOperation) and self._topology._closed:
raise
else:
helpers._handle_exception()
helpers_shared._handle_exception()
# This method is run periodically by a background thread.
async def _process_periodic_tasks(self) -> None:
@ -2020,7 +2017,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
if isinstance(exc, InvalidOperation) and self._topology._closed:
return
else:
helpers._handle_exception()
helpers_shared._handle_exception()
def _return_server_session(
self, server_session: Union[_ServerSession, _EmptyServerSession]
@ -2032,12 +2029,12 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
@contextlib.asynccontextmanager
async def _tmp_session(
self, session: Optional[client_session.ClientSession], close: bool = True
) -> AsyncGenerator[Optional[client_session.ClientSession], None, None]:
self, session: Optional[client_session.AsyncClientSession], close: bool = True
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None, None]:
"""If provided session is None, lend a temporary session."""
if session is not None:
if not isinstance(session, client_session.ClientSession):
raise ValueError("'session' argument must be a ClientSession or None.")
if not isinstance(session, client_session.AsyncClientSession):
raise ValueError("'session' argument must be an AsyncClientSession or None.")
# Don't call end_session.
yield session
return
@ -2061,19 +2058,19 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
yield None
async def _process_response(
self, reply: Mapping[str, Any], session: Optional[ClientSession]
self, reply: Mapping[str, Any], session: Optional[AsyncClientSession]
) -> None:
await self._topology.receive_cluster_time(reply.get("$clusterTime"))
if session is not None:
session._process_response(reply)
async def server_info(
self, session: Optional[client_session.ClientSession] = None
self, session: Optional[client_session.AsyncClientSession] = None
) -> dict[str, Any]:
"""Get information about the MongoDB server we're connected to.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
.. versionchanged:: 3.6
Added ``session`` parameter.
@ -2087,7 +2084,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
async def _list_databases(
self,
session: Optional[client_session.ClientSession] = None,
session: Optional[client_session.AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> AsyncCommandCursor[dict[str, Any]]:
@ -2109,14 +2106,14 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
async def list_databases(
self,
session: Optional[client_session.ClientSession] = None,
session: Optional[client_session.AsyncClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> AsyncCommandCursor[dict[str, Any]]:
"""Get a cursor over the databases of the connected server.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
:param kwargs: Optional parameters of the
@ -2134,13 +2131,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
async def list_database_names(
self,
session: Optional[client_session.ClientSession] = None,
session: Optional[client_session.AsyncClientSession] = None,
comment: Optional[Any] = None,
) -> list[str]:
"""Get a list of the names of all databases on the connected server.
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
@ -2156,7 +2153,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
async def drop_database(
self,
name_or_database: Union[str, database.AsyncDatabase[_DocumentTypeArg]],
session: Optional[client_session.ClientSession] = None,
session: Optional[client_session.AsyncClientSession] = None,
comment: Optional[Any] = None,
) -> None:
"""Drop a database.
@ -2168,7 +2165,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
:class:`~pymongo.database.Database` instance representing the
database to drop
:param session: a
:class:`~pymongo.client_session.ClientSession`.
:class:`~pymongo.client_session.AsyncClientSession`.
:param comment: A user-provided comment to attach to this
command.
@ -2236,10 +2233,10 @@ def _add_retryable_write_error(exc: PyMongoError, max_wire_version: int, is_mong
# Do not consult writeConcernError for pre-4.4 mongos.
if isinstance(exc, WriteConcernError) and is_mongos:
pass
elif code in helpers_constants._RETRYABLE_ERROR_CODES:
elif code in helpers_shared._RETRYABLE_ERROR_CODES:
exc._add_error_label("RetryableWriteError")
# Connection errors are always retryable except NotPrimaryError and WaitQueueTimeoutError which is
# AsyncConnection errors are always retryable except NotPrimaryError and WaitQueueTimeoutError which is
# handled above.
if isinstance(exc, ConnectionFailure) and not isinstance(
exc, (NotPrimaryError, WaitQueueTimeoutError)
@ -2261,7 +2258,9 @@ class _MongoClientErrorHandler:
"handled",
)
def __init__(self, client: AsyncMongoClient, server: Server, session: Optional[ClientSession]):
def __init__(
self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession]
):
self.client = client
self.server_address = server.description.address
self.session = session
@ -2275,7 +2274,7 @@ class _MongoClientErrorHandler:
self.service_id: Optional[ObjectId] = None
self.handled = False
def contribute_socket(self, conn: Connection, completed_handshake: bool = True) -> None:
def contribute_socket(self, conn: AsyncConnection, completed_handshake: bool = True) -> None:
"""Provide socket information to the error handler."""
self.max_wire_version = conn.max_wire_version
self.sock_generation = conn.generation
@ -2327,10 +2326,10 @@ class _ClientConnectionRetryable(Generic[T]):
self,
mongo_client: AsyncMongoClient,
func: _WriteCall[T] | _ReadCall[T],
bulk: Optional[_Bulk],
bulk: Optional[_AsyncBulk],
operation: str,
is_read: bool = False,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
read_pref: Optional[_ServerMode] = None,
address: Optional[_Address] = None,
retryable: bool = False,
@ -2392,7 +2391,7 @@ class _ClientConnectionRetryable(Generic[T]):
exc_code = getattr(exc, "code", None)
if self._is_not_eligible_for_retry() or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_constants._RETRYABLE_ERROR_CODES
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
):
raise
self._retrying = True

View File

@ -21,19 +21,20 @@ import time
import weakref
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
from pymongo import common
from pymongo._csot import MovingMinimum
from pymongo.asynchronous import common, periodic_executor
from pymongo.asynchronous.hello import Hello
from pymongo.asynchronous import periodic_executor
from pymongo.asynchronous.periodic_executor import _shutdown_executors
from pymongo.asynchronous.pool import _is_faas
from pymongo.asynchronous.read_preferences import MovingAverage
from pymongo.asynchronous.server_description import ServerDescription
from pymongo.asynchronous.srv_resolver import _SrvResolver
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
from pymongo.hello import Hello
from pymongo.lock import _create_lock
from pymongo.pool_options import _is_faas
from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
from pymongo.srv_resolver import _SrvResolver
if TYPE_CHECKING:
from pymongo.asynchronous.pool import Connection, Pool, _CancellationContext
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology
@ -294,7 +295,7 @@ class Monitor(MonitorBase):
)
return sd
async def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]:
async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float]:
"""Return (Hello, round_trip_time).
Can raise ConnectionFailure or OperationFailure.

File diff suppressed because it is too large Load Diff

View File

@ -33,20 +33,18 @@ from typing import (
)
from bson import _decode_all_selective
from pymongo import _csot
from pymongo.asynchronous import helpers as _async_helpers
from pymongo.asynchronous import message as _async_message
from pymongo.asynchronous.common import MAX_MESSAGE_SIZE
from pymongo.asynchronous.compression_support import _NO_COMPRESSION, decompress
from pymongo.asynchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.asynchronous.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.asynchronous.monitoring import _is_speculative_authenticate
from pymongo import _csot, helpers_shared, message
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.compression_support import _NO_COMPRESSION, decompress
from pymongo.errors import (
NotPrimaryError,
OperationFailure,
ProtocolError,
_OperationCancelled,
)
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.monitoring import _is_speculative_authenticate
from pymongo.network_layer import (
_POLL_TIMEOUT,
_UNPACK_COMPRESSION_HEADER,
@ -58,27 +56,27 @@ from pymongo.socket_checker import _errno_from_exception
if TYPE_CHECKING:
from bson import CodecOptions
from pymongo.asynchronous.client_session import ClientSession
from pymongo.asynchronous.compression_support import SnappyContext, ZlibContext, ZstdContext
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.monitoring import _EventListeners
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.read_preferences import _ServerMode
from pymongo.asynchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
from pymongo.monitoring import _EventListeners
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
async def command(
conn: Connection,
conn: AsyncConnection,
dbname: str,
spec: MutableMapping[str, Any],
is_mongos: bool,
read_preference: Optional[_ServerMode],
codec_options: CodecOptions[_DocumentType],
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
client: Optional[AsyncMongoClient],
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
@ -97,13 +95,13 @@ async def command(
) -> _DocumentType:
"""Execute a command over the socket, or raise socket.error.
:param conn: a Connection instance
:param conn: a AsyncConnection instance
:param dbname: name of the database on which to run the command
:param spec: a command document as an ordered dict type, eg SON.
:param is_mongos: are we connected to a mongos?
:param read_preference: a read preference
:param codec_options: a CodecOptions instance
:param session: optional ClientSession instance.
:param session: optional AsyncClientSession instance.
:param client: optional AsyncMongoClient instance for updating $clusterTime.
:param check: raise OperationFailure if there are errors
:param allowable_errors: errors to ignore if `check` is True
@ -130,7 +128,7 @@ async def command(
orig = spec
if is_mongos and not use_op_msg:
assert read_preference is not None
spec = _async_message._maybe_add_read_preference(spec, read_preference)
spec = message._maybe_add_read_preference(spec, read_preference)
if read_concern and not (session and session.in_transaction):
if read_concern.level:
spec["readConcern"] = read_concern.document
@ -158,22 +156,20 @@ async def command(
if use_op_msg:
flags = _OpMsg.MORE_TO_COME if unacknowledged else 0
flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0
request_id, msg, size, max_doc_size = _async_message._op_msg(
request_id, msg, size, max_doc_size = message._op_msg(
flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx
)
# If this is an unacknowledged write then make sure the encoded doc(s)
# are small enough, otherwise rely on the server to return an error.
if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size:
_async_message._raise_document_too_large(name, size, max_bson_size)
message._raise_document_too_large(name, size, max_bson_size)
else:
request_id, msg, size = _async_message._query(
request_id, msg, size = message._query(
0, ns, 0, -1, spec, None, codec_options, compression_ctx
)
if max_bson_size is not None and size > max_bson_size + _async_message._COMMAND_OVERHEAD:
_async_message._raise_document_too_large(
name, size, max_bson_size + _async_message._COMMAND_OVERHEAD
)
if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD:
message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD)
if client is not None:
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
@ -220,7 +216,7 @@ async def command(
if client:
await client._process_response(response_doc, session)
if check:
_async_helpers._check_command_response(
helpers_shared._check_command_response(
response_doc,
conn.max_wire_version,
allowable_errors,
@ -231,7 +227,7 @@ async def command(
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _async_message._convert_exception(exc)
failure = message._convert_exception(exc)
if client is not None:
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
@ -310,7 +306,7 @@ async def command(
async def receive_message(
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout():
@ -355,7 +351,7 @@ async def receive_message(
return unpack_reply(data)
async def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
timed_out = False
@ -390,7 +386,7 @@ async def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
async def _receive_data_on_socket(
conn: Connection, length: int, deadline: Optional[float]
conn: AsyncConnection, length: int, deadline: Optional[float]
) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)

View File

@ -1,625 +0,0 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Operation class definitions."""
from __future__ import annotations
import enum
from typing import (
TYPE_CHECKING,
Any,
Generic,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
from bson.raw_bson import RawBSONDocument
from pymongo.asynchronous import helpers
from pymongo.asynchronous.collation import validate_collation_or_none
from pymongo.asynchronous.common import validate_is_mapping, validate_list
from pymongo.asynchronous.helpers import _gen_index_name, _index_document, _index_list
from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _Pipeline
from pymongo.write_concern import validate_boolean
if TYPE_CHECKING:
from pymongo.asynchronous.bulk import _Bulk
_IS_SYNC = False
# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary
_IndexList = Union[
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
]
_IndexKeyHint = Union[str, _IndexList]
class _Op(str, enum.Enum):
ABORT = "abortTransaction"
AGGREGATE = "aggregate"
COMMIT = "commitTransaction"
COUNT = "count"
CREATE = "create"
CREATE_INDEXES = "createIndexes"
CREATE_SEARCH_INDEXES = "createSearchIndexes"
DELETE = "delete"
DISTINCT = "distinct"
DROP = "drop"
DROP_DATABASE = "dropDatabase"
DROP_INDEXES = "dropIndexes"
DROP_SEARCH_INDEXES = "dropSearchIndexes"
END_SESSIONS = "endSessions"
FIND_AND_MODIFY = "findAndModify"
FIND = "find"
INSERT = "insert"
LIST_COLLECTIONS = "listCollections"
LIST_INDEXES = "listIndexes"
LIST_SEARCH_INDEX = "listSearchIndexes"
LIST_DATABASES = "listDatabases"
UPDATE = "update"
UPDATE_INDEX = "updateIndex"
UPDATE_SEARCH_INDEX = "updateSearchIndex"
RENAME = "rename"
GETMORE = "getMore"
KILL_CURSORS = "killCursors"
TEST = "testOperation"
class InsertOne(Generic[_DocumentType]):
"""Represents an insert_one operation."""
__slots__ = ("_doc",)
def __init__(self, document: _DocumentType) -> None:
"""Create an InsertOne instance.
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
:param document: The document to insert. If the document is missing an
_id field one will be added.
"""
self._doc = document
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_insert(self._doc) # type: ignore[arg-type]
def __repr__(self) -> str:
return f"InsertOne({self._doc!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return other._doc == self._doc
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
class DeleteOne:
"""Represents a delete_one operation."""
__slots__ = ("_filter", "_collation", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create a DeleteOne instance.
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
:param filter: A query that matches the document to delete.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.4 and above.
.. versionchanged:: 3.11
Added the ``hint`` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._collation = collation
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_delete(
self._filter,
1,
collation=validate_collation_or_none(self._collation),
hint=self._hint,
)
def __repr__(self) -> str:
return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return (other._filter, other._collation, other._hint) == (
self._filter,
self._collation,
self._hint,
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
class DeleteMany:
"""Represents a delete_many operation."""
__slots__ = ("_filter", "_collation", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create a DeleteMany instance.
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
:param filter: A query that matches the documents to delete.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.4 and above.
.. versionchanged:: 3.11
Added the ``hint`` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._collation = collation
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_delete(
self._filter,
0,
collation=validate_collation_or_none(self._collation),
hint=self._hint,
)
def __repr__(self) -> str:
return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return (other._filter, other._collation, other._hint) == (
self._filter,
self._collation,
self._hint,
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
class ReplaceOne(Generic[_DocumentType]):
"""Represents a replace_one operation."""
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
replacement: Union[_DocumentType, RawBSONDocument],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create a ReplaceOne instance.
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
:param filter: A query that matches the document to replace.
:param replacement: The new document.
:param upsert: If ``True``, perform an insert if no documents
match the filter.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
.. versionchanged:: 3.11
Added the ``hint`` option.
.. versionchanged:: 3.5
Added the ``collation`` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
if upsert is not None:
validate_boolean("upsert", upsert)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._doc = replacement
self._upsert = upsert
self._collation = collation
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_replace(
self._filter,
self._doc,
self._upsert,
collation=validate_collation_or_none(self._collation),
hint=self._hint,
)
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return (
other._filter,
other._doc,
other._upsert,
other._collation,
other._hint,
) == (
self._filter,
self._doc,
self._upsert,
self._collation,
other._hint,
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __repr__(self) -> str:
return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format(
self.__class__.__name__,
self._filter,
self._doc,
self._upsert,
self._collation,
self._hint,
)
class _UpdateOp:
"""Private base class for update operations."""
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
doc: Union[Mapping[str, Any], _Pipeline],
upsert: bool,
collation: Optional[_CollationIn],
array_filters: Optional[list[Mapping[str, Any]]],
hint: Optional[_IndexKeyHint],
):
if filter is not None:
validate_is_mapping("filter", filter)
if upsert is not None:
validate_boolean("upsert", upsert)
if array_filters is not None:
validate_list("array_filters", array_filters)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._doc = doc
self._upsert = upsert
self._collation = collation
self._array_filters = array_filters
def __eq__(self, other: object) -> bool:
if isinstance(other, type(self)):
return (
other._filter,
other._doc,
other._upsert,
other._collation,
other._array_filters,
other._hint,
) == (
self._filter,
self._doc,
self._upsert,
self._collation,
self._array_filters,
self._hint,
)
return NotImplemented
def __repr__(self) -> str:
return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format(
self.__class__.__name__,
self._filter,
self._doc,
self._upsert,
self._collation,
self._array_filters,
self._hint,
)
class UpdateOne(_UpdateOp):
"""Represents an update_one operation."""
__slots__ = ()
def __init__(
self,
filter: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Represents an update_one operation.
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
:param filter: A query that matches the document to update.
:param update: The modifications to apply.
:param upsert: If ``True``, perform an insert if no documents
match the filter.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param array_filters: A list of filters specifying which
array elements an update should apply.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
.. versionchanged:: 3.11
Added the `hint` option.
.. versionchanged:: 3.9
Added the ability to accept a pipeline as the `update`.
.. versionchanged:: 3.6
Added the `array_filters` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
super().__init__(filter, update, upsert, collation, array_filters, hint)
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_update(
self._filter,
self._doc,
False,
self._upsert,
collation=validate_collation_or_none(self._collation),
array_filters=self._array_filters,
hint=self._hint,
)
class UpdateMany(_UpdateOp):
"""Represents an update_many operation."""
__slots__ = ()
def __init__(
self,
filter: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create an UpdateMany instance.
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
:param filter: A query that matches the documents to update.
:param update: The modifications to apply.
:param upsert: If ``True``, perform an insert if no documents
match the filter.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param array_filters: A list of filters specifying which
array elements an update should apply.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
.. versionchanged:: 3.11
Added the `hint` option.
.. versionchanged:: 3.9
Added the ability to accept a pipeline as the `update`.
.. versionchanged:: 3.6
Added the `array_filters` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
super().__init__(filter, update, upsert, collation, array_filters, hint)
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_update(
self._filter,
self._doc,
True,
self._upsert,
collation=validate_collation_or_none(self._collation),
array_filters=self._array_filters,
hint=self._hint,
)
class IndexModel:
"""Represents an index to create."""
__slots__ = ("__document",)
def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None:
"""Create an Index instance.
For use with :meth:`~pymongo.collection.AsyncCollection.create_indexes`.
Takes either a single key or a list containing (key, direction) pairs
or keys. If no direction is given, :data:`~pymongo.ASCENDING` will
be assumed.
The key(s) must be an instance of :class:`str`, and the direction(s) must
be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`,
:data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`,
:data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`).
Valid options include, but are not limited to:
- `name`: custom name to use for this index - if none is
given, a name will be generated.
- `unique`: if ``True``, creates a uniqueness constraint on the index.
- `background`: if ``True``, this index should be created in the
background.
- `sparse`: if ``True``, omit from the index any documents that lack
the indexed field.
- `bucketSize`: for use with geoHaystack indexes.
Number of documents to group together within a certain proximity
to a given longitude and latitude.
- `min`: minimum value for keys in a :data:`~pymongo.GEO2D`
index.
- `max`: maximum value for keys in a :data:`~pymongo.GEO2D`
index.
- `expireAfterSeconds`: <int> Used to create an expiring (TTL)
collection. MongoDB will automatically delete documents from
this collection after <int> seconds. The indexed field must
be a UTC datetime or the data will not expire.
- `partialFilterExpression`: A document that specifies a filter for
a partial index.
- `collation`: An instance of :class:`~pymongo.collation.Collation`
that specifies the collation to use.
- `wildcardProjection`: Allows users to include or exclude specific
field paths from a `wildcard index`_ using the { "$**" : 1} key
pattern. Requires MongoDB >= 4.2.
- `hidden`: if ``True``, this index will be hidden from the query
planner and will not be evaluated as part of query plan
selection. Requires MongoDB >= 4.4.
See the MongoDB documentation for a full list of supported options by
server version.
:param keys: a single key or a list containing (key, direction) pairs
or keys specifying the index to create.
:param kwargs: any additional index creation
options (see the above list) should be passed as keyword
arguments.
.. versionchanged:: 3.11
Added the ``hidden`` option.
.. versionchanged:: 3.2
Added the ``partialFilterExpression`` option to support partial
indexes.
.. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/
"""
keys = _index_list(keys)
if kwargs.get("name") is None:
kwargs["name"] = _gen_index_name(keys)
kwargs["key"] = _index_document(keys)
collation = validate_collation_or_none(kwargs.pop("collation", None))
self.__document = kwargs
if collation is not None:
self.__document["collation"] = collation
@property
def document(self) -> dict[str, Any]:
"""An index document suitable for passing to the createIndexes
command.
"""
return self.__document
class SearchIndexModel:
"""Represents a search index to create."""
__slots__ = ("__document",)
def __init__(
self,
definition: Mapping[str, Any],
name: Optional[str] = None,
type: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Create a Search Index instance.
For use with :meth:`~pymongo.collection.AsyncCollection.create_search_index` and :meth:`~pymongo.collection.AsyncCollection.create_search_indexes`.
:param definition: The definition for this index.
:param name: The name for this index, if present.
:param type: The type for this index which defaults to "search". Alternative values include "vectorSearch".
:param kwargs: Keyword arguments supplying any additional options.
.. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster.
.. versionadded:: 4.5
.. versionchanged:: 4.7
Added the type and kwargs arguments.
"""
self.__document: dict[str, Any] = {}
if name is not None:
self.__document["name"] = name
self.__document["definition"] = definition
if type is not None:
self.__document["type"] = type
self.__document.update(kwargs)
@property
def document(self) -> Mapping[str, Any]:
"""The document for this index."""
return self.__document

View File

@ -16,17 +16,14 @@ from __future__ import annotations
import collections
import contextlib
import copy
import logging
import os
import platform
import socket
import ssl
import sys
import threading
import time
import weakref
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
@ -39,39 +36,18 @@ from typing import (
Union,
)
import bson
from bson import DEFAULT_CODEC_OPTIONS
from pymongo import __version__, _csot
from pymongo.asynchronous import helpers
from pymongo import _csot, helpers_shared
from pymongo.asynchronous.client_session import _validate_session_write_concern
from pymongo.asynchronous.common import (
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.asynchronous.network import command, receive_message
from pymongo.common import (
MAX_BSON_SIZE,
MAX_CONNECTING,
MAX_IDLE_TIME_SEC,
MAX_MESSAGE_SIZE,
MAX_POOL_SIZE,
MAX_WIRE_VERSION,
MAX_WRITE_BATCH_SIZE,
MIN_POOL_SIZE,
ORDERED_TYPES,
WAIT_QUEUE_TIMEOUT,
)
from pymongo.asynchronous.hello import Hello
from pymongo.asynchronous.hello_compat import HelloCompat
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.asynchronous.logger import (
_CONNECTION_LOGGER,
_ConnectionStatusMessage,
_debug_log,
_verbose_connection_error_reason,
)
from pymongo.asynchronous.monitoring import (
ConnectionCheckOutFailedReason,
ConnectionClosedReason,
_EventListeners,
)
from pymongo.asynchronous.network import command, receive_message
from pymongo.asynchronous.read_preferences import ReadPreference
from pymongo.errors import ( # type:ignore[attr-defined]
AutoReconnect,
ConfigurationError,
@ -86,8 +62,21 @@ from pymongo.errors import ( # type:ignore[attr-defined]
WaitQueueTimeoutError,
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import _ACondition, _ALock, _create_lock
from pymongo.logger import (
_CONNECTION_LOGGER,
_ConnectionStatusMessage,
_debug_log,
_verbose_connection_error_reason,
)
from pymongo.monitoring import (
ConnectionCheckOutFailedReason,
ConnectionClosedReason,
)
from pymongo.network_layer import async_sendall
from pymongo.pool_options import PoolOptions
from pymongo.read_preferences import ReadPreference
from pymongo.server_api import _add_to_command
from pymongo.server_type import SERVER_TYPE
from pymongo.socket_checker import SocketChecker
@ -96,22 +85,19 @@ from pymongo.ssl_support import HAS_SNI, SSLError
if TYPE_CHECKING:
from bson import CodecOptions
from bson.objectid import ObjectId
from pymongo.asynchronous.auth import MongoCredential, _AuthContext
from pymongo.asynchronous.client_session import ClientSession
from pymongo.asynchronous.compression_support import (
CompressionSettings,
from pymongo.asynchronous.auth import _AuthContext
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler
from pymongo.compression_support import (
SnappyContext,
ZlibContext,
ZstdContext,
)
from pymongo.asynchronous.message import _OpMsg, _OpReply
from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler
from pymongo.asynchronous.read_preferences import _ServerMode
from pymongo.asynchronous.typings import ClusterTime, _Address, _CollationIn
from pymongo.driver_info import DriverInfo
from pymongo.pyopenssl_context import SSLContext, _sslConn
from pymongo.message import _OpMsg, _OpReply
from pymongo.pyopenssl_context import _sslConn
from pymongo.read_concern import ReadConcern
from pymongo.server_api import ServerApi
from pymongo.read_preferences import _ServerMode
from pymongo.typings import ClusterTime, _Address, _CollationIn
from pymongo.write_concern import WriteConcern
try:
@ -191,236 +177,6 @@ else:
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
_METADATA: dict[str, Any] = {"driver": {"name": "PyMongo", "version": __version__}}
if sys.platform.startswith("linux"):
_METADATA["os"] = {
"type": platform.system(),
"name": platform.system(),
"architecture": platform.machine(),
# Kernel version (e.g. 4.4.0-17-generic).
"version": platform.release(),
}
elif sys.platform == "darwin":
_METADATA["os"] = {
"type": platform.system(),
"name": platform.system(),
"architecture": platform.machine(),
# (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin
# kernel version.
"version": platform.mac_ver()[0],
}
elif sys.platform == "win32":
_ver = sys.getwindowsversion()
_METADATA["os"] = {
"type": "Windows",
"name": "Windows",
# Avoid using platform calls, see PYTHON-4455.
"architecture": os.environ.get("PROCESSOR_ARCHITECTURE") or platform.machine(),
# Windows patch level (e.g. 10.0.17763-SP0).
"version": ".".join(map(str, _ver[:3])) + f"-SP{_ver[-1] or '0'}",
}
elif sys.platform.startswith("java"):
_name, _ver, _arch = platform.java_ver()[-1]
_METADATA["os"] = {
# Linux, Windows 7, Mac OS X, etc.
"type": _name,
"name": _name,
# x86, x86_64, AMD64, etc.
"architecture": _arch,
# Linux kernel version, OSX version, etc.
"version": _ver,
}
else:
# Get potential alias (e.g. SunOS 5.11 becomes Solaris 2.11)
_aliased = platform.system_alias(platform.system(), platform.release(), platform.version())
_METADATA["os"] = {
"type": platform.system(),
"name": " ".join([part for part in _aliased[:2] if part]),
"architecture": platform.machine(),
"version": _aliased[2],
}
if platform.python_implementation().startswith("PyPy"):
_METADATA["platform"] = " ".join(
(
platform.python_implementation(),
".".join(map(str, sys.pypy_version_info)), # type: ignore
"(Python %s)" % ".".join(map(str, sys.version_info)),
)
)
elif sys.platform.startswith("java"):
_METADATA["platform"] = " ".join(
(
platform.python_implementation(),
".".join(map(str, sys.version_info)),
"(%s)" % " ".join((platform.system(), platform.release())),
)
)
else:
_METADATA["platform"] = " ".join(
(platform.python_implementation(), ".".join(map(str, sys.version_info)))
)
DOCKER_ENV_PATH = "/.dockerenv"
ENV_VAR_K8S = "KUBERNETES_SERVICE_HOST"
RUNTIME_NAME_DOCKER = "docker"
ORCHESTRATOR_NAME_K8S = "kubernetes"
def get_container_env_info() -> dict[str, str]:
"""Returns the runtime and orchestrator of a container.
If neither value is present, the metadata client.env.container field will be omitted."""
container = {}
if Path(DOCKER_ENV_PATH).exists():
container["runtime"] = RUNTIME_NAME_DOCKER
if os.getenv(ENV_VAR_K8S):
container["orchestrator"] = ORCHESTRATOR_NAME_K8S
return container
def _is_lambda() -> bool:
if os.getenv("AWS_LAMBDA_RUNTIME_API"):
return True
env = os.getenv("AWS_EXECUTION_ENV")
if env:
return env.startswith("AWS_Lambda_")
return False
def _is_azure_func() -> bool:
return bool(os.getenv("FUNCTIONS_WORKER_RUNTIME"))
def _is_gcp_func() -> bool:
return bool(os.getenv("K_SERVICE") or os.getenv("FUNCTION_NAME"))
def _is_vercel() -> bool:
return bool(os.getenv("VERCEL"))
def _is_faas() -> bool:
return _is_lambda() or _is_azure_func() or _is_gcp_func() or _is_vercel()
def _getenv_int(key: str) -> Optional[int]:
"""Like os.getenv but returns an int, or None if the value is missing/malformed."""
val = os.getenv(key)
if not val:
return None
try:
return int(val)
except ValueError:
return None
def _metadata_env() -> dict[str, Any]:
env: dict[str, Any] = {}
container = get_container_env_info()
if container:
env["container"] = container
# Skip if multiple (or no) envs are matched.
if (_is_lambda(), _is_azure_func(), _is_gcp_func(), _is_vercel()).count(True) != 1:
return env
if _is_lambda():
env["name"] = "aws.lambda"
region = os.getenv("AWS_REGION")
if region:
env["region"] = region
memory_mb = _getenv_int("AWS_LAMBDA_FUNCTION_MEMORY_SIZE")
if memory_mb is not None:
env["memory_mb"] = memory_mb
elif _is_azure_func():
env["name"] = "azure.func"
elif _is_gcp_func():
env["name"] = "gcp.func"
region = os.getenv("FUNCTION_REGION")
if region:
env["region"] = region
memory_mb = _getenv_int("FUNCTION_MEMORY_MB")
if memory_mb is not None:
env["memory_mb"] = memory_mb
timeout_sec = _getenv_int("FUNCTION_TIMEOUT_SEC")
if timeout_sec is not None:
env["timeout_sec"] = timeout_sec
elif _is_vercel():
env["name"] = "vercel"
region = os.getenv("VERCEL_REGION")
if region:
env["region"] = region
return env
_MAX_METADATA_SIZE = 512
# See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations
def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None:
"""Perform metadata truncation."""
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
return
# 1. Omit fields from env except env.name.
env_name = metadata.get("env", {}).get("name")
if env_name:
metadata["env"] = {"name": env_name}
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
return
# 2. Omit fields from os except os.type.
os_type = metadata.get("os", {}).get("type")
if os_type:
metadata["os"] = {"type": os_type}
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
return
# 3. Omit the env document entirely.
metadata.pop("env", None)
encoded_size = len(bson.encode(metadata))
if encoded_size <= _MAX_METADATA_SIZE:
return
# 4. Truncate platform.
overflow = encoded_size - _MAX_METADATA_SIZE
plat = metadata.get("platform", "")
if plat:
plat = plat[:-overflow]
if plat:
metadata["platform"] = plat
else:
metadata.pop("platform", None)
encoded_size = len(bson.encode(metadata))
if encoded_size <= _MAX_METADATA_SIZE:
return
# 5. Truncate driver info.
overflow = encoded_size - _MAX_METADATA_SIZE
driver = metadata.get("driver", {})
if driver:
# Truncate driver version.
driver_version = driver.get("version")[:-overflow]
if len(driver_version) >= len(_METADATA["driver"]["version"]):
metadata["driver"]["version"] = driver_version
else:
metadata["driver"]["version"] = _METADATA["driver"]["version"]
encoded_size = len(bson.encode(metadata))
if encoded_size <= _MAX_METADATA_SIZE:
return
# Truncate driver name.
overflow = encoded_size - _MAX_METADATA_SIZE
driver_name = driver.get("name")[:-overflow]
if len(driver_name) >= len(_METADATA["driver"]["name"]):
metadata["driver"]["name"] = driver_name
else:
metadata["driver"]["name"] = _METADATA["driver"]["name"]
# If the first getaddrinfo call of this interpreter's life is on a thread,
# while the main thread holds the import lock, getaddrinfo deadlocks trying
# to import the IDNA codec. Import it here, where presumably we're on the
# main thread, to avoid the deadlock. See PYTHON-607.
"foo".encode("idna")
def _raise_connection_failure(
address: Any,
error: Exception,
@ -481,238 +237,6 @@ def format_timeout_details(details: Optional[dict[str, float]]) -> str:
return result
class PoolOptions:
"""Read only connection pool options for an AsyncMongoClient.
Should not be instantiated directly by application developers. Access
a client's pool options via
:attr:`~pymongo.client_options.ClientOptions.pool_options` instead::
pool_opts = client.options.pool_options
pool_opts.max_pool_size
pool_opts.min_pool_size
"""
__slots__ = (
"__max_pool_size",
"__min_pool_size",
"__max_idle_time_seconds",
"__connect_timeout",
"__socket_timeout",
"__wait_queue_timeout",
"__ssl_context",
"__tls_allow_invalid_hostnames",
"__event_listeners",
"__appname",
"__driver",
"__metadata",
"__compression_settings",
"__max_connecting",
"__pause_enabled",
"__server_api",
"__load_balanced",
"__credentials",
)
def __init__(
self,
max_pool_size: int = MAX_POOL_SIZE,
min_pool_size: int = MIN_POOL_SIZE,
max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC,
connect_timeout: Optional[float] = None,
socket_timeout: Optional[float] = None,
wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT,
ssl_context: Optional[SSLContext] = None,
tls_allow_invalid_hostnames: bool = False,
event_listeners: Optional[_EventListeners] = None,
appname: Optional[str] = None,
driver: Optional[DriverInfo] = None,
compression_settings: Optional[CompressionSettings] = None,
max_connecting: int = MAX_CONNECTING,
pause_enabled: bool = True,
server_api: Optional[ServerApi] = None,
load_balanced: Optional[bool] = None,
credentials: Optional[MongoCredential] = None,
):
self.__max_pool_size = max_pool_size
self.__min_pool_size = min_pool_size
self.__max_idle_time_seconds = max_idle_time_seconds
self.__connect_timeout = connect_timeout
self.__socket_timeout = socket_timeout
self.__wait_queue_timeout = wait_queue_timeout
self.__ssl_context = ssl_context
self.__tls_allow_invalid_hostnames = tls_allow_invalid_hostnames
self.__event_listeners = event_listeners
self.__appname = appname
self.__driver = driver
self.__compression_settings = compression_settings
self.__max_connecting = max_connecting
self.__pause_enabled = pause_enabled
self.__server_api = server_api
self.__load_balanced = load_balanced
self.__credentials = credentials
self.__metadata = copy.deepcopy(_METADATA)
if appname:
self.__metadata["application"] = {"name": appname}
# Combine the "driver" AsyncMongoClient option with PyMongo's info, like:
# {
# 'driver': {
# 'name': 'PyMongo|MyDriver',
# 'version': '4.2.0|1.2.3',
# },
# 'platform': 'CPython 3.8.0|MyPlatform'
# }
if driver:
if driver.name:
self.__metadata["driver"]["name"] = "{}|{}".format(
_METADATA["driver"]["name"],
driver.name,
)
if driver.version:
self.__metadata["driver"]["version"] = "{}|{}".format(
_METADATA["driver"]["version"],
driver.version,
)
if driver.platform:
self.__metadata["platform"] = "{}|{}".format(_METADATA["platform"], driver.platform)
env = _metadata_env()
if env:
self.__metadata["env"] = env
_truncate_metadata(self.__metadata)
@property
def _credentials(self) -> Optional[MongoCredential]:
"""A :class:`~pymongo.auth.MongoCredentials` instance or None."""
return self.__credentials
@property
def non_default_options(self) -> dict[str, Any]:
"""The non-default options this pool was created with.
Added for CMAP's :class:`PoolCreatedEvent`.
"""
opts = {}
if self.__max_pool_size != MAX_POOL_SIZE:
opts["maxPoolSize"] = self.__max_pool_size
if self.__min_pool_size != MIN_POOL_SIZE:
opts["minPoolSize"] = self.__min_pool_size
if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC:
assert self.__max_idle_time_seconds is not None
opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000
if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT:
assert self.__wait_queue_timeout is not None
opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000
if self.__max_connecting != MAX_CONNECTING:
opts["maxConnecting"] = self.__max_connecting
return opts
@property
def max_pool_size(self) -> float:
"""The maximum allowable number of concurrent connections to each
connected server. Requests to a server will block if there are
`maxPoolSize` outstanding connections to the requested server.
Defaults to 100. Cannot be 0.
When a server's pool has reached `max_pool_size`, operations for that
server block waiting for a socket to be returned to the pool. If
``waitQueueTimeoutMS`` is set, a blocked operation will raise
:exc:`~pymongo.errors.ConnectionFailure` after a timeout.
By default ``waitQueueTimeoutMS`` is not set.
"""
return self.__max_pool_size
@property
def min_pool_size(self) -> int:
"""The minimum required number of concurrent connections that the pool
will maintain to each connected server. Default is 0.
"""
return self.__min_pool_size
@property
def max_connecting(self) -> int:
"""The maximum number of concurrent connection creation attempts per
pool. Defaults to 2.
"""
return self.__max_connecting
@property
def pause_enabled(self) -> bool:
return self.__pause_enabled
@property
def max_idle_time_seconds(self) -> Optional[int]:
"""The maximum number of seconds that a connection can remain
idle in the pool before being removed and replaced. Defaults to
`None` (no limit).
"""
return self.__max_idle_time_seconds
@property
def connect_timeout(self) -> Optional[float]:
"""How long a connection can take to be opened before timing out."""
return self.__connect_timeout
@property
def socket_timeout(self) -> Optional[float]:
"""How long a send or receive on a socket can take before timing out."""
return self.__socket_timeout
@property
def wait_queue_timeout(self) -> Optional[int]:
"""How long a thread will wait for a socket from the pool if the pool
has no free sockets.
"""
return self.__wait_queue_timeout
@property
def _ssl_context(self) -> Optional[SSLContext]:
"""An SSLContext instance or None."""
return self.__ssl_context
@property
def tls_allow_invalid_hostnames(self) -> bool:
"""If True skip ssl.match_hostname."""
return self.__tls_allow_invalid_hostnames
@property
def _event_listeners(self) -> Optional[_EventListeners]:
"""An instance of pymongo.monitoring._EventListeners."""
return self.__event_listeners
@property
def appname(self) -> Optional[str]:
"""The application name, for sending with hello in server handshake."""
return self.__appname
@property
def driver(self) -> Optional[DriverInfo]:
"""Driver name and version, for sending with hello in handshake."""
return self.__driver
@property
def _compression_settings(self) -> Optional[CompressionSettings]:
return self.__compression_settings
@property
def metadata(self) -> dict[str, Any]:
"""A dict of metadata about the application, driver, os, and platform."""
return self.__metadata.copy()
@property
def server_api(self) -> Optional[ServerApi]:
"""A pymongo.server_api.ServerApi or None."""
return self.__server_api
@property
def load_balanced(self) -> Optional[bool]:
"""True if this Pool is configured in load balanced mode."""
return self.__load_balanced
class _CancellationContext:
def __init__(self) -> None:
self._cancelled = False
@ -727,7 +251,7 @@ class _CancellationContext:
return self._cancelled
class Connection:
class AsyncConnection:
"""Store a connection with some metadata.
:param conn: a raw connection object
@ -946,7 +470,7 @@ class Connection:
self.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response()
response_doc = unpacked_docs[0]
helpers._check_command_response(response_doc, self.max_wire_version)
helpers_shared._check_command_response(response_doc, self.max_wire_version)
return response_doc
@_handle_reauth
@ -962,7 +486,7 @@ class Connection:
write_concern: Optional[WriteConcern] = None,
parse_write_concern_error: bool = False,
collation: Optional[_CollationIn] = None,
session: Optional[ClientSession] = None,
session: Optional[AsyncClientSession] = None,
client: Optional[AsyncMongoClient] = None,
retryable_write: bool = False,
publish_events: bool = True,
@ -982,7 +506,7 @@ class Connection:
:param parse_write_concern_error: Whether to parse the
``writeConcernError`` field in the command response.
:param collation: The collation for this command.
:param session: optional ClientSession instance.
:param session: optional AsyncClientSession instance.
:param client: optional AsyncMongoClient for gossipping $clusterTime.
:param retryable_write: True if this command is a retryable write.
:param publish_events: Should we publish events for this command?
@ -1099,7 +623,7 @@ class Connection:
result = reply.command_response(codec_options)
# Raises NotPrimaryError or OperationFailure.
helpers._check_command_response(result, self.max_wire_version)
helpers_shared._check_command_response(result, self.max_wire_version)
return result
async def authenticate(self, reauthenticate: bool = False) -> None:
@ -1137,7 +661,7 @@ class Connection:
)
def validate_session(
self, client: Optional[AsyncMongoClient], session: Optional[ClientSession]
self, client: Optional[AsyncMongoClient], session: Optional[AsyncClientSession]
) -> None:
"""Validate this session before use with client.
@ -1190,7 +714,7 @@ class Connection:
def send_cluster_time(
self,
command: MutableMapping[str, Any],
session: Optional[ClientSession],
session: Optional[AsyncClientSession],
client: Optional[AsyncMongoClient],
) -> None:
"""Add $clusterTime."""
@ -1250,7 +774,7 @@ class Connection:
return hash(self.conn)
def __repr__(self) -> str:
return "Connection({}){} at {}".format(
return "AsyncConnection({}){} at {}".format(
repr(self.conn),
self.closed and " CLOSED" or "",
id(self),
@ -1441,7 +965,7 @@ class Pool:
"""
:param address: a (hostname, port) tuple
:param options: a PoolOptions instance
:param handshake: whether to call hello for each new Connection
:param handshake: whether to call hello for each new AsyncConnection
"""
if options.pause_enabled:
self.state = PoolState.PAUSED
@ -1512,7 +1036,7 @@ class Pool:
# Retain references to pinned connections to prevent the CPython GC
# from thinking that a cursor's pinned connection can be GC'd when the
# cursor is GC'd (see PYTHON-2751).
self.__pinned_sockets: set[Connection] = set()
self.__pinned_sockets: set[AsyncConnection] = set()
self.ncursors = 0
self.ntxns = 0
@ -1700,8 +1224,8 @@ class Pool:
self.requests -= 1
self.size_cond.notify()
async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection:
"""Connect to Mongo and return a new Connection.
async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnection:
"""Connect to Mongo and return a new AsyncConnection.
Can raise ConnectionFailure.
@ -1751,7 +1275,7 @@ class Pool:
raise
conn = Connection(sock, self, self.address, conn_id) # type: ignore[arg-type]
conn = AsyncConnection(sock, self, self.address, conn_id) # type: ignore[arg-type]
async with self.lock:
self.active_contexts.add(conn.cancel_context)
try:
@ -1771,10 +1295,10 @@ class Pool:
@contextlib.asynccontextmanager
async def checkout(
self, handler: Optional[_MongoClientErrorHandler] = None
) -> AsyncGenerator[Connection, None]:
) -> AsyncGenerator[AsyncConnection, None]:
"""Get a connection from the pool. Use with a "with" statement.
Returns a :class:`Connection` object wrapping a connected
Returns a :class:`AsyncConnection` object wrapping a connected
:class:`socket.socket`.
This method should always be used in a with-statement::
@ -1874,8 +1398,8 @@ class Pool:
async def _get_conn(
self, checkout_started_time: float, handler: Optional[_MongoClientErrorHandler] = None
) -> Connection:
"""Get or create a Connection. Can raise ConnectionFailure."""
) -> AsyncConnection:
"""Get or create a AsyncConnection. Can raise ConnectionFailure."""
# We use the pid here to avoid issues with fork / multiprocessing.
# See test.test_client:TestClient.test_fork for an example of
# what could go wrong otherwise
@ -1998,7 +1522,7 @@ class Pool:
conn.active = True
return conn
async def checkin(self, conn: Connection) -> None:
async def checkin(self, conn: AsyncConnection) -> None:
"""Return the connection to the pool, or if it's closed discard it.
:param conn: The connection to check into the pool.
@ -2070,7 +1594,7 @@ class Pool:
self.operation_count -= 1
self.size_cond.notify()
def _perished(self, conn: Connection) -> bool:
def _perished(self, conn: AsyncConnection) -> bool:
"""Return True and close the connection if it is "perished".
This side-effecty function checks if this socket has been idle for

View File

@ -1,624 +0,0 @@
# Copyright 2012-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License",
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for choosing which member of a replica set to read from."""
from __future__ import annotations
from collections import abc
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
from pymongo.asynchronous import max_staleness_selectors
from pymongo.asynchronous.server_selectors import (
member_with_tags_server_selector,
secondary_with_tags_server_selector,
)
from pymongo.errors import ConfigurationError
if TYPE_CHECKING:
from pymongo.asynchronous.server_selectors import Selection
from pymongo.asynchronous.topology_description import TopologyDescription
_IS_SYNC = False
_PRIMARY = 0
_PRIMARY_PREFERRED = 1
_SECONDARY = 2
_SECONDARY_PREFERRED = 3
_NEAREST = 4
_MONGOS_MODES = (
"primary",
"primaryPreferred",
"secondary",
"secondaryPreferred",
"nearest",
)
_Hedge = Mapping[str, Any]
_TagSets = Sequence[Mapping[str, Any]]
def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]:
"""Validate tag sets for a MongoClient."""
if tag_sets is None:
return tag_sets
if not isinstance(tag_sets, (list, tuple)):
raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence")
if len(tag_sets) == 0:
raise ValueError(
f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags"
)
for tags in tag_sets:
if not isinstance(tags, abc.Mapping):
raise TypeError(
f"Tag set {tags!r} invalid, must be an instance of dict, "
"bson.son.SON or other type that inherits from "
"collection.Mapping"
)
return list(tag_sets)
def _invalid_max_staleness_msg(max_staleness: Any) -> str:
return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness
# Some duplication with common.py to avoid import cycle.
def _validate_max_staleness(max_staleness: Any) -> int:
"""Validate max_staleness."""
if max_staleness == -1:
return -1
if not isinstance(max_staleness, int):
raise TypeError(_invalid_max_staleness_msg(max_staleness))
if max_staleness <= 0:
raise ValueError(_invalid_max_staleness_msg(max_staleness))
return max_staleness
def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]:
"""Validate hedge."""
if hedge is None:
return None
if not isinstance(hedge, dict):
raise TypeError(f"hedge must be a dictionary, not {hedge!r}")
return hedge
class _ServerMode:
"""Base class for all read preferences."""
__slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge")
def __init__(
self,
mode: int,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
self.__mongos_mode = _MONGOS_MODES[mode]
self.__mode = mode
self.__tag_sets = _validate_tag_sets(tag_sets)
self.__max_staleness = _validate_max_staleness(max_staleness)
self.__hedge = _validate_hedge(hedge)
@property
def name(self) -> str:
"""The name of this read preference."""
return self.__class__.__name__
@property
def mongos_mode(self) -> str:
"""The mongos mode of this read preference."""
return self.__mongos_mode
@property
def document(self) -> dict[str, Any]:
"""Read preference as a document."""
doc: dict[str, Any] = {"mode": self.__mongos_mode}
if self.__tag_sets not in (None, [{}]):
doc["tags"] = self.__tag_sets
if self.__max_staleness != -1:
doc["maxStalenessSeconds"] = self.__max_staleness
if self.__hedge not in (None, {}):
doc["hedge"] = self.__hedge
return doc
@property
def mode(self) -> int:
"""The mode of this read preference instance."""
return self.__mode
@property
def tag_sets(self) -> _TagSets:
"""Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to
read only from members whose ``dc`` tag has the value ``"ny"``.
To specify a priority-order for tag sets, provide a list of
tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag
set, ``{}``, means "read from any member that matches the mode,
ignoring tags." MongoClient tries each set of tags in turn
until it finds a set of tags with at least one matching member.
For example, to only send a query to an analytic node::
Nearest(tag_sets=[{"node":"analytics"}])
Or using :class:`SecondaryPreferred`::
SecondaryPreferred(tag_sets=[{"node":"analytics"}])
.. seealso:: `Data-Center Awareness
<https://www.mongodb.com/docs/manual/data-center-awareness/>`_
"""
return list(self.__tag_sets) if self.__tag_sets else [{}]
@property
def max_staleness(self) -> int:
"""The maximum estimated length of time (in seconds) a replica set
secondary can fall behind the primary in replication before it will
no longer be selected for operations, or -1 for no maximum.
"""
return self.__max_staleness
@property
def hedge(self) -> Optional[_Hedge]:
"""The read preference ``hedge`` parameter.
A dictionary that configures how the server will perform hedged reads.
It consists of the following keys:
- ``enabled``: Enables or disables hedged reads in sharded clusters.
Hedged reads are automatically enabled in MongoDB 4.4+ when using a
``nearest`` read preference. To explicitly enable hedged reads, set
the ``enabled`` key to ``true``::
>>> Nearest(hedge={'enabled': True})
To explicitly disable hedged reads, set the ``enabled`` key to
``False``::
>>> Nearest(hedge={'enabled': False})
.. versionadded:: 3.11
"""
return self.__hedge
@property
def min_wire_version(self) -> int:
"""The wire protocol version the server must support.
Some read preferences impose version requirements on all servers (e.g.
maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5).
All servers' maxWireVersion must be at least this read preference's
`min_wire_version`, or the driver raises
:exc:`~pymongo.errors.ConfigurationError`.
"""
return 0 if self.__max_staleness == -1 else 5
def __repr__(self) -> str:
return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format(
self.name,
self.__tag_sets,
self.__max_staleness,
self.__hedge,
)
def __eq__(self, other: Any) -> bool:
if isinstance(other, _ServerMode):
return (
self.mode == other.mode
and self.tag_sets == other.tag_sets
and self.max_staleness == other.max_staleness
and self.hedge == other.hedge
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __getstate__(self) -> dict[str, Any]:
"""Return value of object for pickling.
Needed explicitly because __slots__() defined.
"""
return {
"mode": self.__mode,
"tag_sets": self.__tag_sets,
"max_staleness": self.__max_staleness,
"hedge": self.__hedge,
}
def __setstate__(self, value: Mapping[str, Any]) -> None:
"""Restore from pickling."""
self.__mode = value["mode"]
self.__mongos_mode = _MONGOS_MODES[self.__mode]
self.__tag_sets = _validate_tag_sets(value["tag_sets"])
self.__max_staleness = _validate_max_staleness(value["max_staleness"])
self.__hedge = _validate_hedge(value["hedge"])
def __call__(self, selection: Selection) -> Selection:
return selection
class Primary(_ServerMode):
"""Primary read preference.
* When directly connected to one mongod queries are allowed if the server
is standalone or a replica set primary.
* When connected to a mongos queries are sent to the primary of a shard.
* When connected to a replica set queries are sent to the primary of
the replica set.
"""
__slots__ = ()
def __init__(self) -> None:
super().__init__(_PRIMARY)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to a Selection."""
return selection.primary_selection
def __repr__(self) -> str:
return "Primary()"
def __eq__(self, other: Any) -> bool:
if isinstance(other, _ServerMode):
return other.mode == _PRIMARY
return NotImplemented
class PrimaryPreferred(_ServerMode):
"""PrimaryPreferred read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are sent to the primary of a shard if
available, otherwise a shard secondary.
* When connected to a replica set queries are sent to the primary if
available, otherwise a secondary.
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
created reads will be routed to an available secondary until the
primary of the replica set is discovered.
:param tag_sets: The :attr:`~tag_sets` to use if the primary is not
available.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` to use if the primary is not available.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
if selection.primary:
return selection.primary_selection
else:
return secondary_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
class Secondary(_ServerMode):
"""Secondary read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among shard
secondaries. An error is raised if no secondaries are available.
* When connected to a replica set queries are distributed among
secondaries. An error is raised if no secondaries are available.
:param tag_sets: The :attr:`~tag_sets` for this read preference.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_SECONDARY, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
return secondary_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
class SecondaryPreferred(_ServerMode):
"""SecondaryPreferred read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among shard
secondaries, or the shard primary if no secondary is available.
* When connected to a replica set queries are distributed among
secondaries, or the primary if no secondary is available.
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
created reads will be routed to the primary of the replica set until
an available secondary is discovered.
:param tag_sets: The :attr:`~tag_sets` for this read preference.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
secondaries = secondary_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
if secondaries:
return secondaries
else:
return selection.primary_selection
class Nearest(_ServerMode):
"""Nearest read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among all members of
a shard.
* When connected to a replica set queries are distributed among all
members.
:param tag_sets: The :attr:`~tag_sets` for this read preference.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_NEAREST, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
return member_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
class _AggWritePref:
"""Agg $out/$merge write preference.
* If there are readable servers and there is any pre-5.0 server, use
primary read preference.
* Otherwise use `pref` read preference.
:param pref: The read preference to use on MongoDB 5.0+.
"""
__slots__ = ("pref", "effective_pref")
def __init__(self, pref: _ServerMode):
self.pref = pref
self.effective_pref: _ServerMode = ReadPreference.PRIMARY
def selection_hook(self, topology_description: TopologyDescription) -> None:
common_wv = topology_description.common_wire_version
if (
topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED)
and common_wv
and common_wv < 13
):
self.effective_pref = ReadPreference.PRIMARY
else:
self.effective_pref = self.pref
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to a Selection."""
return self.effective_pref(selection)
def __repr__(self) -> str:
return f"_AggWritePref(pref={self.pref!r})"
# Proxy other calls to the effective_pref so that _AggWritePref can be
# used in place of an actual read preference.
def __getattr__(self, name: str) -> Any:
return getattr(self.effective_pref, name)
_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest)
def make_read_preference(
mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1
) -> _ServerMode:
if mode == _PRIMARY:
if tag_sets not in (None, [{}]):
raise ConfigurationError("Read preference primary cannot be combined with tags")
if max_staleness != -1:
raise ConfigurationError(
"Read preference primary cannot be combined with maxStalenessSeconds"
)
return Primary()
return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore
_MODES = (
"PRIMARY",
"PRIMARY_PREFERRED",
"SECONDARY",
"SECONDARY_PREFERRED",
"NEAREST",
)
class ReadPreference:
"""An enum that defines some commonly used read preference modes.
Apps can also create a custom read preference, for example::
Nearest(tag_sets=[{"node":"analytics"}])
See :doc:`/examples/high_availability` for code examples.
A read preference is used in three cases:
:class:`~pymongo.mongo_client.MongoClient` connected to a single mongod:
- ``PRIMARY``: Queries are allowed if the server is standalone or a replica
set primary.
- All other modes allow queries to standalone servers, to a replica set
primary, or to replica set secondaries.
:class:`~pymongo.mongo_client.MongoClient` initialized with the
``replicaSet`` option:
- ``PRIMARY``: Read from the primary. This is the default, and provides the
strongest consistency. If no primary is available, raise
:class:`~pymongo.errors.AutoReconnect`.
- ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is
none, read from a secondary.
- ``SECONDARY``: Read from a secondary. If no secondary is available,
raise :class:`~pymongo.errors.AutoReconnect`.
- ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise
from the primary.
- ``NEAREST``: Read from any member.
:class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a
sharded cluster of replica sets:
- ``PRIMARY``: Read from the primary of the shard, or raise
:class:`~pymongo.errors.OperationFailure` if there is none.
This is the default.
- ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is
none, read from a secondary of the shard.
- ``SECONDARY``: Read from a secondary of the shard, or raise
:class:`~pymongo.errors.OperationFailure` if there is none.
- ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available,
otherwise from the shard primary.
- ``NEAREST``: Read from any shard member.
"""
PRIMARY = Primary()
PRIMARY_PREFERRED = PrimaryPreferred()
SECONDARY = Secondary()
SECONDARY_PREFERRED = SecondaryPreferred()
NEAREST = Nearest()
def read_pref_mode_from_name(name: str) -> int:
"""Get the read preference mode from mongos/uri name."""
return _MONGOS_MODES.index(name)
class MovingAverage:
"""Tracks an exponentially-weighted moving average."""
average: Optional[float]
def __init__(self) -> None:
self.average = None
def add_sample(self, sample: float) -> None:
if sample < 0:
# Likely system time change while waiting for hello response
# and not using time.monotonic. Ignore it, the next one will
# probably be valid.
return
if self.average is None:
self.average = sample
else:
# The Server Selection Spec requires an exponentially weighted
# average with alpha = 0.2.
self.average = 0.8 * self.average + 0.2 * sample
def get(self) -> Optional[float]:
"""Get the calculated average, or None if no samples yet."""
return self.average
def reset(self) -> None:
self.average = None

View File

@ -1,133 +0,0 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Represent a response from the server."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union
if TYPE_CHECKING:
from datetime import timedelta
from pymongo.asynchronous.message import _OpMsg, _OpReply
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.typings import _Address, _DocumentOut
_IS_SYNC = False
class Response:
__slots__ = ("_data", "_address", "_request_id", "_duration", "_from_command", "_docs")
def __init__(
self,
data: Union[_OpMsg, _OpReply],
address: _Address,
request_id: int,
duration: Optional[timedelta],
from_command: bool,
docs: Sequence[Mapping[str, Any]],
):
"""Represent a response from the server.
:param data: A network response message.
:param address: (host, port) of the source server.
:param request_id: The request id of this operation.
:param duration: The duration of the operation.
:param from_command: if the response is the result of a db command.
"""
self._data = data
self._address = address
self._request_id = request_id
self._duration = duration
self._from_command = from_command
self._docs = docs
@property
def data(self) -> Union[_OpMsg, _OpReply]:
"""Server response's raw BSON bytes."""
return self._data
@property
def address(self) -> _Address:
"""(host, port) of the source server."""
return self._address
@property
def request_id(self) -> int:
"""The request id of this operation."""
return self._request_id
@property
def duration(self) -> Optional[timedelta]:
"""The duration of the operation."""
return self._duration
@property
def from_command(self) -> bool:
"""If the response is a result from a db command."""
return self._from_command
@property
def docs(self) -> Sequence[Mapping[str, Any]]:
"""The decoded document(s)."""
return self._docs
class PinnedResponse(Response):
__slots__ = ("_conn", "_more_to_come")
def __init__(
self,
data: Union[_OpMsg, _OpReply],
address: _Address,
conn: Connection,
request_id: int,
duration: Optional[timedelta],
from_command: bool,
docs: list[_DocumentOut],
more_to_come: bool,
):
"""Represent a response to an exhaust cursor's initial query.
:param data: A network response message.
:param address: (host, port) of the source server.
:param conn: The Connection used for the initial query.
:param request_id: The request id of this operation.
:param duration: The duration of the operation.
:param from_command: If the response is the result of a db command.
:param docs: List of documents.
:param more_to_come: Bool indicating whether cursor is ready to be
exhausted.
"""
super().__init__(data, address, request_id, duration, from_command, docs)
self._conn = conn
self._more_to_come = more_to_come
@property
def conn(self) -> Connection:
"""The Connection used for the initial query.
The server will send batches on this socket, without waiting for
getMores from the client, until the result set is exhausted or there
is an error.
"""
return self._conn
@property
def more_to_come(self) -> bool:
"""If true, server is ready to send batches on the socket until the
result set is exhausted or there is an error.
"""
return self._more_to_come

View File

@ -27,11 +27,12 @@ from typing import (
)
from bson import _decode_all_selective
from pymongo.asynchronous.helpers import _check_command_response, _handle_reauth
from pymongo.asynchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.asynchronous.message import _convert_exception, _GetMore, _OpMsg, _Query
from pymongo.asynchronous.response import PinnedResponse, Response
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.errors import NotPrimaryError, OperationFailure
from pymongo.helpers_shared import _check_command_response
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query
from pymongo.response import PinnedResponse, Response
if TYPE_CHECKING:
from queue import Queue
@ -40,11 +41,11 @@ if TYPE_CHECKING:
from bson.objectid import ObjectId
from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler
from pymongo.asynchronous.monitor import Monitor
from pymongo.asynchronous.monitoring import _EventListeners
from pymongo.asynchronous.pool import Connection, Pool
from pymongo.asynchronous.read_preferences import _ServerMode
from pymongo.asynchronous.server_description import ServerDescription
from pymongo.asynchronous.typings import _DocumentOut
from pymongo.asynchronous.pool import AsyncConnection, Pool
from pymongo.monitoring import _EventListeners
from pymongo.read_preferences import _ServerMode
from pymongo.server_description import ServerDescription
from pymongo.typings import _DocumentOut
_IS_SYNC = False
@ -105,10 +106,23 @@ class Server:
"""Check the server's state soon."""
self._monitor.request_check()
async def operation_to_command(
self, operation: Union[_Query, _GetMore], conn: AsyncConnection, apply_timeout: bool = False
) -> tuple[dict[str, Any], str]:
cmd, db = operation.as_command(conn, apply_timeout)
# Support auto encryption
if operation.client._encrypter and not operation.client._encrypter._bypass_auto_encryption:
cmd = await operation.client._encrypter.encrypt( # type: ignore[misc, assignment]
operation.db, cmd, operation.codec_options
)
operation.update_command(cmd)
return cmd, db
@_handle_reauth
async def run_operation(
self,
conn: Connection,
conn: AsyncConnection,
operation: Union[_Query, _GetMore],
read_preference: _ServerMode,
listeners: Optional[_EventListeners],
@ -121,26 +135,26 @@ class Server:
cursors.
Can raise ConnectionFailure, OperationFailure, etc.
:param conn: A Connection instance.
:param conn: An AsyncConnection instance.
:param operation: A _Query or _GetMore object.
:param read_preference: The read preference to use.
:param listeners: Instance of _EventListeners or None.
:param unpack_res: A callable that decodes the wire protocol response.
:param client: An AsyncMongoClient instance.
"""
duration = None
assert listeners is not None
publish = listeners.enabled_for_commands
start = datetime.now()
use_cmd = operation.use_command(conn)
more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come
cmd, dbn = await self.operation_to_command(operation, conn, use_cmd)
if more_to_come:
request_id = 0
else:
message = await operation.get_message(read_preference, conn, use_cmd)
message = operation.get_message(read_preference, conn, use_cmd)
request_id, data, max_doc_size = self._split_message(message)
cmd, dbn = await operation.as_command(conn)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
@ -159,7 +173,6 @@ class Server:
)
if publish:
cmd, dbn = await operation.as_command(conn)
if "$db" not in cmd:
cmd["$db"] = dbn
assert listeners is not None
@ -195,7 +208,7 @@ class Server:
)
if use_cmd:
first = docs[0]
await operation.client._process_response(first, operation.session)
await operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type]
_check_command_response(first, conn.max_wire_version)
except Exception as exc:
duration = datetime.now() - start
@ -278,7 +291,7 @@ class Server:
)
# Decrypt response.
client = operation.client
client = operation.client # type: ignore[assignment]
if client and client._encrypter:
if use_cmd:
decrypted = client._encrypter.decrypt(reply.raw_command_response())
@ -286,7 +299,7 @@ class Server:
response: Response
if client._should_pin_cursor(operation.session) or operation.exhaust:
if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type]
conn.pin_cursor()
if isinstance(reply, _OpMsg):
# In OP_MSG, the server keeps sending only if the
@ -321,7 +334,7 @@ class Server:
async def checkout(
self, handler: Optional[_MongoClientErrorHandler] = None
) -> AsyncContextManager[Connection]:
) -> AsyncContextManager[AsyncConnection]:
return self.pool.checkout(handler)
@property

View File

@ -1,301 +0,0 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Represent one server the driver is connected to."""
from __future__ import annotations
import time
import warnings
from typing import Any, Mapping, Optional
from bson import EPOCH_NAIVE
from bson.objectid import ObjectId
from pymongo.asynchronous.hello import Hello
from pymongo.asynchronous.typings import ClusterTime, _Address
from pymongo.server_type import SERVER_TYPE
_IS_SYNC = False
class ServerDescription:
"""Immutable representation of one server.
:param address: A (host, port) pair
:param hello: Optional Hello instance
:param round_trip_time: Optional float
:param error: Optional, the last error attempting to connect to the server
:param round_trip_time: Optional float, the min latency from the most recent samples
"""
__slots__ = (
"_address",
"_server_type",
"_all_hosts",
"_tags",
"_replica_set_name",
"_primary",
"_max_bson_size",
"_max_message_size",
"_max_write_batch_size",
"_min_wire_version",
"_max_wire_version",
"_round_trip_time",
"_min_round_trip_time",
"_me",
"_is_writable",
"_is_readable",
"_ls_timeout_minutes",
"_error",
"_set_version",
"_election_id",
"_cluster_time",
"_last_write_date",
"_last_update_time",
"_topology_version",
)
def __init__(
self,
address: _Address,
hello: Optional[Hello] = None,
round_trip_time: Optional[float] = None,
error: Optional[Exception] = None,
min_round_trip_time: float = 0.0,
) -> None:
self._address = address
if not hello:
hello = Hello({})
self._server_type = hello.server_type
self._all_hosts = hello.all_hosts
self._tags = hello.tags
self._replica_set_name = hello.replica_set_name
self._primary = hello.primary
self._max_bson_size = hello.max_bson_size
self._max_message_size = hello.max_message_size
self._max_write_batch_size = hello.max_write_batch_size
self._min_wire_version = hello.min_wire_version
self._max_wire_version = hello.max_wire_version
self._set_version = hello.set_version
self._election_id = hello.election_id
self._cluster_time = hello.cluster_time
self._is_writable = hello.is_writable
self._is_readable = hello.is_readable
self._ls_timeout_minutes = hello.logical_session_timeout_minutes
self._round_trip_time = round_trip_time
self._min_round_trip_time = min_round_trip_time
self._me = hello.me
self._last_update_time = time.monotonic()
self._error = error
self._topology_version = hello.topology_version
if error:
details = getattr(error, "details", None)
if isinstance(details, dict):
self._topology_version = details.get("topologyVersion")
self._last_write_date: Optional[float]
if hello.last_write_date:
# Convert from datetime to seconds.
delta = hello.last_write_date - EPOCH_NAIVE
self._last_write_date = delta.total_seconds()
else:
self._last_write_date = None
@property
def address(self) -> _Address:
"""The address (host, port) of this server."""
return self._address
@property
def server_type(self) -> int:
"""The type of this server."""
return self._server_type
@property
def server_type_name(self) -> str:
"""The server type as a human readable string.
.. versionadded:: 3.4
"""
return SERVER_TYPE._fields[self._server_type]
@property
def all_hosts(self) -> set[tuple[str, int]]:
"""List of hosts, passives, and arbiters known to this server."""
return self._all_hosts
@property
def tags(self) -> Mapping[str, Any]:
return self._tags
@property
def replica_set_name(self) -> Optional[str]:
"""Replica set name or None."""
return self._replica_set_name
@property
def primary(self) -> Optional[tuple[str, int]]:
"""This server's opinion about who the primary is, or None."""
return self._primary
@property
def max_bson_size(self) -> int:
return self._max_bson_size
@property
def max_message_size(self) -> int:
return self._max_message_size
@property
def max_write_batch_size(self) -> int:
return self._max_write_batch_size
@property
def min_wire_version(self) -> int:
return self._min_wire_version
@property
def max_wire_version(self) -> int:
return self._max_wire_version
@property
def set_version(self) -> Optional[int]:
return self._set_version
@property
def election_id(self) -> Optional[ObjectId]:
return self._election_id
@property
def cluster_time(self) -> Optional[ClusterTime]:
return self._cluster_time
@property
def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]:
warnings.warn(
"'election_tuple' is deprecated, use 'set_version' and 'election_id' instead",
DeprecationWarning,
stacklevel=2,
)
return self._set_version, self._election_id
@property
def me(self) -> Optional[tuple[str, int]]:
return self._me
@property
def logical_session_timeout_minutes(self) -> Optional[int]:
return self._ls_timeout_minutes
@property
def last_write_date(self) -> Optional[float]:
return self._last_write_date
@property
def last_update_time(self) -> float:
return self._last_update_time
@property
def round_trip_time(self) -> Optional[float]:
"""The current average latency or None."""
# This override is for unittesting only!
if self._address in self._host_to_round_trip_time:
return self._host_to_round_trip_time[self._address]
return self._round_trip_time
@property
def min_round_trip_time(self) -> float:
"""The min latency from the most recent samples."""
return self._min_round_trip_time
@property
def error(self) -> Optional[Exception]:
"""The last error attempting to connect to the server, or None."""
return self._error
@property
def is_writable(self) -> bool:
return self._is_writable
@property
def is_readable(self) -> bool:
return self._is_readable
@property
def mongos(self) -> bool:
return self._server_type == SERVER_TYPE.Mongos
@property
def is_server_type_known(self) -> bool:
return self.server_type != SERVER_TYPE.Unknown
@property
def retryable_writes_supported(self) -> bool:
"""Checks if this server supports retryable writes."""
return (
self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)
) or self._server_type == SERVER_TYPE.LoadBalancer
@property
def retryable_reads_supported(self) -> bool:
"""Checks if this server supports retryable writes."""
return self._max_wire_version >= 6
@property
def topology_version(self) -> Optional[Mapping[str, Any]]:
return self._topology_version
def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription:
unknown = ServerDescription(self.address, error=error)
unknown._topology_version = self.topology_version
return unknown
def __eq__(self, other: Any) -> bool:
if isinstance(other, ServerDescription):
return (
(self._address == other.address)
and (self._server_type == other.server_type)
and (self._min_wire_version == other.min_wire_version)
and (self._max_wire_version == other.max_wire_version)
and (self._me == other.me)
and (self._all_hosts == other.all_hosts)
and (self._tags == other.tags)
and (self._replica_set_name == other.replica_set_name)
and (self._set_version == other.set_version)
and (self._election_id == other.election_id)
and (self._primary == other.primary)
and (self._ls_timeout_minutes == other.logical_session_timeout_minutes)
and (self._error == other.error)
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __repr__(self) -> str:
errmsg = ""
if self.error:
errmsg = f", error={self.error!r}"
return "<{} {} server_type: {}, rtt: {}{}>".format(
self.__class__.__name__,
self.address,
self.server_type_name,
self.round_trip_time,
errmsg,
)
# For unittesting only. Use under no circumstances!
_host_to_round_trip_time: dict = {}

View File

@ -1,175 +0,0 @@
# Copyright 2014-2016 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Criteria to select some ServerDescriptions from a TopologyDescription."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, TypeVar, cast
from pymongo.server_type import SERVER_TYPE
if TYPE_CHECKING:
from pymongo.asynchronous.server_description import ServerDescription
from pymongo.asynchronous.topology_description import TopologyDescription
_IS_SYNC = False
T = TypeVar("T")
TagSet = Mapping[str, Any]
TagSets = Sequence[TagSet]
class Selection:
"""Input or output of a server selector function."""
@classmethod
def from_topology_description(cls, topology_description: TopologyDescription) -> Selection:
known_servers = topology_description.known_servers
primary = None
for sd in known_servers:
if sd.server_type == SERVER_TYPE.RSPrimary:
primary = sd
break
return Selection(
topology_description,
topology_description.known_servers,
topology_description.common_wire_version,
primary,
)
def __init__(
self,
topology_description: TopologyDescription,
server_descriptions: list[ServerDescription],
common_wire_version: Optional[int],
primary: Optional[ServerDescription],
):
self.topology_description = topology_description
self.server_descriptions = server_descriptions
self.primary = primary
self.common_wire_version = common_wire_version
def with_server_descriptions(self, server_descriptions: list[ServerDescription]) -> Selection:
return Selection(
self.topology_description, server_descriptions, self.common_wire_version, self.primary
)
def secondary_with_max_last_write_date(self) -> Optional[ServerDescription]:
secondaries = secondary_server_selector(self)
if secondaries.server_descriptions:
return max(
secondaries.server_descriptions, key=lambda sd: cast(float, sd.last_write_date)
)
return None
@property
def primary_selection(self) -> Selection:
primaries = [self.primary] if self.primary else []
return self.with_server_descriptions(primaries)
@property
def heartbeat_frequency(self) -> int:
return self.topology_description.heartbeat_frequency
@property
def topology_type(self) -> int:
return self.topology_description.topology_type
def __bool__(self) -> bool:
return bool(self.server_descriptions)
def __getitem__(self, item: int) -> ServerDescription:
return self.server_descriptions[item]
def any_server_selector(selection: T) -> T:
return selection
def readable_server_selector(selection: Selection) -> Selection:
return selection.with_server_descriptions(
[s for s in selection.server_descriptions if s.is_readable]
)
def writable_server_selector(selection: Selection) -> Selection:
return selection.with_server_descriptions(
[s for s in selection.server_descriptions if s.is_writable]
)
def secondary_server_selector(selection: Selection) -> Selection:
return selection.with_server_descriptions(
[s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSSecondary]
)
def arbiter_server_selector(selection: Selection) -> Selection:
return selection.with_server_descriptions(
[s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSArbiter]
)
def writable_preferred_server_selector(selection: Selection) -> Selection:
"""Like PrimaryPreferred but doesn't use tags or latency."""
return writable_server_selector(selection) or secondary_server_selector(selection)
def apply_single_tag_set(tag_set: TagSet, selection: Selection) -> Selection:
"""All servers matching one tag set.
A tag set is a dict. A server matches if its tags are a superset:
A server tagged {'a': '1', 'b': '2'} matches the tag set {'a': '1'}.
The empty tag set {} matches any server.
"""
def tags_match(server_tags: Mapping[str, Any]) -> bool:
for key, value in tag_set.items():
if key not in server_tags or server_tags[key] != value:
return False
return True
return selection.with_server_descriptions(
[s for s in selection.server_descriptions if tags_match(s.tags)]
)
def apply_tag_sets(tag_sets: TagSets, selection: Selection) -> Selection:
"""All servers match a list of tag sets.
tag_sets is a list of dicts. The empty tag set {} matches any server,
and may be provided at the end of the list as a fallback. So
[{'a': 'value'}, {}] expresses a preference for servers tagged
{'a': 'value'}, but accepts any server if none matches the first
preference.
"""
for tag_set in tag_sets:
with_tag_set = apply_single_tag_set(tag_set, selection)
if with_tag_set:
return with_tag_set
return selection.with_server_descriptions([])
def secondary_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection:
"""All near-enough secondaries matching the tag sets."""
return apply_tag_sets(tag_sets, secondary_server_selector(selection))
def member_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection:
"""All near-enough members matching the tag sets."""
return apply_tag_sets(tag_sets, readable_server_selector(selection))

View File

@ -20,12 +20,14 @@ import traceback
from typing import Any, Collection, Optional, Type, Union
from bson.objectid import ObjectId
from pymongo.asynchronous import common, monitor, pool
from pymongo.asynchronous.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT
from pymongo.asynchronous.pool import Pool, PoolOptions
from pymongo.asynchronous.server_description import ServerDescription
from pymongo.asynchronous.topology_description import TOPOLOGY_TYPE, _ServerSelector
from pymongo import common
from pymongo.asynchronous import monitor, pool
from pymongo.asynchronous.pool import Pool
from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT
from pymongo.errors import ConfigurationError
from pymongo.pool_options import PoolOptions
from pymongo.server_description import ServerDescription
from pymongo.topology_description import TOPOLOGY_TYPE, _ServerSelector
_IS_SYNC = False

View File

@ -1,149 +0,0 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
from __future__ import annotations
import ipaddress
import random
from typing import TYPE_CHECKING, Any, Optional, Union
from pymongo.asynchronous.common import CONNECT_TIMEOUT
from pymongo.errors import ConfigurationError
if TYPE_CHECKING:
from dns import resolver
_IS_SYNC = False
def _have_dnspython() -> bool:
try:
import dns # noqa: F401
return True
except ImportError:
return False
# dnspython can return bytes or str from various parts
# of its API depending on version. We always want str.
def maybe_decode(text: Union[str, bytes]) -> str:
if isinstance(text, bytes):
return text.decode()
return text
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
from dns import resolver
if hasattr(resolver, "resolve"):
# dnspython >= 2
return resolver.resolve(*args, **kwargs)
# dnspython 1.X
return resolver.query(*args, **kwargs)
_INVALID_HOST_MSG = (
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. "
"Did you mean to use 'mongodb://'?"
)
class _SrvResolver:
def __init__(
self,
fqdn: str,
connect_timeout: Optional[float],
srv_service_name: str,
srv_max_hosts: int = 0,
):
self.__fqdn = fqdn
self.__srv = srv_service_name
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
self.__srv_max_hosts = srv_max_hosts or 0
# Validate the fully qualified domain name.
try:
ipaddress.ip_address(fqdn)
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
except ValueError:
pass
try:
self.__plist = self.__fqdn.split(".")[1:]
except Exception:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
self.__slen = len(self.__plist)
if self.__slen < 2:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
def get_options(self) -> Optional[str]:
from dns import resolver
try:
results = _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout)
except (resolver.NoAnswer, resolver.NXDOMAIN):
# No TXT records
return None
except Exception as exc:
raise ConfigurationError(str(exc)) from None
if len(results) > 1:
raise ConfigurationError("Only one TXT record is supported")
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8")
def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer:
try:
results = _resolve(
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
)
except Exception as exc:
if not encapsulate_errors:
# Raise the original error.
raise
# Else, raise all errors as ConfigurationError.
raise ConfigurationError(str(exc)) from None
return results
def _get_srv_response_and_hosts(
self, encapsulate_errors: bool
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
results = self._resolve_uri(encapsulate_errors)
# Construct address tuples
nodes = [
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) for res in results
]
# Validate hosts
for node in nodes:
try:
nlist = node[0].lower().split(".")[1:][-self.__slen :]
except Exception:
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
if self.__plist != nlist:
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
if self.__srv_max_hosts:
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
return results, nodes
def get_hosts(self) -> list[tuple[str, Any]]:
_, nodes = self._get_srv_response_and_hosts(True)
return nodes
def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]:
results, nodes = self._get_srv_response_and_hosts(False)
rrset = results.rrset
ttl = rrset.ttl if rrset else 0
return nodes, ttl

View File

@ -27,33 +27,12 @@ import weakref
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
from pymongo import _csot, helpers_constants
from pymongo.asynchronous import common, periodic_executor
from pymongo import _csot, common, helpers_shared
from pymongo.asynchronous import periodic_executor
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.asynchronous.hello import Hello
from pymongo.asynchronous.logger import (
_SERVER_SELECTION_LOGGER,
_debug_log,
_ServerSelectionStatusMessage,
)
from pymongo.asynchronous.monitor import SrvMonitor
from pymongo.asynchronous.pool import Pool, PoolOptions
from pymongo.asynchronous.pool import Pool
from pymongo.asynchronous.server import Server
from pymongo.asynchronous.server_description import ServerDescription
from pymongo.asynchronous.server_selectors import (
Selection,
any_server_selector,
arbiter_server_selector,
secondary_server_selector,
writable_server_selector,
)
from pymongo.asynchronous.topology_description import (
SRV_POLLING_TOPOLOGIES,
TOPOLOGY_TYPE,
TopologyDescription,
_updated_topology_description_srv_polling,
updated_topology_description,
)
from pymongo.errors import (
ConnectionFailure,
InvalidOperation,
@ -64,12 +43,34 @@ from pymongo.errors import (
ServerSelectionTimeoutError,
WriteError,
)
from pymongo.hello import Hello
from pymongo.lock import _ACondition, _ALock, _create_lock
from pymongo.logger import (
_SERVER_SELECTION_LOGGER,
_debug_log,
_ServerSelectionStatusMessage,
)
from pymongo.pool_options import PoolOptions
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import (
Selection,
any_server_selector,
arbiter_server_selector,
secondary_server_selector,
writable_server_selector,
)
from pymongo.topology_description import (
SRV_POLLING_TOPOLOGIES,
TOPOLOGY_TYPE,
TopologyDescription,
_updated_topology_description_srv_polling,
updated_topology_description,
)
if TYPE_CHECKING:
from bson import ObjectId
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.typings import ClusterTime, _Address
from pymongo.typings import ClusterTime, _Address
_IS_SYNC = False
@ -781,8 +782,8 @@ class Topology:
# Default error code if one does not exist.
default = 10107 if isinstance(error, NotPrimaryError) else None
err_code = error.details.get("code", default) # type: ignore[union-attr]
if err_code in helpers_constants._NOT_PRIMARY_CODES:
is_shutting_down = err_code in helpers_constants._SHUTDOWN_CODES
if err_code in helpers_shared._NOT_PRIMARY_CODES:
is_shutting_down = err_code in helpers_shared._SHUTDOWN_CODES
# Mark server Unknown, clear the pool, and request check.
if not self._settings.load_balanced:
await self._process_change(ServerDescription(address, error=error))

View File

@ -1,678 +0,0 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Represent a deployment of MongoDB servers."""
from __future__ import annotations
from random import sample
from typing import (
Any,
Callable,
List,
Mapping,
MutableMapping,
NamedTuple,
Optional,
cast,
)
from bson.min_key import MinKey
from bson.objectid import ObjectId
from pymongo.asynchronous import common
from pymongo.asynchronous.read_preferences import ReadPreference, _AggWritePref, _ServerMode
from pymongo.asynchronous.server_description import ServerDescription
from pymongo.asynchronous.server_selectors import Selection
from pymongo.asynchronous.typings import _Address
from pymongo.errors import ConfigurationError
from pymongo.server_type import SERVER_TYPE
_IS_SYNC = False
# Enumeration for various kinds of MongoDB cluster topologies.
class _TopologyType(NamedTuple):
Single: int
ReplicaSetNoPrimary: int
ReplicaSetWithPrimary: int
Sharded: int
Unknown: int
LoadBalanced: int
TOPOLOGY_TYPE = _TopologyType(*range(6))
# Topologies compatible with SRV record polling.
SRV_POLLING_TOPOLOGIES: tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded)
_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]]
class TopologyDescription:
def __init__(
self,
topology_type: int,
server_descriptions: dict[_Address, ServerDescription],
replica_set_name: Optional[str],
max_set_version: Optional[int],
max_election_id: Optional[ObjectId],
topology_settings: Any,
) -> None:
"""Representation of a deployment of MongoDB servers.
:param topology_type: initial type
:param server_descriptions: dict of (address, ServerDescription) for
all seeds
:param replica_set_name: replica set name or None
:param max_set_version: greatest setVersion seen from a primary, or None
:param max_election_id: greatest electionId seen from a primary, or None
:param topology_settings: a TopologySettings
"""
self._topology_type = topology_type
self._replica_set_name = replica_set_name
self._server_descriptions = server_descriptions
self._max_set_version = max_set_version
self._max_election_id = max_election_id
# The heartbeat_frequency is used in staleness estimates.
self._topology_settings = topology_settings
# Is PyMongo compatible with all servers' wire protocols?
self._incompatible_err = None
if self._topology_type != TOPOLOGY_TYPE.LoadBalanced:
self._init_incompatible_err()
# Server Discovery And Monitoring Spec: Whenever a client updates the
# TopologyDescription from an hello response, it MUST set
# TopologyDescription.logicalSessionTimeoutMinutes to the smallest
# logicalSessionTimeoutMinutes value among ServerDescriptions of all
# data-bearing server types. If any have a null
# logicalSessionTimeoutMinutes, then
# TopologyDescription.logicalSessionTimeoutMinutes MUST be set to null.
readable_servers = self.readable_servers
if not readable_servers:
self._ls_timeout_minutes = None
elif any(s.logical_session_timeout_minutes is None for s in readable_servers):
self._ls_timeout_minutes = None
else:
self._ls_timeout_minutes = min( # type: ignore[type-var]
s.logical_session_timeout_minutes for s in readable_servers
)
def _init_incompatible_err(self) -> None:
"""Internal compatibility check for non-load balanced topologies."""
for s in self._server_descriptions.values():
if not s.is_server_type_known:
continue
# s.min/max_wire_version is the server's wire protocol.
# MIN/MAX_SUPPORTED_WIRE_VERSION is what PyMongo supports.
server_too_new = (
# Server too new.
s.min_wire_version is not None
and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION
)
server_too_old = (
# Server too old.
s.max_wire_version is not None
and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION
)
if server_too_new:
self._incompatible_err = (
"Server at %s:%d requires wire version %d, but this " # type: ignore
"version of PyMongo only supports up to %d."
% (
s.address[0],
s.address[1] or 0,
s.min_wire_version,
common.MAX_SUPPORTED_WIRE_VERSION,
)
)
elif server_too_old:
self._incompatible_err = (
"Server at %s:%d reports wire version %d, but this " # type: ignore
"version of PyMongo requires at least %d (MongoDB %s)."
% (
s.address[0],
s.address[1] or 0,
s.max_wire_version,
common.MIN_SUPPORTED_WIRE_VERSION,
common.MIN_SUPPORTED_SERVER_VERSION,
)
)
break
def check_compatible(self) -> None:
"""Raise ConfigurationError if any server is incompatible.
A server is incompatible if its wire protocol version range does not
overlap with PyMongo's.
"""
if self._incompatible_err:
raise ConfigurationError(self._incompatible_err)
def has_server(self, address: _Address) -> bool:
return address in self._server_descriptions
def reset_server(self, address: _Address) -> TopologyDescription:
"""A copy of this description, with one server marked Unknown."""
unknown_sd = self._server_descriptions[address].to_unknown()
return updated_topology_description(self, unknown_sd)
def reset(self) -> TopologyDescription:
"""A copy of this description, with all servers marked Unknown."""
if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary:
topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary
else:
topology_type = self._topology_type
# The default ServerDescription's type is Unknown.
sds = {address: ServerDescription(address) for address in self._server_descriptions}
return TopologyDescription(
topology_type,
sds,
self._replica_set_name,
self._max_set_version,
self._max_election_id,
self._topology_settings,
)
def server_descriptions(self) -> dict[_Address, ServerDescription]:
"""dict of (address,
:class:`~pymongo.server_description.ServerDescription`).
"""
return self._server_descriptions.copy()
@property
def topology_type(self) -> int:
"""The type of this topology."""
return self._topology_type
@property
def topology_type_name(self) -> str:
"""The topology type as a human readable string.
.. versionadded:: 3.4
"""
return TOPOLOGY_TYPE._fields[self._topology_type]
@property
def replica_set_name(self) -> Optional[str]:
"""The replica set name."""
return self._replica_set_name
@property
def max_set_version(self) -> Optional[int]:
"""Greatest setVersion seen from a primary, or None."""
return self._max_set_version
@property
def max_election_id(self) -> Optional[ObjectId]:
"""Greatest electionId seen from a primary, or None."""
return self._max_election_id
@property
def logical_session_timeout_minutes(self) -> Optional[int]:
"""Minimum logical session timeout, or None."""
return self._ls_timeout_minutes
@property
def known_servers(self) -> list[ServerDescription]:
"""List of Servers of types besides Unknown."""
return [s for s in self._server_descriptions.values() if s.is_server_type_known]
@property
def has_known_servers(self) -> bool:
"""Whether there are any Servers of types besides Unknown."""
return any(s for s in self._server_descriptions.values() if s.is_server_type_known)
@property
def readable_servers(self) -> list[ServerDescription]:
"""List of readable Servers."""
return [s for s in self._server_descriptions.values() if s.is_readable]
@property
def common_wire_version(self) -> Optional[int]:
"""Minimum of all servers' max wire versions, or None."""
servers = self.known_servers
if servers:
return min(s.max_wire_version for s in self.known_servers)
return None
@property
def heartbeat_frequency(self) -> int:
return self._topology_settings.heartbeat_frequency
@property
def srv_max_hosts(self) -> int:
return self._topology_settings._srv_max_hosts
def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]:
if not selection:
return []
round_trip_times: list[float] = []
for server in selection.server_descriptions:
if server.round_trip_time is None:
config_err_msg = f"round_trip_time for server {server.address} is unexpectedly None: {self}, servers: {selection.server_descriptions}"
raise ConfigurationError(config_err_msg)
round_trip_times.append(server.round_trip_time)
# Round trip time in seconds.
fastest = min(round_trip_times)
threshold = self._topology_settings.local_threshold_ms / 1000.0
return [
s
for s in selection.server_descriptions
if (cast(float, s.round_trip_time) - fastest) <= threshold
]
def apply_selector(
self,
selector: Any,
address: Optional[_Address] = None,
custom_selector: Optional[_ServerSelector] = None,
) -> list[ServerDescription]:
"""List of servers matching the provided selector(s).
:param selector: a callable that takes a Selection as input and returns
a Selection as output. For example, an instance of a read
preference from :mod:`~pymongo.read_preferences`.
:param address: A server address to select.
:param custom_selector: A callable that augments server
selection rules. Accepts a list of
:class:`~pymongo.server_description.ServerDescription` objects and
return a list of server descriptions that should be considered
suitable for the desired operation.
.. versionadded:: 3.4
"""
if getattr(selector, "min_wire_version", 0):
common_wv = self.common_wire_version
if common_wv and common_wv < selector.min_wire_version:
raise ConfigurationError(
"%s requires min wire version %d, but topology's min"
" wire version is %d" % (selector, selector.min_wire_version, common_wv)
)
if isinstance(selector, _AggWritePref):
selector.selection_hook(self)
if self.topology_type == TOPOLOGY_TYPE.Unknown:
return []
elif self.topology_type in (TOPOLOGY_TYPE.Single, TOPOLOGY_TYPE.LoadBalanced):
# Ignore selectors for standalone and load balancer mode.
return self.known_servers
if address:
# Ignore selectors when explicit address is requested.
description = self.server_descriptions().get(address)
return [description] if description else []
selection = Selection.from_topology_description(self)
# Ignore read preference for sharded clusters.
if self.topology_type != TOPOLOGY_TYPE.Sharded:
selection = selector(selection)
# Apply custom selector followed by localThresholdMS.
if custom_selector is not None and selection:
selection = selection.with_server_descriptions(
custom_selector(selection.server_descriptions)
)
return self._apply_local_threshold(selection)
def has_readable_server(self, read_preference: _ServerMode = ReadPreference.PRIMARY) -> bool:
"""Does this topology have any readable servers available matching the
given read preference?
:param read_preference: an instance of a read preference from
:mod:`~pymongo.read_preferences`. Defaults to
:attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`.
.. note:: When connected directly to a single server this method
always returns ``True``.
.. versionadded:: 3.4
"""
common.validate_read_preference("read_preference", read_preference)
return any(self.apply_selector(read_preference))
def has_writable_server(self) -> bool:
"""Does this topology have a writable server available?
.. note:: When connected directly to a single server this method
always returns ``True``.
.. versionadded:: 3.4
"""
return self.has_readable_server(ReadPreference.PRIMARY)
def __repr__(self) -> str:
# Sort the servers by address.
servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address)
return "<{} id: {}, topology_type: {}, servers: {!r}>".format(
self.__class__.__name__,
self._topology_settings._topology_id,
self.topology_type_name,
servers,
)
# If topology type is Unknown and we receive a hello response, what should
# the new topology type be?
_SERVER_TYPE_TO_TOPOLOGY_TYPE = {
SERVER_TYPE.Mongos: TOPOLOGY_TYPE.Sharded,
SERVER_TYPE.RSPrimary: TOPOLOGY_TYPE.ReplicaSetWithPrimary,
SERVER_TYPE.RSSecondary: TOPOLOGY_TYPE.ReplicaSetNoPrimary,
SERVER_TYPE.RSArbiter: TOPOLOGY_TYPE.ReplicaSetNoPrimary,
SERVER_TYPE.RSOther: TOPOLOGY_TYPE.ReplicaSetNoPrimary,
# Note: SERVER_TYPE.LoadBalancer and Unknown are intentionally left out.
}
def updated_topology_description(
topology_description: TopologyDescription, server_description: ServerDescription
) -> TopologyDescription:
"""Return an updated copy of a TopologyDescription.
:param topology_description: the current TopologyDescription
:param server_description: a new ServerDescription that resulted from
a hello call
Called after attempting (successfully or not) to call hello on the
server at server_description.address. Does not modify topology_description.
"""
address = server_description.address
# These values will be updated, if necessary, to form the new
# TopologyDescription.
topology_type = topology_description.topology_type
set_name = topology_description.replica_set_name
max_set_version = topology_description.max_set_version
max_election_id = topology_description.max_election_id
server_type = server_description.server_type
# Don't mutate the original dict of server descriptions; copy it.
sds = topology_description.server_descriptions()
# Replace this server's description with the new one.
sds[address] = server_description
if topology_type == TOPOLOGY_TYPE.Single:
# Set server type to Unknown if replica set name does not match.
if set_name is not None and set_name != server_description.replica_set_name:
error = ConfigurationError(
"client is configured to connect to a replica set named "
"'{}' but this node belongs to a set named '{}'".format(
set_name, server_description.replica_set_name
)
)
sds[address] = server_description.to_unknown(error=error)
# Single type never changes.
return TopologyDescription(
TOPOLOGY_TYPE.Single,
sds,
set_name,
max_set_version,
max_election_id,
topology_description._topology_settings,
)
if topology_type == TOPOLOGY_TYPE.Unknown:
if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.LoadBalancer):
if len(topology_description._topology_settings.seeds) == 1:
topology_type = TOPOLOGY_TYPE.Single
else:
# Remove standalone from Topology when given multiple seeds.
sds.pop(address)
elif server_type not in (SERVER_TYPE.Unknown, SERVER_TYPE.RSGhost):
topology_type = _SERVER_TYPE_TO_TOPOLOGY_TYPE[server_type]
if topology_type == TOPOLOGY_TYPE.Sharded:
if server_type not in (SERVER_TYPE.Mongos, SERVER_TYPE.Unknown):
sds.pop(address)
elif topology_type == TOPOLOGY_TYPE.ReplicaSetNoPrimary:
if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos):
sds.pop(address)
elif server_type == SERVER_TYPE.RSPrimary:
(topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary(
sds, set_name, server_description, max_set_version, max_election_id
)
elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther):
topology_type, set_name = _update_rs_no_primary_from_member(
sds, set_name, server_description
)
elif topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary:
if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos):
sds.pop(address)
topology_type = _check_has_primary(sds)
elif server_type == SERVER_TYPE.RSPrimary:
(topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary(
sds, set_name, server_description, max_set_version, max_election_id
)
elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther):
topology_type = _update_rs_with_primary_from_member(sds, set_name, server_description)
else:
# Server type is Unknown or RSGhost: did we just lose the primary?
topology_type = _check_has_primary(sds)
# Return updated copy.
return TopologyDescription(
topology_type,
sds,
set_name,
max_set_version,
max_election_id,
topology_description._topology_settings,
)
def _updated_topology_description_srv_polling(
topology_description: TopologyDescription, seedlist: list[tuple[str, Any]]
) -> TopologyDescription:
"""Return an updated copy of a TopologyDescription.
:param topology_description: the current TopologyDescription
:param seedlist: a list of new seeds new ServerDescription that resulted from
a hello call
"""
assert topology_description.topology_type in SRV_POLLING_TOPOLOGIES
# Create a copy of the server descriptions.
sds = topology_description.server_descriptions()
# If seeds haven't changed, don't do anything.
if set(sds.keys()) == set(seedlist):
return topology_description
# Remove SDs corresponding to servers no longer part of the SRV record.
for address in list(sds.keys()):
if address not in seedlist:
sds.pop(address)
if topology_description.srv_max_hosts != 0:
new_hosts = set(seedlist) - set(sds.keys())
n_to_add = topology_description.srv_max_hosts - len(sds)
if n_to_add > 0:
seedlist = sample(sorted(new_hosts), min(n_to_add, len(new_hosts)))
else:
seedlist = []
# Add SDs corresponding to servers recently added to the SRV record.
for address in seedlist:
if address not in sds:
sds[address] = ServerDescription(address)
return TopologyDescription(
topology_description.topology_type,
sds,
topology_description.replica_set_name,
topology_description.max_set_version,
topology_description.max_election_id,
topology_description._topology_settings,
)
def _update_rs_from_primary(
sds: MutableMapping[_Address, ServerDescription],
replica_set_name: Optional[str],
server_description: ServerDescription,
max_set_version: Optional[int],
max_election_id: Optional[ObjectId],
) -> tuple[int, Optional[str], Optional[int], Optional[ObjectId]]:
"""Update topology description from a primary's hello response.
Pass in a dict of ServerDescriptions, current replica set name, the
ServerDescription we are processing, and the TopologyDescription's
max_set_version and max_election_id if any.
Returns (new topology type, new replica_set_name, new max_set_version,
new max_election_id).
"""
if replica_set_name is None:
replica_set_name = server_description.replica_set_name
elif replica_set_name != server_description.replica_set_name:
# We found a primary but it doesn't have the replica_set_name
# provided by the user.
sds.pop(server_description.address)
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
if server_description.max_wire_version is None or server_description.max_wire_version < 17:
new_election_tuple: tuple = (server_description.set_version, server_description.election_id)
max_election_tuple: tuple = (max_set_version, max_election_id)
if None not in new_election_tuple:
if None not in max_election_tuple and new_election_tuple < max_election_tuple:
# Stale primary, set to type Unknown.
sds[server_description.address] = server_description.to_unknown()
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
max_election_id = server_description.election_id
if server_description.set_version is not None and (
max_set_version is None or server_description.set_version > max_set_version
):
max_set_version = server_description.set_version
else:
new_election_tuple = server_description.election_id, server_description.set_version
max_election_tuple = max_election_id, max_set_version
new_election_safe = tuple(MinKey() if i is None else i for i in new_election_tuple)
max_election_safe = tuple(MinKey() if i is None else i for i in max_election_tuple)
if new_election_safe < max_election_safe:
# Stale primary, set to type Unknown.
sds[server_description.address] = server_description.to_unknown()
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
else:
max_election_id = server_description.election_id
max_set_version = server_description.set_version
# We've heard from the primary. Is it the same primary as before?
for server in sds.values():
if (
server.server_type is SERVER_TYPE.RSPrimary
and server.address != server_description.address
):
# Reset old primary's type to Unknown.
sds[server.address] = server.to_unknown()
# There can be only one prior primary.
break
# Discover new hosts from this primary's response.
for new_address in server_description.all_hosts:
if new_address not in sds:
sds[new_address] = ServerDescription(new_address)
# Remove hosts not in the response.
for addr in set(sds) - server_description.all_hosts:
sds.pop(addr)
# If the host list differs from the seed list, we may not have a primary
# after all.
return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id)
def _update_rs_with_primary_from_member(
sds: MutableMapping[_Address, ServerDescription],
replica_set_name: Optional[str],
server_description: ServerDescription,
) -> int:
"""RS with known primary. Process a response from a non-primary.
Pass in a dict of ServerDescriptions, current replica set name, and the
ServerDescription we are processing.
Returns new topology type.
"""
assert replica_set_name is not None
if replica_set_name != server_description.replica_set_name:
sds.pop(server_description.address)
elif server_description.me and server_description.address != server_description.me:
sds.pop(server_description.address)
# Had this member been the primary?
return _check_has_primary(sds)
def _update_rs_no_primary_from_member(
sds: MutableMapping[_Address, ServerDescription],
replica_set_name: Optional[str],
server_description: ServerDescription,
) -> tuple[int, Optional[str]]:
"""RS without known primary. Update from a non-primary's response.
Pass in a dict of ServerDescriptions, current replica set name, and the
ServerDescription we are processing.
Returns (new topology type, new replica_set_name).
"""
topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary
if replica_set_name is None:
replica_set_name = server_description.replica_set_name
elif replica_set_name != server_description.replica_set_name:
sds.pop(server_description.address)
return topology_type, replica_set_name
# This isn't the primary's response, so don't remove any servers
# it doesn't report. Only add new servers.
for address in server_description.all_hosts:
if address not in sds:
sds[address] = ServerDescription(address)
if server_description.me and server_description.address != server_description.me:
sds.pop(server_description.address)
return topology_type, replica_set_name
def _check_has_primary(sds: Mapping[_Address, ServerDescription]) -> int:
"""Current topology type is ReplicaSetWithPrimary. Is primary still known?
Pass in a dict of ServerDescriptions.
Returns new topology type.
"""
for s in sds.values():
if s.server_type == SERVER_TYPE.RSPrimary:
return TOPOLOGY_TYPE.ReplicaSetWithPrimary
else: # noqa: PLW0120
return TOPOLOGY_TYPE.ReplicaSetNoPrimary

View File

@ -1,61 +0,0 @@
# Copyright 2022-Present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Type aliases used by PyMongo"""
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from bson.typings import _DocumentOut, _DocumentType, _DocumentTypeArg
if TYPE_CHECKING:
from pymongo.asynchronous.collation import Collation
_IS_SYNC = False
# Common Shared Types.
_Address = Tuple[str, Optional[int]]
_CollationIn = Union[Mapping[str, Any], "Collation"]
_Pipeline = Sequence[Mapping[str, Any]]
ClusterTime = Mapping[str, Any]
_T = TypeVar("_T")
def strip_optional(elem: Optional[_T]) -> _T:
"""This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T
while inside a list comprehension.
"""
assert elem is not None
return elem
__all__ = [
"_DocumentOut",
"_DocumentType",
"_DocumentTypeArg",
"_Address",
"_CollationIn",
"_Pipeline",
"strip_optional",
]

View File

@ -1,624 +0,0 @@
# Copyright 2011-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Tools to parse and validate a MongoDB URI."""
from __future__ import annotations
import re
import sys
import warnings
from typing import (
TYPE_CHECKING,
Any,
Mapping,
MutableMapping,
Optional,
Sized,
Union,
cast,
)
from urllib.parse import unquote_plus
from pymongo.asynchronous.client_options import _parse_ssl_options
from pymongo.asynchronous.common import (
INTERNAL_URI_OPTION_NAME_MAP,
SRV_SERVICE_NAME,
URI_OPTIONS_DEPRECATION_MAP,
_CaseInsensitiveDictionary,
get_validated_options,
)
from pymongo.asynchronous.srv_resolver import _have_dnspython, _SrvResolver
from pymongo.asynchronous.typings import _Address
from pymongo.errors import ConfigurationError, InvalidURI
if TYPE_CHECKING:
from pymongo.pyopenssl_context import SSLContext
_IS_SYNC = False
SCHEME = "mongodb://"
SCHEME_LEN = len(SCHEME)
SRV_SCHEME = "mongodb+srv://"
SRV_SCHEME_LEN = len(SRV_SCHEME)
DEFAULT_PORT = 27017
def _unquoted_percent(s: str) -> bool:
"""Check for unescaped percent signs.
:param s: A string. `s` can have things like '%25', '%2525',
and '%E2%85%A8' but cannot have unquoted percent like '%foo'.
"""
for i in range(len(s)):
if s[i] == "%":
sub = s[i : i + 3]
# If unquoting yields the same string this means there was an
# unquoted %.
if unquote_plus(sub) == sub:
return True
return False
def parse_userinfo(userinfo: str) -> tuple[str, str]:
"""Validates the format of user information in a MongoDB URI.
Reserved characters that are gen-delimiters (":", "/", "?", "#", "[",
"]", "@") as per RFC 3986 must be escaped.
Returns a 2-tuple containing the unescaped username followed
by the unescaped password.
:param userinfo: A string of the form <username>:<password>
"""
if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo):
raise InvalidURI(
"Username and password must be escaped according to "
"RFC 3986, use urllib.parse.quote_plus"
)
user, _, passwd = userinfo.partition(":")
# No password is expected with GSSAPI authentication.
if not user:
raise InvalidURI("The empty string is not valid username.")
return unquote_plus(user), unquote_plus(passwd)
def parse_ipv6_literal_host(
entity: str, default_port: Optional[int]
) -> tuple[str, Optional[Union[str, int]]]:
"""Validates an IPv6 literal host:port string.
Returns a 2-tuple of IPv6 literal followed by port where
port is default_port if it wasn't specified in entity.
:param entity: A string that represents an IPv6 literal enclosed
in braces (e.g. '[::1]' or '[::1]:27017').
:param default_port: The port number to use when one wasn't
specified in entity.
"""
if entity.find("]") == -1:
raise ValueError(
"an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732."
)
i = entity.find("]:")
if i == -1:
return entity[1:-1], default_port
return entity[1:i], entity[i + 2 :]
def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address:
"""Validates a host string
Returns a 2-tuple of host followed by port where port is default_port
if it wasn't specified in the string.
:param entity: A host or host:port string where host could be a
hostname or IP address.
:param default_port: The port number to use when one wasn't
specified in entity.
"""
host = entity
port: Optional[Union[str, int]] = default_port
if entity[0] == "[":
host, port = parse_ipv6_literal_host(entity, default_port)
elif entity.endswith(".sock"):
return entity, default_port
elif entity.find(":") != -1:
if entity.count(":") > 1:
raise ValueError(
"Reserved characters such as ':' must be "
"escaped according RFC 2396. An IPv6 "
"address literal must be enclosed in '[' "
"and ']' according to RFC 2732."
)
host, port = host.split(":", 1)
if isinstance(port, str):
if not port.isdigit() or int(port) > 65535 or int(port) <= 0:
raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}")
port = int(port)
# Normalize hostname to lowercase, since DNS is case-insensitive:
# http://tools.ietf.org/html/rfc4343
# This prevents useless rediscovery if "foo.com" is in the seed list but
# "FOO.com" is in the hello response.
return host.lower(), port
# Options whose values are implicitly determined by tlsInsecure.
_IMPLICIT_TLSINSECURE_OPTS = {
"tlsallowinvalidcertificates",
"tlsallowinvalidhostnames",
"tlsdisableocspendpointcheck",
}
def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary:
"""Helper method for split_options which creates the options dict.
Also handles the creation of a list for the URI tag_sets/
readpreferencetags portion, and the use of a unicode options string.
"""
options = _CaseInsensitiveDictionary()
for uriopt in opts.split(delim):
key, value = uriopt.split("=")
if key.lower() == "readpreferencetags":
options.setdefault(key, []).append(value)
else:
if key in options:
warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2)
if key.lower() == "authmechanismproperties":
val = value
else:
val = unquote_plus(value)
options[key] = val
return options
def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
"""Raise appropriate errors when conflicting TLS options are present in
the options dictionary.
:param options: Instance of _CaseInsensitiveDictionary containing
MongoDB URI options.
"""
# Implicitly defined options must not be explicitly specified.
tlsinsecure = options.get("tlsinsecure")
if tlsinsecure is not None:
for opt in _IMPLICIT_TLSINSECURE_OPTS:
if opt in options:
err_msg = "URI options %s and %s cannot be specified simultaneously."
raise InvalidURI(
err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt))
)
# Handle co-occurence of OCSP & tlsAllowInvalidCertificates options.
tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates")
if tlsallowinvalidcerts is not None:
if "tlsdisableocspendpointcheck" in options:
err_msg = "URI options %s and %s cannot be specified simultaneously."
raise InvalidURI(
err_msg
% ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck"))
)
if tlsallowinvalidcerts is True:
options["tlsdisableocspendpointcheck"] = True
# Handle co-occurence of CRL and OCSP-related options.
tlscrlfile = options.get("tlscrlfile")
if tlscrlfile is not None:
for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"):
if options.get(opt) is True:
err_msg = "URI option %s=True cannot be specified when CRL checking is enabled."
raise InvalidURI(err_msg % (opt,))
if "ssl" in options and "tls" in options:
def truth_value(val: Any) -> Any:
if val in ("true", "false"):
return val == "true"
if isinstance(val, bool):
return val
return val
if truth_value(options.get("ssl")) != truth_value(options.get("tls")):
err_msg = "Can not specify conflicting values for URI options %s and %s."
raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls")))
return options
def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
"""Issue appropriate warnings when deprecated options are present in the
options dictionary. Removes deprecated option key, value pairs if the
options dictionary is found to also have the renamed option.
:param options: Instance of _CaseInsensitiveDictionary containing
MongoDB URI options.
"""
for optname in list(options):
if optname in URI_OPTIONS_DEPRECATION_MAP:
mode, message = URI_OPTIONS_DEPRECATION_MAP[optname]
if mode == "renamed":
newoptname = message
if newoptname in options:
warn_msg = "Deprecated option '%s' ignored in favor of '%s'."
warnings.warn(
warn_msg % (options.cased_key(optname), options.cased_key(newoptname)),
DeprecationWarning,
stacklevel=2,
)
options.pop(optname)
continue
warn_msg = "Option '%s' is deprecated, use '%s' instead."
warnings.warn(
warn_msg % (options.cased_key(optname), newoptname),
DeprecationWarning,
stacklevel=2,
)
elif mode == "removed":
warn_msg = "Option '%s' is deprecated. %s."
warnings.warn(
warn_msg % (options.cased_key(optname), message),
DeprecationWarning,
stacklevel=2,
)
return options
def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
"""Normalizes option names in the options dictionary by converting them to
their internally-used names.
:param options: Instance of _CaseInsensitiveDictionary containing
MongoDB URI options.
"""
# Expand the tlsInsecure option.
tlsinsecure = options.get("tlsinsecure")
if tlsinsecure is not None:
for opt in _IMPLICIT_TLSINSECURE_OPTS:
# Implicit options are logically the same as tlsInsecure.
options[opt] = tlsinsecure
for optname in list(options):
intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None)
if intname is not None:
options[intname] = options.pop(optname)
return options
def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]:
"""Validates and normalizes options passed in a MongoDB URI.
Returns a new dictionary of validated and normalized options. If warn is
False then errors will be thrown for invalid options, otherwise they will
be ignored and a warning will be issued.
:param opts: A dict of MongoDB URI options.
:param warn: If ``True`` then warnings will be logged and
invalid options will be ignored. Otherwise invalid options will
cause errors.
"""
return get_validated_options(opts, warn)
def split_options(
opts: str, validate: bool = True, warn: bool = False, normalize: bool = True
) -> MutableMapping[str, Any]:
"""Takes the options portion of a MongoDB URI, validates each option
and returns the options in a dictionary.
:param opt: A string representing MongoDB URI options.
:param validate: If ``True`` (the default), validate and normalize all
options.
:param warn: If ``False`` (the default), suppress all warnings raised
during validation of options.
:param normalize: If ``True`` (the default), renames all options to their
internally-used names.
"""
and_idx = opts.find("&")
semi_idx = opts.find(";")
try:
if and_idx >= 0 and semi_idx >= 0:
raise InvalidURI("Can not mix '&' and ';' for option separators.")
elif and_idx >= 0:
options = _parse_options(opts, "&")
elif semi_idx >= 0:
options = _parse_options(opts, ";")
elif opts.find("=") != -1:
options = _parse_options(opts, None)
else:
raise ValueError
except ValueError:
raise InvalidURI("MongoDB URI options are key=value pairs.") from None
options = _handle_security_options(options)
options = _handle_option_deprecations(options)
if normalize:
options = _normalize_options(options)
if validate:
options = cast(_CaseInsensitiveDictionary, validate_options(options, warn))
if options.get("authsource") == "":
raise InvalidURI("the authSource database cannot be an empty string")
return options
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]:
"""Takes a string of the form host1[:port],host2[:port]... and
splits it into (host, port) tuples. If [:port] isn't present the
default_port is used.
Returns a set of 2-tuples containing the host name (or IP) followed by
port number.
:param hosts: A string of the form host1[:port],host2[:port],...
:param default_port: The port number to use when one wasn't specified
for a host.
"""
nodes = []
for entity in hosts.split(","):
if not entity:
raise ConfigurationError("Empty host (or extra comma in host list).")
port = default_port
# Unix socket entities don't have ports
if entity.endswith(".sock"):
port = None
nodes.append(parse_host(entity, port))
return nodes
# Prohibited characters in database name. DB names also can't have ".", but for
# backward-compat we allow "db.collection" in URI.
_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]")
_ALLOWED_TXT_OPTS = frozenset(
["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"]
)
def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None:
# Ensure directConnection was not True if there are multiple seeds.
if len(nodes) > 1 and options.get("directconnection"):
raise ConfigurationError("Cannot specify multiple hosts with directConnection=true")
if options.get("loadbalanced"):
if len(nodes) > 1:
raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true")
if options.get("directconnection"):
raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true")
if options.get("replicaset"):
raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true")
def parse_uri(
uri: str,
default_port: Optional[int] = DEFAULT_PORT,
validate: bool = True,
warn: bool = False,
normalize: bool = True,
connect_timeout: Optional[float] = None,
srv_service_name: Optional[str] = None,
srv_max_hosts: Optional[int] = None,
) -> dict[str, Any]:
"""Parse and validate a MongoDB URI.
Returns a dict of the form::
{
'nodelist': <list of (host, port) tuples>,
'username': <username> or None,
'password': <password> or None,
'database': <database name> or None,
'collection': <collection name> or None,
'options': <dict of MongoDB URI options>,
'fqdn': <fqdn of the MongoDB+SRV URI> or None
}
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
to build nodelist and options.
:param uri: The MongoDB URI to parse.
:param default_port: The port number to use when one wasn't specified
for a host in the URI.
:param validate: If ``True`` (the default), validate and
normalize all options. Default: ``True``.
:param warn: When validating, if ``True`` then will warn
the user then ignore any invalid options or values. If ``False``,
validation will error when options are unsupported or values are
invalid. Default: ``False``.
:param normalize: If ``True``, convert names of URI options
to their internally-used names. Default: ``True``.
:param connect_timeout: The maximum time in milliseconds to
wait for a response from the DNS server.
:param srv_service_name: A custom SRV service name
.. versionchanged:: 4.6
The delimiting slash (``/``) between hosts and connection options is now optional.
For example, "mongodb://example.com?tls=true" is now a valid URI.
.. versionchanged:: 4.0
To better follow RFC 3986, unquoted percent signs ("%") are no longer
supported.
.. versionchanged:: 3.9
Added the ``normalize`` parameter.
.. versionchanged:: 3.6
Added support for mongodb+srv:// URIs.
.. versionchanged:: 3.5
Return the original value of the ``readPreference`` MongoDB URI option
instead of the validated read preference mode.
.. versionchanged:: 3.1
``warn`` added so invalid options can be ignored.
"""
if uri.startswith(SCHEME):
is_srv = False
scheme_free = uri[SCHEME_LEN:]
elif uri.startswith(SRV_SCHEME):
if not _have_dnspython():
python_path = sys.executable or "python"
raise ConfigurationError(
'The "dnspython" module must be '
"installed to use mongodb+srv:// URIs. "
"To fix this error install pymongo again:\n "
"%s -m pip install pymongo>=4.3" % (python_path)
)
is_srv = True
scheme_free = uri[SRV_SCHEME_LEN:]
else:
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
if not scheme_free:
raise InvalidURI("Must provide at least one hostname or IP.")
user = None
passwd = None
dbase = None
collection = None
options = _CaseInsensitiveDictionary()
host_plus_db_part, _, opts = scheme_free.partition("?")
if "/" in host_plus_db_part:
host_part, _, dbase = host_plus_db_part.partition("/")
else:
host_part = host_plus_db_part
if dbase:
dbase = unquote_plus(dbase)
if "." in dbase:
dbase, collection = dbase.split(".", 1)
if _BAD_DB_CHARS.search(dbase):
raise InvalidURI('Bad database name "%s"' % dbase)
else:
dbase = None
if opts:
options.update(split_options(opts, validate, warn, normalize))
if srv_service_name is None:
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
if "@" in host_part:
userinfo, _, hosts = host_part.rpartition("@")
user, passwd = parse_userinfo(userinfo)
else:
hosts = host_part
if "/" in hosts:
raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part)
hosts = unquote_plus(hosts)
fqdn = None
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
if is_srv:
if options.get("directConnection"):
raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs")
nodes = split_hosts(hosts, default_port=None)
if len(nodes) != 1:
raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname")
fqdn, port = nodes[0]
if port is not None:
raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number")
# Use the connection timeout. connectTimeoutMS passed as a keyword
# argument overrides the same option passed in the connection string.
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
nodes = dns_resolver.get_hosts()
dns_options = dns_resolver.get_options()
if dns_options:
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
raise ConfigurationError(
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
)
for opt, val in parsed_dns_options.items():
if opt not in options:
options[opt] = val
if options.get("loadBalanced") and srv_max_hosts:
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
if options.get("replicaSet") and srv_max_hosts:
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
if "tls" not in options and "ssl" not in options:
options["tls"] = True if validate else "true"
elif not is_srv and options.get("srvServiceName") is not None:
raise ConfigurationError(
"The srvServiceName option is only allowed with 'mongodb+srv://' URIs"
)
elif not is_srv and srv_max_hosts:
raise ConfigurationError(
"The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs"
)
else:
nodes = split_hosts(hosts, default_port=default_port)
_check_options(nodes, options)
return {
"nodelist": nodes,
"username": user,
"password": passwd,
"database": dbase,
"collection": collection,
"options": options,
"fqdn": fqdn,
}
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]:
"""Parse KMS TLS connection options."""
if not kms_tls_options:
return {}
if not isinstance(kms_tls_options, dict):
raise TypeError("kms_tls_options must be a dict")
contexts = {}
for provider, options in kms_tls_options.items():
if not isinstance(options, dict):
raise TypeError(f'kms_tls_options["{provider}"] must be a dict')
options.setdefault("tls", True)
opts = _CaseInsensitiveDictionary(options)
opts = _handle_security_options(opts)
opts = _normalize_options(opts)
opts = cast(_CaseInsensitiveDictionary, validate_options(opts))
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts)
if ssl_context is None:
raise ConfigurationError("TLS is required for KMS providers")
if allow_invalid_hostnames:
raise ConfigurationError("Insecure TLS options prohibited")
for n in [
"tlsInsecure",
"tlsAllowInvalidCertificates",
"tlsAllowInvalidHostnames",
"tlsDisableCertificateRevocationCheck",
]:
if n in opts:
raise ConfigurationError(f"Insecure TLS options prohibited: {n}")
contexts[provider] = ssl_context
return contexts
if __name__ == "__main__":
import pprint
try:
pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203
except InvalidURI as exc:
print(exc) # noqa: T201
sys.exit(0)

View File

@ -15,6 +15,7 @@
"""Re-import of synchronous Auth API for compatibility."""
from __future__ import annotations
from pymongo.auth_shared import * # noqa: F403
from pymongo.synchronous.auth import * # noqa: F403
from pymongo.synchronous.auth import __doc__ as original_doc

View File

@ -15,7 +15,9 @@
"""Re-import of synchronous AuthOIDC API for compatibility."""
from __future__ import annotations
from pymongo.auth_oidc_shared import * # noqa: F403
from pymongo.synchronous.auth_oidc import * # noqa: F403
from pymongo.synchronous.auth_oidc import __doc__ as original_doc
__doc__ = original_doc
__all__ = ["OIDCCallback", "OIDCCallbackContext", "OIDCCallbackResult", "OIDCIdPInfo"] # noqa: F405

118
pymongo/auth_oidc_shared.py Normal file
View File

@ -0,0 +1,118 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Constants, types, and classes shared across OIDC auth implementations."""
from __future__ import annotations
import abc
import os
from dataclasses import dataclass, field
from typing import Optional
from urllib.parse import quote
from pymongo._azure_helpers import _get_azure_response
from pymongo._gcp_helpers import _get_gcp_response
@dataclass
class OIDCIdPInfo:
issuer: str
clientId: Optional[str] = field(default=None)
requestScopes: Optional[list[str]] = field(default=None)
@dataclass
class OIDCCallbackContext:
timeout_seconds: float
username: str
version: int
refresh_token: Optional[str] = field(default=None)
idp_info: Optional[OIDCIdPInfo] = field(default=None)
@dataclass
class OIDCCallbackResult:
access_token: str
expires_in_seconds: Optional[float] = field(default=None)
refresh_token: Optional[str] = field(default=None)
class OIDCCallback(abc.ABC):
"""A base class for defining OIDC callbacks."""
@abc.abstractmethod
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
"""Convert the given BSON value into our own type."""
@dataclass
class _OIDCProperties:
callback: Optional[OIDCCallback] = field(default=None)
human_callback: Optional[OIDCCallback] = field(default=None)
environment: Optional[str] = field(default=None)
allowed_hosts: list[str] = field(default_factory=list)
token_resource: Optional[str] = field(default=None)
username: str = ""
"""Mechanism properties for MONGODB-OIDC authentication."""
TOKEN_BUFFER_MINUTES = 5
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
CALLBACK_VERSION = 1
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
TIME_BETWEEN_CALLS_SECONDS = 0.1
class _OIDCTestCallback(OIDCCallback):
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
token_file = os.environ.get("OIDC_TOKEN_FILE")
if not token_file:
raise RuntimeError(
'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set'
)
with open(token_file) as fid:
return OIDCCallbackResult(access_token=fid.read().strip())
class _OIDCAWSCallback(OIDCCallback):
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
if not token_file:
raise RuntimeError(
'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set'
)
with open(token_file) as fid:
return OIDCCallbackResult(access_token=fid.read().strip())
class _OIDCAzureCallback(OIDCCallback):
def __init__(self, token_resource: str) -> None:
self.token_resource = quote(token_resource)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds)
return OIDCCallbackResult(
access_token=resp["access_token"], expires_in_seconds=resp["expires_in"]
)
class _OIDCGCPCallback(OIDCCallback):
def __init__(self, token_resource: str) -> None:
self.token_resource = quote(token_resource)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
return OIDCCallbackResult(access_token=resp["access_token"])

236
pymongo/auth_shared.py Normal file
View File

@ -0,0 +1,236 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Constants and types shared across multiple auth types."""
from __future__ import annotations
import os
import typing
from base64 import standard_b64encode
from collections import namedtuple
from typing import Any, Dict, Mapping, Optional
from bson import Binary
from pymongo.auth_oidc_shared import (
_OIDCAzureCallback,
_OIDCGCPCallback,
_OIDCProperties,
_OIDCTestCallback,
)
from pymongo.errors import ConfigurationError
MECHANISMS = frozenset(
[
"GSSAPI",
"MONGODB-CR",
"MONGODB-OIDC",
"MONGODB-X509",
"MONGODB-AWS",
"PLAIN",
"SCRAM-SHA-1",
"SCRAM-SHA-256",
"DEFAULT",
]
)
"""The authentication mechanisms supported by PyMongo."""
class _Cache:
__slots__ = ("data",)
_hash_val = hash("_Cache")
def __init__(self) -> None:
self.data = None
def __eq__(self, other: object) -> bool:
# Two instances must always compare equal.
if isinstance(other, _Cache):
return True
return NotImplemented
def __ne__(self, other: object) -> bool:
if isinstance(other, _Cache):
return False
return NotImplemented
def __hash__(self) -> int:
return self._hash_val
MongoCredential = namedtuple(
"MongoCredential",
["mechanism", "source", "username", "password", "mechanism_properties", "cache"],
)
"""A hashable namedtuple of values used for authentication."""
GSSAPIProperties = namedtuple(
"GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"]
)
"""Mechanism properties for GSSAPI authentication."""
_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
"""Mechanism properties for MONGODB-AWS authentication."""
def _build_credentials_tuple(
mech: str,
source: Optional[str],
user: str,
passwd: str,
extra: Mapping[str, Any],
database: Optional[str],
) -> MongoCredential:
"""Build and return a mechanism specific credentials tuple."""
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
raise ConfigurationError(f"{mech} requires a username.")
if mech == "GSSAPI":
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for GSSAPI")
properties = extra.get("authmechanismproperties", {})
service_name = properties.get("SERVICE_NAME", "mongodb")
canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False))
service_realm = properties.get("SERVICE_REALM")
props = GSSAPIProperties(
service_name=service_name,
canonicalize_host_name=canonicalize,
service_realm=service_realm,
)
# Source is always $external.
return MongoCredential(mech, "$external", user, passwd, props, None)
elif mech == "MONGODB-X509":
if passwd is not None:
raise ConfigurationError("Passwords are not supported by MONGODB-X509")
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for MONGODB-X509")
# Source is always $external, user can be None.
return MongoCredential(mech, "$external", user, None, None, None)
elif mech == "MONGODB-AWS":
if user is not None and passwd is None:
raise ConfigurationError("username without a password is not supported by MONGODB-AWS")
if source is not None and source != "$external":
raise ConfigurationError(
"authentication source must be $external or None for MONGODB-AWS"
)
properties = extra.get("authmechanismproperties", {})
aws_session_token = properties.get("AWS_SESSION_TOKEN")
aws_props = _AWSProperties(aws_session_token=aws_session_token)
# user can be None for temporary link-local EC2 credentials.
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
elif mech == "MONGODB-OIDC":
properties = extra.get("authmechanismproperties", {})
callback = properties.get("OIDC_CALLBACK")
human_callback = properties.get("OIDC_HUMAN_CALLBACK")
environ = properties.get("ENVIRONMENT")
token_resource = properties.get("TOKEN_RESOURCE", "")
default_allowed = [
"*.mongodb.net",
"*.mongodb-dev.net",
"*.mongodb-qa.net",
"*.mongodbgov.net",
"localhost",
"127.0.0.1",
"::1",
]
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
msg = (
"authentication with MONGODB-OIDC requires providing either a callback or a environment"
)
if passwd is not None:
msg = "password is not supported by MONGODB-OIDC"
raise ConfigurationError(msg)
if callback or human_callback:
if environ is not None:
raise ConfigurationError(msg)
if callback and human_callback:
msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK"
raise ConfigurationError(msg)
elif environ is not None:
if environ == "test":
if user is not None:
msg = "test environment for MONGODB-OIDC does not support username"
raise ConfigurationError(msg)
callback = _OIDCTestCallback()
elif environ == "azure":
passwd = None
if not token_resource:
raise ConfigurationError(
"Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCAzureCallback(token_resource)
elif environ == "gcp":
passwd = None
if not token_resource:
raise ConfigurationError(
"GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCGCPCallback(token_resource)
else:
raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}")
else:
raise ConfigurationError(msg)
oidc_props = _OIDCProperties(
callback=callback,
human_callback=human_callback,
environment=environ,
allowed_hosts=allowed_hosts,
token_resource=token_resource,
username=user,
)
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
elif mech == "PLAIN":
source_database = source or database or "$external"
return MongoCredential(mech, source_database, user, passwd, None, None)
else:
source_database = source or database or "admin"
if passwd is None:
raise ConfigurationError("A password is required.")
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
def _xor(fir: bytes, sec: bytes) -> bytes:
"""XOR two byte strings together."""
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]:
"""Split a scram response into key, value pairs."""
return dict(
typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1))
for item in response.split(b",")
)
def _authenticate_scram_start(
credentials: MongoCredential, mechanism: str
) -> tuple[bytes, bytes, typing.MutableMapping[str, Any]]:
username = credentials.username
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
nonce = standard_b64encode(os.urandom(32))
first_bare = b"n=" + user + b",r=" + nonce
cmd = {
"saslStart": 1,
"mechanism": mechanism,
"payload": Binary(b"n,," + first_bare),
"autoAuthorize": 1,
"options": {"skipEmptyExchange": True},
}
return nonce, first_bare, cmd

131
pymongo/bulk_shared.py Normal file
View File

@ -0,0 +1,131 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Constants, types, and classes shared across Bulk Write API implementations."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, NoReturn
from pymongo.errors import BulkWriteError, OperationFailure
from pymongo.helpers_shared import _get_wce_doc
from pymongo.message import (
_DELETE,
_INSERT,
_UPDATE,
)
if TYPE_CHECKING:
from pymongo.typings import _DocumentOut
_DELETE_ALL: int = 0
_DELETE_ONE: int = 1
# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err
_BAD_VALUE: int = 2
_UNKNOWN_ERROR: int = 8
_WRITE_CONCERN_ERROR: int = 64
_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete")
class _Run:
"""Represents a batch of write operations."""
def __init__(self, op_type: int) -> None:
"""Initialize a new Run object."""
self.op_type: int = op_type
self.index_map: list[int] = []
self.ops: list[Any] = []
self.idx_offset: int = 0
def index(self, idx: int) -> int:
"""Get the original index of an operation in this run.
:param idx: The Run index that maps to the original index.
"""
return self.index_map[idx]
def add(self, original_index: int, operation: Any) -> None:
"""Add an operation to this Run instance.
:param original_index: The original index of this operation
within a larger bulk operation.
:param operation: The operation document.
"""
self.index_map.append(original_index)
self.ops.append(operation)
def _merge_command(
run: _Run,
full_result: MutableMapping[str, Any],
offset: int,
result: Mapping[str, Any],
) -> None:
"""Merge a write command result into the full bulk result."""
affected = result.get("n", 0)
if run.op_type == _INSERT:
full_result["nInserted"] += affected
elif run.op_type == _DELETE:
full_result["nRemoved"] += affected
elif run.op_type == _UPDATE:
upserted = result.get("upserted")
if upserted:
n_upserted = len(upserted)
for doc in upserted:
doc["index"] = run.index(doc["index"] + offset)
full_result["upserted"].extend(upserted)
full_result["nUpserted"] += n_upserted
full_result["nMatched"] += affected - n_upserted
else:
full_result["nMatched"] += affected
full_result["nModified"] += result["nModified"]
write_errors = result.get("writeErrors")
if write_errors:
for doc in write_errors:
# Leave the server response intact for APM.
replacement = doc.copy()
idx = doc["index"] + offset
replacement["index"] = run.index(idx)
# Add the failed operation to the error document.
replacement["op"] = run.ops[idx]
full_result["writeErrors"].append(replacement)
wce = _get_wce_doc(result)
if wce:
full_result["writeConcernErrors"].append(wce)
def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn:
"""Raise a BulkWriteError from the full bulk api result."""
# retryWrites on MMAPv1 should raise an actionable error.
if full_result["writeErrors"]:
full_result["writeErrors"].sort(key=lambda error: error["index"])
err = full_result["writeErrors"][0]
code = err["code"]
msg = err["errmsg"]
if code == 20 and msg.startswith("Transaction numbers"):
errmsg = (
"This MongoDB deployment does not support "
"retryable writes. Please add retryWrites=false "
"to your connection string."
)
raise OperationFailure(errmsg, code, full_result)
raise BulkWriteError(full_result)

View File

@ -19,3 +19,4 @@ from pymongo.synchronous.change_stream import * # noqa: F403
from pymongo.synchronous.change_stream import __doc__ as original_doc
__doc__ = original_doc
__all__ = ["ChangeStream", "ClusterChangeStream", "CollectionChangeStream", "DatabaseChangeStream"] # noqa: F405

View File

@ -1,21 +1,332 @@
# Copyright 2024-present MongoDB, Inc.
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Re-import of synchronous ClientOptions API for compatibility."""
"""Tools to parse mongo client options."""
from __future__ import annotations
from pymongo.synchronous.client_options import * # noqa: F403
from pymongo.synchronous.client_options import __doc__ as original_doc
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast
__doc__ = original_doc
from bson.codec_options import _parse_codec_options
from pymongo import common
from pymongo.compression_support import CompressionSettings
from pymongo.errors import ConfigurationError
from pymongo.monitoring import _EventListener, _EventListeners
from pymongo.pool_options import PoolOptions
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import (
_ServerMode,
make_read_preference,
read_pref_mode_from_name,
)
from pymongo.server_selectors import any_server_selector
from pymongo.ssl_support import get_ssl_context
from pymongo.write_concern import WriteConcern, validate_boolean
if TYPE_CHECKING:
from bson.codec_options import CodecOptions
from pymongo.auth_shared import MongoCredential
from pymongo.encryption_options import AutoEncryptionOpts
from pymongo.pyopenssl_context import SSLContext
from pymongo.topology_description import _ServerSelector
def _parse_credentials(
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
) -> Optional[MongoCredential]:
"""Parse authentication credentials."""
mechanism = options.get("authmechanism", "DEFAULT" if username else None)
source = options.get("authsource")
if username or mechanism:
from pymongo.auth_shared import _build_credentials_tuple
return _build_credentials_tuple(mechanism, source, username, password, options, database)
return None
def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode:
"""Parse read preference options."""
if "read_preference" in options:
return options["read_preference"]
name = options.get("readpreference", "primary")
mode = read_pref_mode_from_name(name)
tags = options.get("readpreferencetags")
max_staleness = options.get("maxstalenessseconds", -1)
return make_read_preference(mode, tags, max_staleness)
def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern:
"""Parse write concern options."""
concern = options.get("w")
wtimeout = options.get("wtimeoutms")
j = options.get("journal")
fsync = options.get("fsync")
return WriteConcern(concern, wtimeout, j, fsync)
def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
"""Parse read concern options."""
concern = options.get("readconcernlevel")
return ReadConcern(concern)
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
"""Parse ssl options."""
use_tls = options.get("tls")
if use_tls is not None:
validate_boolean("tls", use_tls)
certfile = options.get("tlscertificatekeyfile")
passphrase = options.get("tlscertificatekeyfilepassword")
ca_certs = options.get("tlscafile")
crlfile = options.get("tlscrlfile")
allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False)
allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False)
disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False)
enabled_tls_opts = []
for opt in (
"tlscertificatekeyfile",
"tlscertificatekeyfilepassword",
"tlscafile",
"tlscrlfile",
):
# Any non-null value of these options implies tls=True.
if opt in options and options[opt]:
enabled_tls_opts.append(opt)
for opt in (
"tlsallowinvalidcertificates",
"tlsallowinvalidhostnames",
"tlsdisableocspendpointcheck",
):
# A value of False for these options implies tls=True.
if opt in options and not options[opt]:
enabled_tls_opts.append(opt)
if enabled_tls_opts:
if use_tls is None:
# Implicitly enable TLS when one of the tls* options is set.
use_tls = True
elif not use_tls:
# Error since tls is explicitly disabled but a tls option is set.
raise ConfigurationError(
"TLS has not been enabled but the "
"following tls parameters have been set: "
"%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts)
)
if use_tls:
ctx = get_ssl_context(
certfile,
passphrase,
ca_certs,
crlfile,
allow_invalid_certificates,
allow_invalid_hostnames,
disable_ocsp_endpoint_check,
)
return ctx, allow_invalid_hostnames
return None, allow_invalid_hostnames
def _parse_pool_options(
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
) -> PoolOptions:
"""Parse connection pool options."""
credentials = _parse_credentials(username, password, database, options)
max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE)
min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE)
max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC)
if max_pool_size is not None and min_pool_size > max_pool_size:
raise ValueError("minPoolSize must be smaller or equal to maxPoolSize")
connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT)
socket_timeout = options.get("sockettimeoutms")
wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT)
event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners"))
appname = options.get("appname")
driver = options.get("driver")
server_api = options.get("server_api")
compression_settings = CompressionSettings(
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
)
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
load_balanced = options.get("loadbalanced")
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
return PoolOptions(
max_pool_size,
min_pool_size,
max_idle_time_seconds,
connect_timeout,
socket_timeout,
wait_queue_timeout,
ssl_context,
tls_allow_invalid_hostnames,
_EventListeners(event_listeners),
appname,
driver,
compression_settings,
max_connecting=max_connecting,
server_api=server_api,
load_balanced=load_balanced,
credentials=credentials,
)
class ClientOptions:
"""Read only configuration options for an AsyncMongoClient/MongoClient.
Should not be instantiated directly by application developers. Access
a client's options via :attr:`pymongo.mongo_client.AsyncMongoClient.options` or :attr:`pymongo.mongo_client.MongoClient.options`
instead.
"""
def __init__(
self, username: str, password: str, database: Optional[str], options: Mapping[str, Any]
):
self.__options = options
self.__codec_options = _parse_codec_options(options)
self.__direct_connection = options.get("directconnection")
self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS)
# self.__server_selection_timeout is in seconds. Must use full name for
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
self.__server_selection_timeout = options.get(
"serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT
)
self.__pool_options = _parse_pool_options(username, password, database, options)
self.__read_preference = _parse_read_preference(options)
self.__replica_set_name = options.get("replicaset")
self.__write_concern = _parse_write_concern(options)
self.__read_concern = _parse_read_concern(options)
self.__connect = options.get("connect")
self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY)
self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES)
self.__retry_reads = options.get("retryreads", common.RETRY_READS)
self.__server_selector = options.get("server_selector", any_server_selector)
self.__auto_encryption_opts = options.get("auto_encryption_opts")
self.__load_balanced = options.get("loadbalanced")
self.__timeout = options.get("timeoutms")
self.__server_monitoring_mode = options.get(
"servermonitoringmode", common.SERVER_MONITORING_MODE
)
@property
def _options(self) -> Mapping[str, Any]:
"""The original options used to create this ClientOptions."""
return self.__options
@property
def connect(self) -> Optional[bool]:
"""Whether to begin discovering a MongoDB topology automatically."""
return self.__connect
@property
def codec_options(self) -> CodecOptions:
"""A :class:`~bson.codec_options.CodecOptions` instance."""
return self.__codec_options
@property
def direct_connection(self) -> Optional[bool]:
"""Whether to connect to the deployment in 'Single' topology."""
return self.__direct_connection
@property
def local_threshold_ms(self) -> int:
"""The local threshold for this instance."""
return self.__local_threshold_ms
@property
def server_selection_timeout(self) -> int:
"""The server selection timeout for this instance in seconds."""
return self.__server_selection_timeout
@property
def server_selector(self) -> _ServerSelector:
return self.__server_selector
@property
def heartbeat_frequency(self) -> int:
"""The monitoring frequency in seconds."""
return self.__heartbeat_frequency
@property
def pool_options(self) -> PoolOptions:
"""A :class:`~pymongo.pool.PoolOptions` instance."""
return self.__pool_options
@property
def read_preference(self) -> _ServerMode:
"""A read preference instance."""
return self.__read_preference
@property
def replica_set_name(self) -> Optional[str]:
"""Replica set name or None."""
return self.__replica_set_name
@property
def write_concern(self) -> WriteConcern:
"""A :class:`~pymongo.write_concern.WriteConcern` instance."""
return self.__write_concern
@property
def read_concern(self) -> ReadConcern:
"""A :class:`~pymongo.read_concern.ReadConcern` instance."""
return self.__read_concern
@property
def timeout(self) -> Optional[float]:
"""The configured timeoutMS converted to seconds, or None.
.. versionadded:: 4.2
"""
return self.__timeout
@property
def retry_writes(self) -> bool:
"""If this instance should retry supported write operations."""
return self.__retry_writes
@property
def retry_reads(self) -> bool:
"""If this instance should retry supported read operations."""
return self.__retry_reads
@property
def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]:
"""A :class:`~pymongo.encryption.AutoEncryptionOpts` or None."""
return self.__auto_encryption_opts
@property
def load_balanced(self) -> Optional[bool]:
"""True if the client was configured to connect to a load balancer."""
return self.__load_balanced
@property
def event_listeners(self) -> list[_EventListeners]:
"""The event listeners registered for this client.
See :mod:`~pymongo.monitoring` for details.
.. versionadded:: 4.0
"""
assert self.__pool_options._event_listeners is not None
return self.__pool_options._event_listeners.event_listeners()
@property
def server_monitoring_mode(self) -> str:
"""The configured serverMonitoringMode option.
.. versionadded:: 4.5
"""
return self.__server_monitoring_mode

View File

@ -19,3 +19,4 @@ from pymongo.synchronous.client_session import * # noqa: F403
from pymongo.synchronous.client_session import __doc__ as original_doc
__doc__ = original_doc
__all__ = ["ClientSession", "SessionOptions", "TransactionOptions"] # noqa: F405

View File

@ -1,4 +1,4 @@
# Copyright 2024-present MongoDB, Inc.
# Copyright 2016-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,10 +12,213 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Re-import of synchronous Collation API for compatibility."""
"""Tools for working with `collations`_.
.. _collations: https://www.mongodb.com/docs/manual/reference/collation/
"""
from __future__ import annotations
from pymongo.synchronous.collation import * # noqa: F403
from pymongo.synchronous.collation import __doc__ as original_doc
from typing import Any, Mapping, Optional, Union
__doc__ = original_doc
from pymongo import common
from pymongo.write_concern import validate_boolean
class CollationStrength:
"""
An enum that defines values for `strength` on a
:class:`~pymongo.collation.Collation`.
"""
PRIMARY = 1
"""Differentiate base (unadorned) characters."""
SECONDARY = 2
"""Differentiate character accents."""
TERTIARY = 3
"""Differentiate character case."""
QUATERNARY = 4
"""Differentiate words with and without punctuation."""
IDENTICAL = 5
"""Differentiate unicode code point (characters are exactly identical)."""
class CollationAlternate:
"""
An enum that defines values for `alternate` on a
:class:`~pymongo.collation.Collation`.
"""
NON_IGNORABLE = "non-ignorable"
"""Spaces and punctuation are treated as base characters."""
SHIFTED = "shifted"
"""Spaces and punctuation are *not* considered base characters.
Spaces and punctuation are distinguished regardless when the
:class:`~pymongo.collation.Collation` strength is at least
:data:`~pymongo.collation.CollationStrength.QUATERNARY`.
"""
class CollationMaxVariable:
"""
An enum that defines values for `max_variable` on a
:class:`~pymongo.collation.Collation`.
"""
PUNCT = "punct"
"""Both punctuation and spaces are ignored."""
SPACE = "space"
"""Spaces alone are ignored."""
class CollationCaseFirst:
"""
An enum that defines values for `case_first` on a
:class:`~pymongo.collation.Collation`.
"""
UPPER = "upper"
"""Sort uppercase characters first."""
LOWER = "lower"
"""Sort lowercase characters first."""
OFF = "off"
"""Default for locale or collation strength."""
class Collation:
"""Collation
:param locale: (string) The locale of the collation. This should be a string
that identifies an `ICU locale ID` exactly. For example, ``en_US`` is
valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB
documentation for a list of supported locales.
:param caseLevel: (optional) If ``True``, turn on case sensitivity if
`strength` is 1 or 2 (case sensitivity is implied if `strength` is
greater than 2). Defaults to ``False``.
:param caseFirst: (optional) Specify that either uppercase or lowercase
characters take precedence. Must be one of the following values:
* :data:`~CollationCaseFirst.UPPER`
* :data:`~CollationCaseFirst.LOWER`
* :data:`~CollationCaseFirst.OFF` (the default)
:param strength: Specify the comparison strength. This is also
known as the ICU comparison level. This must be one of the following
values:
* :data:`~CollationStrength.PRIMARY`
* :data:`~CollationStrength.SECONDARY`
* :data:`~CollationStrength.TERTIARY` (the default)
* :data:`~CollationStrength.QUATERNARY`
* :data:`~CollationStrength.IDENTICAL`
Each successive level builds upon the previous. For example, a
`strength` of :data:`~CollationStrength.SECONDARY` differentiates
characters based both on the unadorned base character and its accents.
:param numericOrdering: If ``True``, order numbers numerically
instead of in collation order (defaults to ``False``).
:param alternate: Specify whether spaces and punctuation are
considered base characters. This must be one of the following values:
* :data:`~CollationAlternate.NON_IGNORABLE` (the default)
* :data:`~CollationAlternate.SHIFTED`
:param maxVariable: When `alternate` is
:data:`~CollationAlternate.SHIFTED`, this option specifies what
characters may be ignored. This must be one of the following values:
* :data:`~CollationMaxVariable.PUNCT` (the default)
* :data:`~CollationMaxVariable.SPACE`
:param normalization: If ``True``, normalizes text into Unicode
NFD. Defaults to ``False``.
:param backwards: If ``True``, accents on characters are
considered from the back of the word to the front, as it is done in some
French dictionary ordering traditions. Defaults to ``False``.
:param kwargs: Keyword arguments supplying any additional options
to be sent with this Collation object.
.. versionadded: 3.4
"""
__slots__ = ("__document",)
def __init__(
self,
locale: str,
caseLevel: Optional[bool] = None,
caseFirst: Optional[str] = None,
strength: Optional[int] = None,
numericOrdering: Optional[bool] = None,
alternate: Optional[str] = None,
maxVariable: Optional[str] = None,
normalization: Optional[bool] = None,
backwards: Optional[bool] = None,
**kwargs: Any,
) -> None:
locale = common.validate_string("locale", locale)
self.__document: dict[str, Any] = {"locale": locale}
if caseLevel is not None:
self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel)
if caseFirst is not None:
self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst)
if strength is not None:
self.__document["strength"] = common.validate_integer("strength", strength)
if numericOrdering is not None:
self.__document["numericOrdering"] = validate_boolean(
"numericOrdering", numericOrdering
)
if alternate is not None:
self.__document["alternate"] = common.validate_string("alternate", alternate)
if maxVariable is not None:
self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable)
if normalization is not None:
self.__document["normalization"] = validate_boolean("normalization", normalization)
if backwards is not None:
self.__document["backwards"] = validate_boolean("backwards", backwards)
self.__document.update(kwargs)
@property
def document(self) -> dict[str, Any]:
"""The document representation of this collation.
.. note::
:class:`Collation` is immutable. Mutating the value of
:attr:`document` does not mutate this :class:`Collation`.
"""
return self.__document.copy()
def __repr__(self) -> str:
document = self.document
return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document))
def __eq__(self, other: Any) -> bool:
if isinstance(other, Collation):
return self.document == other.document
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def validate_collation_or_none(
value: Optional[Union[Mapping[str, Any], Collation]]
) -> Optional[dict[str, Any]]:
if value is None:
return None
if isinstance(value, Collation):
return value.document
if isinstance(value, dict):
return value
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")

View File

@ -19,3 +19,7 @@ from pymongo.synchronous.collection import * # noqa: F403
from pymongo.synchronous.collection import __doc__ as original_doc
__doc__ = original_doc
__all__ = [ # noqa: F405
"Collection",
"ReturnDocument",
]

View File

@ -19,3 +19,4 @@ from pymongo.synchronous.command_cursor import * # noqa: F403
from pymongo.synchronous.command_cursor import __doc__ as original_doc
__doc__ = original_doc
__all__ = ["CommandCursor", "RawBatchCommandCursor"] # noqa: F405

View File

@ -40,22 +40,21 @@ from bson import SON
from bson.binary import UuidRepresentation
from bson.codec_options import CodecOptions, DatetimeConversion, TypeRegistry
from bson.raw_bson import RawBSONDocument
from pymongo.driver_info import DriverInfo
from pymongo.errors import ConfigurationError
from pymongo.read_concern import ReadConcern
from pymongo.server_api import ServerApi
from pymongo.synchronous.compression_support import (
from pymongo.compression_support import (
validate_compressors,
validate_zlib_compression_level,
)
from pymongo.synchronous.monitoring import _validate_event_listeners
from pymongo.synchronous.read_preferences import _MONGOS_MODES, _ServerMode
from pymongo.driver_info import DriverInfo
from pymongo.errors import ConfigurationError
from pymongo.monitoring import _validate_event_listeners
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _MONGOS_MODES, _ServerMode
from pymongo.server_api import ServerApi
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean
if TYPE_CHECKING:
from pymongo.synchronous.client_session import ClientSession
from pymongo.typings import _AgnosticClientSession
_IS_SYNC = True
ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict)
@ -69,7 +68,8 @@ MAX_WRITE_BATCH_SIZE = 1000
# What this version of PyMongo supports.
MIN_SUPPORTED_SERVER_VERSION = "3.6"
MIN_SUPPORTED_WIRE_VERSION = 6
MAX_SUPPORTED_WIRE_VERSION = 21
# MongoDB 8.0
MAX_SUPPORTED_WIRE_VERSION = 25
# Frequency to call hello on servers, in seconds.
HEARTBEAT_FREQUENCY = 10
@ -380,7 +380,7 @@ def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode:
def validate_auth_mechanism(option: str, value: Any) -> str:
"""Validate the authMechanism URI option."""
from pymongo.synchronous.auth import MECHANISMS
from pymongo.auth_shared import MECHANISMS
if value not in MECHANISMS:
raise ValueError(f"{option} must be in {tuple(MECHANISMS)}")
@ -446,7 +446,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
elif key in ["ALLOWED_HOSTS"] and isinstance(value, list):
props[key] = value
elif key in ["OIDC_CALLBACK", "OIDC_HUMAN_CALLBACK"]:
from pymongo.synchronous.auth_oidc import OIDCCallback
from pymongo.auth_oidc_shared import OIDCCallback
if not isinstance(value, OIDCCallback):
raise ValueError("callback must be an OIDCCallback object")
@ -642,7 +642,7 @@ def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[A
"""Validate the driver keyword arg."""
if value is None:
return value
from pymongo.synchronous.encryption_options import AutoEncryptionOpts
from pymongo.encryption_options import AutoEncryptionOpts
if not isinstance(value, AutoEncryptionOpts):
raise TypeError(f"{option} must be an instance of AutoEncryptionOpts")
@ -941,7 +941,7 @@ class BaseObject:
"""
return self._write_concern
def _write_concern_for(self, session: Optional[ClientSession]) -> WriteConcern:
def _write_concern_for(self, session: Optional[_AgnosticClientSession]) -> WriteConcern:
"""Read only access to the write concern of this instance or session."""
# Override this operation's write concern with the transaction's.
if session and session.in_transaction:
@ -957,7 +957,7 @@ class BaseObject:
"""
return self._read_preference
def _read_preference_for(self, session: Optional[ClientSession]) -> _ServerMode:
def _read_preference_for(self, session: Optional[_AgnosticClientSession]) -> _ServerMode:
"""Read only access to the read preference of this instance or session."""
# Override this operation's read preference with the transaction's.
if session:

View File

@ -16,11 +16,8 @@ from __future__ import annotations
import warnings
from typing import Any, Iterable, Optional, Union
from pymongo.helpers_constants import _SENSITIVE_COMMANDS
from pymongo.synchronous.hello_compat import HelloCompat
_IS_SYNC = True
from pymongo.hello import HelloCompat
from pymongo.helpers_shared import _SENSITIVE_COMMANDS
_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"}
_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD}

View File

@ -20,3 +20,4 @@ from pymongo.synchronous.cursor import * # noqa: F403
from pymongo.synchronous.cursor import __doc__ as original_doc
__doc__ = original_doc
__all__ = ["Cursor", "CursorType", "RawBatchCursor"] # noqa: F405

View File

@ -19,3 +19,4 @@ from pymongo.synchronous.database import * # noqa: F403
from pymongo.synchronous.database import __doc__ as original_doc
__doc__ = original_doc
__all__ = ["Database"] # noqa: F405

View File

@ -19,3 +19,4 @@ from pymongo.synchronous.encryption import * # noqa: F403
from pymongo.synchronous.encryption import __doc__ as original_doc
__doc__ = original_doc
__all__ = ["Algorithm", "ClientEncryption", "QueryType", "RewrapManyDataKeyResult"] # noqa: F405

View File

@ -1,4 +1,4 @@
# Copyright 2024-present MongoDB, Inc.
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,10 +12,261 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Re-import of synchronous EncryptionOptions API for compatibility."""
"""Support for automatic client-side field level encryption."""
from __future__ import annotations
from pymongo.synchronous.encryption_options import * # noqa: F403
from pymongo.synchronous.encryption_options import __doc__ as original_doc
from typing import TYPE_CHECKING, Any, Mapping, Optional
__doc__ = original_doc
try:
import pymongocrypt # type:ignore[import] # noqa: F401
# Check for pymongocrypt>=1.10.
from pymongocrypt import synchronous as _ # noqa: F401
_HAVE_PYMONGOCRYPT = True
except ImportError:
_HAVE_PYMONGOCRYPT = False
from bson import int64
from pymongo.common import validate_is_mapping
from pymongo.errors import ConfigurationError
from pymongo.uri_parser import _parse_kms_tls_options
if TYPE_CHECKING:
from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg
class AutoEncryptionOpts:
"""Options to configure automatic client-side field level encryption."""
def __init__(
self,
kms_providers: Mapping[str, Any],
key_vault_namespace: str,
key_vault_client: Optional[_AgnosticMongoClient[_DocumentTypeArg]] = None,
schema_map: Optional[Mapping[str, Any]] = None,
bypass_auto_encryption: bool = False,
mongocryptd_uri: str = "mongodb://localhost:27020",
mongocryptd_bypass_spawn: bool = False,
mongocryptd_spawn_path: str = "mongocryptd",
mongocryptd_spawn_args: Optional[list[str]] = None,
kms_tls_options: Optional[Mapping[str, Any]] = None,
crypt_shared_lib_path: Optional[str] = None,
crypt_shared_lib_required: bool = False,
bypass_query_analysis: bool = False,
encrypted_fields_map: Optional[Mapping[str, Any]] = None,
) -> None:
"""Options to configure automatic client-side field level encryption.
Automatic client-side field level encryption requires MongoDB >=4.2
enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not
supported for operations on a database or view and will result in
error.
Although automatic encryption requires MongoDB >=4.2 enterprise or a
MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all
users. To configure automatic *decryption* without automatic
*encryption* set ``bypass_auto_encryption=True``. Explicit
encryption and explicit decryption is also supported for all users
with the :class:`~pymongo.encryption.ClientEncryption` class.
See :ref:`automatic-client-side-encryption` for an example.
:param kms_providers: Map of KMS provider options. The `kms_providers`
map values differ by provider:
- `aws`: Map with "accessKeyId" and "secretAccessKey" as strings.
These are the AWS access key ID and AWS secret access key used
to generate KMS messages. An optional "sessionToken" may be
included to support temporary AWS credentials.
- `azure`: Map with "tenantId", "clientId", and "clientSecret" as
strings. Additionally, "identityPlatformEndpoint" may also be
specified as a string (defaults to 'login.microsoftonline.com').
These are the Azure Active Directory credentials used to
generate Azure Key Vault messages.
- `gcp`: Map with "email" as a string and "privateKey"
as `bytes` or a base64 encoded string.
Additionally, "endpoint" may also be specified as a string
(defaults to 'oauth2.googleapis.com'). These are the
credentials used to generate Google Cloud KMS messages.
- `kmip`: Map with "endpoint" as a host with required port.
For example: ``{"endpoint": "example.com:443"}``.
- `local`: Map with "key" as `bytes` (96 bytes in length) or
a base64 encoded string which decodes
to 96 bytes. "key" is the master key used to encrypt/decrypt
data keys. This key should be generated and stored as securely
as possible.
KMS providers may be specified with an optional name suffix
separated by a colon, for example "kmip:name" or "aws:name".
Named KMS providers do not support :ref:`CSFLE on-demand credentials`.
Named KMS providers enables more than one of each KMS provider type to be configured.
For example, to configure multiple local KMS providers::
kms_providers = {
"local": {"key": local_kek1}, # Unnamed KMS provider.
"local:myname": {"key": local_kek2}, # Named KMS provider with name "myname".
}
:param key_vault_namespace: The namespace for the key vault collection.
The key vault collection contains all data keys used for encryption
and decryption. Data keys are stored as documents in this MongoDB
collection. Data keys are protected with encryption by a KMS
provider.
:param key_vault_client: By default, the key vault collection
is assumed to reside in the same MongoDB cluster as the encrypted
AsyncMongoClient/MongoClient. Use this option to route data key queries to a
separate MongoDB cluster.
:param schema_map: Map of collection namespace ("db.coll") to
JSON Schema. By default, a collection's JSONSchema is periodically
polled with the listCollections command. But a JSONSchema may be
specified locally with the schemaMap option.
**Supplying a `schema_map` provides more security than relying on
JSON Schemas obtained from the server. It protects against a
malicious server advertising a false JSON Schema, which could trick
the client into sending unencrypted data that should be
encrypted.**
Schemas supplied in the schemaMap only apply to configuring
automatic encryption for client side encryption. Other validation
rules in the JSON schema will not be enforced by the driver and
will result in an error.
:param bypass_auto_encryption: If ``True``, automatic
encryption will be disabled but automatic decryption will still be
enabled. Defaults to ``False``.
:param mongocryptd_uri: The MongoDB URI used to connect
to the *local* mongocryptd process. Defaults to
``'mongodb://localhost:27020'``.
:param mongocryptd_bypass_spawn: If ``True``, the encrypted
AsyncMongoClient/MongoClient will not attempt to spawn the mongocryptd process.
Defaults to ``False``.
:param mongocryptd_spawn_path: Used for spawning the
mongocryptd process. Defaults to ``'mongocryptd'`` and spawns
mongocryptd from the system path.
:param mongocryptd_spawn_args: A list of string arguments to
use when spawning the mongocryptd process. Defaults to
``['--idleShutdownTimeoutSecs=60']``. If the list does not include
the ``idleShutdownTimeoutSecs`` option then
``'--idleShutdownTimeoutSecs=60'`` will be added.
:param kms_tls_options: A map of KMS provider names to TLS
options to use when creating secure connections to KMS providers.
Accepts the same TLS options as
:class:`pymongo.mongo_client.AsyncMongoClient` and :class:`pymongo.mongo_client.MongoClient`. For example, to
override the system default CA file::
kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}}
Or to supply a client certificate::
kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}}
:param crypt_shared_lib_path: Override the path to load the crypt_shared library.
:param crypt_shared_lib_required: If True, raise an error if libmongocrypt is
unable to load the crypt_shared library.
:param bypass_query_analysis: If ``True``, disable automatic analysis
of outgoing commands. Set `bypass_query_analysis` to use explicit
encryption on indexed fields without the MongoDB Enterprise Advanced
licensed crypt_shared library.
:param encrypted_fields_map: Map of collection namespace ("db.coll") to documents
that described the encrypted fields for Queryable Encryption. For example::
{
"db.encryptedCollection": {
"escCollection": "enxcol_.encryptedCollection.esc",
"ecocCollection": "enxcol_.encryptedCollection.ecoc",
"fields": [
{
"path": "firstName",
"keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')),
"bsonType": "string",
"queries": {"queryType": "equality"}
},
{
"path": "ssn",
"keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')),
"bsonType": "string"
}
]
}
}
.. versionchanged:: 4.2
Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`,
and `bypass_query_analysis` parameters.
.. versionchanged:: 4.0
Added the `kms_tls_options` parameter and the "kmip" KMS provider.
.. versionadded:: 3.9
"""
if not _HAVE_PYMONGOCRYPT:
raise ConfigurationError(
"client side encryption requires the pymongocrypt library: "
"install a compatible version with: "
"python -m pip install 'pymongo[encryption]'"
)
if encrypted_fields_map:
validate_is_mapping("encrypted_fields_map", encrypted_fields_map)
self._encrypted_fields_map = encrypted_fields_map
self._bypass_query_analysis = bypass_query_analysis
self._crypt_shared_lib_path = crypt_shared_lib_path
self._crypt_shared_lib_required = crypt_shared_lib_required
self._kms_providers = kms_providers
self._key_vault_namespace = key_vault_namespace
self._key_vault_client = key_vault_client
self._schema_map = schema_map
self._bypass_auto_encryption = bypass_auto_encryption
self._mongocryptd_uri = mongocryptd_uri
self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
self._mongocryptd_spawn_path = mongocryptd_spawn_path
if mongocryptd_spawn_args is None:
mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"]
self._mongocryptd_spawn_args = mongocryptd_spawn_args
if not isinstance(self._mongocryptd_spawn_args, list):
raise TypeError("mongocryptd_spawn_args must be a list")
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
# Maps KMS provider name to a SSLContext.
self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options)
self._bypass_query_analysis = bypass_query_analysis
class RangeOpts:
"""Options to configure encrypted queries using the range algorithm."""
def __init__(
self,
sparsity: int,
trim_factor: int,
min: Optional[Any] = None,
max: Optional[Any] = None,
precision: Optional[int] = None,
) -> None:
"""Options to configure encrypted queries using the range algorithm.
:param sparsity: An integer.
:param trim_factor: An integer.
:param min: A BSON scalar value corresponding to the type being queried.
:param max: A BSON scalar value corresponding to the type being queried.
:param precision: An integer, may only be set for double or decimal128 types.
.. versionadded:: 4.4
"""
self.min = min
self.max = max
self.sparsity = sparsity
self.trim_factor = trim_factor
self.precision = precision
@property
def document(self) -> dict[str, Any]:
doc = {}
for k, v in [
("sparsity", int64.Int64(self.sparsity)),
("trimFactor", self.trim_factor),
("precision", self.precision),
("min", self.min),
("max", self.max),
]:
if v is not None:
doc[k] = v
return doc

View File

@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional, Sequence, Un
from bson.errors import InvalidDocument
if TYPE_CHECKING:
from pymongo.asynchronous.typings import _DocumentOut
from pymongo.typings import _DocumentOut
class PyMongoError(Exception):

View File

@ -1,4 +1,4 @@
# Copyright 2024-present MongoDB, Inc.
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,10 +12,212 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Re-import of synchronous EventLoggers API for compatibility."""
"""Example event logger classes.
.. versionadded:: 3.11
These loggers can be registered using :func:`register` or
:class:`~pymongo.mongo_client.MongoClient`.
``monitoring.register(CommandLogger())``
or
``MongoClient(event_listeners=[CommandLogger()])``
"""
from __future__ import annotations
from pymongo.synchronous.event_loggers import * # noqa: F403
from pymongo.synchronous.event_loggers import __doc__ as original_doc
import logging
__doc__ = original_doc
from pymongo import monitoring
class CommandLogger(monitoring.CommandListener):
"""A simple listener that logs command events.
Listens for :class:`~pymongo.monitoring.CommandStartedEvent`,
:class:`~pymongo.monitoring.CommandSucceededEvent` and
:class:`~pymongo.monitoring.CommandFailedEvent` events and
logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def started(self, event: monitoring.CommandStartedEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} started on server "
f"{event.connection_id}"
)
def succeeded(self, event: monitoring.CommandSucceededEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} on server {event.connection_id} "
f"succeeded in {event.duration_micros} "
"microseconds"
)
def failed(self, event: monitoring.CommandFailedEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} on server {event.connection_id} "
f"failed in {event.duration_micros} "
"microseconds"
)
class ServerLogger(monitoring.ServerListener):
"""A simple listener that logs server discovery events.
Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`,
:class:`~pymongo.monitoring.ServerDescriptionChangedEvent`,
and :class:`~pymongo.monitoring.ServerClosedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def opened(self, event: monitoring.ServerOpeningEvent) -> None:
logging.info(f"Server {event.server_address} added to topology {event.topology_id}")
def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None:
previous_server_type = event.previous_description.server_type
new_server_type = event.new_description.server_type
if new_server_type != previous_server_type:
# server_type_name was added in PyMongo 3.4
logging.info(
f"Server {event.server_address} changed type from "
f"{event.previous_description.server_type_name} to "
f"{event.new_description.server_type_name}"
)
def closed(self, event: monitoring.ServerClosedEvent) -> None:
logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}")
class HeartbeatLogger(monitoring.ServerHeartbeatListener):
"""A simple listener that logs server heartbeat events.
Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`,
:class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`,
and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
logging.info(f"Heartbeat sent to server {event.connection_id}")
def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
# The reply.document attribute was added in PyMongo 3.4.
logging.info(
f"Heartbeat to server {event.connection_id} "
"succeeded with reply "
f"{event.reply.document}"
)
def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
logging.warning(
f"Heartbeat to server {event.connection_id} failed with error {event.reply}"
)
class TopologyLogger(monitoring.TopologyListener):
"""A simple listener that logs server topology events.
Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`,
:class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`,
and :class:`~pymongo.monitoring.TopologyClosedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def opened(self, event: monitoring.TopologyOpenedEvent) -> None:
logging.info(f"Topology with id {event.topology_id} opened")
def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None:
logging.info(f"Topology description updated for topology id {event.topology_id}")
previous_topology_type = event.previous_description.topology_type
new_topology_type = event.new_description.topology_type
if new_topology_type != previous_topology_type:
# topology_type_name was added in PyMongo 3.4
logging.info(
f"Topology {event.topology_id} changed type from "
f"{event.previous_description.topology_type_name} to "
f"{event.new_description.topology_type_name}"
)
# The has_writable_server and has_readable_server methods
# were added in PyMongo 3.4.
if not event.new_description.has_writable_server():
logging.warning("No writable servers available.")
if not event.new_description.has_readable_server():
logging.warning("No readable servers available.")
def closed(self, event: monitoring.TopologyClosedEvent) -> None:
logging.info(f"Topology with id {event.topology_id} closed")
class ConnectionPoolLogger(monitoring.ConnectionPoolListener):
"""A simple listener that logs server connection pool events.
Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`,
:class:`~pymongo.monitoring.PoolClearedEvent`,
:class:`~pymongo.monitoring.PoolClosedEvent`,
:~pymongo.monitoring.class:`ConnectionCreatedEvent`,
:class:`~pymongo.monitoring.ConnectionReadyEvent`,
:class:`~pymongo.monitoring.ConnectionClosedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckedOutEvent`,
and :class:`~pymongo.monitoring.ConnectionCheckedInEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def pool_created(self, event: monitoring.PoolCreatedEvent) -> None:
logging.info(f"[pool {event.address}] pool created")
def pool_ready(self, event: monitoring.PoolReadyEvent) -> None:
logging.info(f"[pool {event.address}] pool ready")
def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None:
logging.info(f"[pool {event.address}] pool cleared")
def pool_closed(self, event: monitoring.PoolClosedEvent) -> None:
logging.info(f"[pool {event.address}] pool closed")
def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None:
logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created")
def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded"
)
def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] "
f'connection closed, reason: "{event.reason}"'
)
def connection_check_out_started(
self, event: monitoring.ConnectionCheckOutStartedEvent
) -> None:
logging.info(f"[pool {event.address}] connection check out started")
def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None:
logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}")
def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool"
)
def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool"
)

View File

@ -21,12 +21,9 @@ import itertools
from typing import Any, Generic, Mapping, Optional
from bson.objectid import ObjectId
from pymongo.asynchronous import common
from pymongo.asynchronous.hello_compat import HelloCompat
from pymongo.asynchronous.typings import ClusterTime, _DocumentType
from pymongo import common
from pymongo.server_type import SERVER_TYPE
_IS_SYNC = False
from pymongo.typings import ClusterTime, _DocumentType
def _get_server_type(doc: Mapping[str, Any]) -> int:
@ -57,6 +54,14 @@ def _get_server_type(doc: Mapping[str, Any]) -> int:
return SERVER_TYPE.Standalone
class HelloCompat:
CMD = "hello"
LEGACY_CMD = "ismaster"
PRIMARY = "isWritablePrimary"
LEGACY_PRIMARY = "ismaster"
LEGACY_ERROR = "not master"
class Hello(Generic[_DocumentType]):
"""Parse a hello response from the server.

View File

@ -1,72 +0,0 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants used by the driver that don't really fit elsewhere."""
# From the SDAM spec, the "node is shutting down" codes.
from __future__ import annotations
_SHUTDOWN_CODES: frozenset = frozenset(
[
11600, # InterruptedAtShutdown
91, # ShutdownInProgress
]
)
# From the SDAM spec, the "not primary" error codes are combined with the
# "node is recovering" error codes (of which the "node is shutting down"
# errors are a subset).
_NOT_PRIMARY_CODES: frozenset = (
frozenset(
[
10058, # LegacyNotPrimary <=3.2 "not primary" error code
10107, # NotWritablePrimary
13435, # NotPrimaryNoSecondaryOk
11602, # InterruptedDueToReplStateChange
13436, # NotPrimaryOrSecondary
189, # PrimarySteppedDown
]
)
| _SHUTDOWN_CODES
)
# From the retryable writes spec.
_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
[
7, # HostNotFound
6, # HostUnreachable
89, # NetworkTimeout
9001, # SocketException
262, # ExceededTimeLimit
134, # ReadConcernMajorityNotAvailableYet
]
)
# Server code raised when re-authentication is required
_REAUTHENTICATION_REQUIRED_CODE: int = 391
# Server code raised when authentication fails.
_AUTHENTICATION_FAILURE_CODE: int = 18
# Note - to avoid bugs from forgetting which if these is all lowercase and
# which are camelCase, and at the same time avoid having to add a test for
# every command, use all lowercase here and test against command_name.lower().
_SENSITIVE_COMMANDS: set = {
"authenticate",
"saslstart",
"saslcontinue",
"getnonce",
"createuser",
"updateuser",
"copydbgetnonce",
"copydbsaslstart",
"copydb",
}

327
pymongo/helpers_shared.py Normal file
View File

@ -0,0 +1,327 @@
# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bits and pieces used by the driver that don't really fit elsewhere."""
from __future__ import annotations
import sys
import traceback
from collections import abc
from typing import (
TYPE_CHECKING,
Any,
Container,
Iterable,
Mapping,
NoReturn,
Optional,
Sequence,
Union,
)
from pymongo import ASCENDING
from pymongo.errors import (
CursorNotFound,
DuplicateKeyError,
ExecutionTimeout,
NotPrimaryError,
OperationFailure,
WriteConcernError,
WriteError,
WTimeoutError,
_wtimeout_error,
)
from pymongo.hello import HelloCompat
if TYPE_CHECKING:
from pymongo.cursor_shared import _Hint
from pymongo.operations import _IndexList
from pymongo.typings import _DocumentOut
# From the SDAM spec, the "node is shutting down" codes.
_SHUTDOWN_CODES: frozenset = frozenset(
[
11600, # InterruptedAtShutdown
91, # ShutdownInProgress
]
)
# From the SDAM spec, the "not primary" error codes are combined with the
# "node is recovering" error codes (of which the "node is shutting down"
# errors are a subset).
_NOT_PRIMARY_CODES: frozenset = (
frozenset(
[
10058, # LegacyNotPrimary <=3.2 "not primary" error code
10107, # NotWritablePrimary
13435, # NotPrimaryNoSecondaryOk
11602, # InterruptedDueToReplStateChange
13436, # NotPrimaryOrSecondary
189, # PrimarySteppedDown
]
)
| _SHUTDOWN_CODES
)
# From the retryable writes spec.
_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
[
7, # HostNotFound
6, # HostUnreachable
89, # NetworkTimeout
9001, # SocketException
262, # ExceededTimeLimit
134, # ReadConcernMajorityNotAvailableYet
]
)
# Server code raised when re-authentication is required
_REAUTHENTICATION_REQUIRED_CODE: int = 391
# Server code raised when authentication fails.
_AUTHENTICATION_FAILURE_CODE: int = 18
# Note - to avoid bugs from forgetting which if these is all lowercase and
# which are camelCase, and at the same time avoid having to add a test for
# every command, use all lowercase here and test against command_name.lower().
_SENSITIVE_COMMANDS: set = {
"authenticate",
"saslstart",
"saslcontinue",
"getnonce",
"createuser",
"updateuser",
"copydbgetnonce",
"copydbsaslstart",
"copydb",
}
def _gen_index_name(keys: _IndexList) -> str:
"""Generate an index name from the set of fields it is over."""
return "_".join(["{}_{}".format(*item) for item in keys])
def _index_list(
key_or_list: _Hint, direction: Optional[Union[int, str]] = None
) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]:
"""Helper to generate a list of (key, direction) pairs.
Takes such a list, or a single key, or a single key and direction.
"""
if direction is not None:
if not isinstance(key_or_list, str):
raise TypeError("Expected a string and a direction")
return [(key_or_list, direction)]
else:
if isinstance(key_or_list, str):
return [(key_or_list, ASCENDING)]
elif isinstance(key_or_list, abc.ItemsView):
return list(key_or_list) # type: ignore[arg-type]
elif isinstance(key_or_list, abc.Mapping):
return list(key_or_list.items())
elif not isinstance(key_or_list, (list, tuple)):
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
values: list[tuple[str, int]] = []
for item in key_or_list:
if isinstance(item, str):
item = (item, ASCENDING) # noqa: PLW2901
values.append(item)
return values
def _index_document(index_list: _IndexList) -> dict[str, Any]:
"""Helper to generate an index specifying document.
Takes a list of (key, direction) pairs.
"""
if not isinstance(index_list, (list, tuple, abc.Mapping)):
raise TypeError(
"must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list)
)
if not len(index_list):
raise ValueError("key_or_list must not be empty")
index: dict[str, Any] = {}
if isinstance(index_list, abc.Mapping):
for key in index_list:
value = index_list[key]
_validate_index_key_pair(key, value)
index[key] = value
else:
for item in index_list:
if isinstance(item, str):
item = (item, ASCENDING) # noqa: PLW2901
key, value = item
_validate_index_key_pair(key, value)
index[key] = value
return index
def _validate_index_key_pair(key: Any, value: Any) -> None:
if not isinstance(key, str):
raise TypeError("first item in each key pair must be an instance of str")
if not isinstance(value, (str, int, abc.Mapping)):
raise TypeError(
"second item in each key pair must be 1, -1, "
"'2d', or another valid MongoDB index specifier."
)
def _check_command_response(
response: _DocumentOut,
max_wire_version: Optional[int],
allowable_errors: Optional[Container[Union[int, str]]] = None,
parse_write_concern_error: bool = False,
) -> None:
"""Check the response to a command for errors."""
if "ok" not in response:
# Server didn't recognize our message as a command.
raise OperationFailure(
response.get("$err"), # type: ignore[arg-type]
response.get("code"),
response,
max_wire_version,
)
if parse_write_concern_error and "writeConcernError" in response:
_error = response["writeConcernError"]
_labels = response.get("errorLabels")
if _labels:
_error.update({"errorLabels": _labels})
_raise_write_concern_error(_error)
if response["ok"]:
return
details = response
# Mongos returns the error details in a 'raw' object
# for some errors.
if "raw" in response:
for shard in response["raw"].values():
# Grab the first non-empty raw error from a shard.
if shard.get("errmsg") and not shard.get("ok"):
details = shard
break
errmsg = details["errmsg"]
code = details.get("code")
# For allowable errors, only check for error messages when the code is not
# included.
if allowable_errors:
if code is not None:
if code in allowable_errors:
return
elif errmsg in allowable_errors:
return
# Server is "not primary" or "recovering"
if code is not None:
if code in _NOT_PRIMARY_CODES:
raise NotPrimaryError(errmsg, response)
elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg:
raise NotPrimaryError(errmsg, response)
# Other errors
# findAndModify with upsert can raise duplicate key error
if code in (11000, 11001, 12582):
raise DuplicateKeyError(errmsg, code, response, max_wire_version)
elif code == 50:
raise ExecutionTimeout(errmsg, code, response, max_wire_version)
elif code == 43:
raise CursorNotFound(errmsg, code, response, max_wire_version)
raise OperationFailure(errmsg, code, response, max_wire_version)
def _raise_last_write_error(write_errors: list[Any]) -> NoReturn:
# If the last batch had multiple errors only report
# the last error to emulate continue_on_error.
error = write_errors[-1]
if error.get("code") == 11000:
raise DuplicateKeyError(error.get("errmsg"), 11000, error)
raise WriteError(error.get("errmsg"), error.get("code"), error)
def _raise_write_concern_error(error: Any) -> NoReturn:
if _wtimeout_error(error):
# Make sure we raise WTimeoutError
raise WTimeoutError(error.get("errmsg"), error.get("code"), error)
raise WriteConcernError(error.get("errmsg"), error.get("code"), error)
def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
"""Return the writeConcernError or None."""
wce = result.get("writeConcernError")
if wce:
# The server reports errorLabels at the top level but it's more
# convenient to attach it to the writeConcernError doc itself.
error_labels = result.get("errorLabels")
if error_labels:
# Copy to avoid changing the original document.
wce = wce.copy()
wce["errorLabels"] = error_labels
return wce
def _check_write_command_response(result: Mapping[str, Any]) -> None:
"""Backward compatibility helper for write command error handling."""
# Prefer write errors over write concern errors
write_errors = result.get("writeErrors")
if write_errors:
_raise_last_write_error(write_errors)
wce = _get_wce_doc(result)
if wce:
_raise_write_concern_error(wce)
def _fields_list_to_dict(
fields: Union[Mapping[str, Any], Iterable[str]], option_name: str
) -> Mapping[str, Any]:
"""Takes a sequence of field names and returns a matching dictionary.
["a", "b"] becomes {"a": 1, "b": 1}
and
["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1}
"""
if isinstance(fields, abc.Mapping):
return fields
if isinstance(fields, (abc.Sequence, abc.Set)):
if not all(isinstance(field, str) for field in fields):
raise TypeError(f"{option_name} must be a list of key names, each an instance of str")
return dict.fromkeys(fields, 1)
raise TypeError(f"{option_name} must be a mapping or list of key names")
def _handle_exception() -> None:
"""Print exceptions raised by subscribers to stderr."""
# Heavily influenced by logging.Handler.handleError.
# See note here:
# https://docs.python.org/3.4/library/sys.html#sys.__stderr__
if sys.stderr:
einfo = sys.exc_info()
try:
traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr)
except OSError:
pass
finally:
del einfo

View File

@ -21,9 +21,7 @@ from typing import Any
from bson import UuidRepresentation, json_util
from bson.json_util import JSONOptions, _truncate_documents
from pymongo.synchronous.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
_IS_SYNC = True
from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
class _CommandStatusMessage(str, enum.Enum):

View File

@ -34,9 +34,8 @@ from pymongo.errors import ConfigurationError
from pymongo.server_type import SERVER_TYPE
if TYPE_CHECKING:
from pymongo.synchronous.server_selectors import Selection
from pymongo.server_selectors import Selection
_IS_SYNC = True
# Constant defined in Max Staleness Spec: An idle primary writes a no-op every
# 10 seconds to refresh secondaries' lastWriteDate values.

File diff suppressed because it is too large Load Diff

View File

@ -19,3 +19,4 @@ from pymongo.synchronous.mongo_client import * # noqa: F403
from pymongo.synchronous.mongo_client import __doc__ as original_doc
__doc__ = original_doc
__all__ = ["MongoClient"] # noqa: F405

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
# Copyright 2024-present MongoDB, Inc.
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,10 +12,613 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Re-import of synchronous Operations API for compatibility."""
"""Operation class definitions."""
from __future__ import annotations
from pymongo.synchronous.operations import * # noqa: F403
from pymongo.synchronous.operations import __doc__ as original_doc
import enum
from typing import (
TYPE_CHECKING,
Any,
Generic,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
__doc__ = original_doc
from bson.raw_bson import RawBSONDocument
from pymongo import helpers_shared
from pymongo.collation import validate_collation_or_none
from pymongo.common import validate_is_mapping, validate_list
from pymongo.helpers_shared import _gen_index_name, _index_document, _index_list
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
from pymongo.write_concern import validate_boolean
if TYPE_CHECKING:
from pymongo.typings import _AgnosticBulk
# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary
_IndexList = Union[
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
]
_IndexKeyHint = Union[str, _IndexList]
class _Op(str, enum.Enum):
ABORT = "abortTransaction"
AGGREGATE = "aggregate"
COMMIT = "commitTransaction"
COUNT = "count"
CREATE = "create"
CREATE_INDEXES = "createIndexes"
CREATE_SEARCH_INDEXES = "createSearchIndexes"
DELETE = "delete"
DISTINCT = "distinct"
DROP = "drop"
DROP_DATABASE = "dropDatabase"
DROP_INDEXES = "dropIndexes"
DROP_SEARCH_INDEXES = "dropSearchIndexes"
END_SESSIONS = "endSessions"
FIND_AND_MODIFY = "findAndModify"
FIND = "find"
INSERT = "insert"
LIST_COLLECTIONS = "listCollections"
LIST_INDEXES = "listIndexes"
LIST_SEARCH_INDEX = "listSearchIndexes"
LIST_DATABASES = "listDatabases"
UPDATE = "update"
UPDATE_INDEX = "updateIndex"
UPDATE_SEARCH_INDEX = "updateSearchIndex"
RENAME = "rename"
GETMORE = "getMore"
KILL_CURSORS = "killCursors"
TEST = "testOperation"
class InsertOne(Generic[_DocumentType]):
"""Represents an insert_one operation."""
__slots__ = ("_doc",)
def __init__(self, document: _DocumentType) -> None:
"""Create an InsertOne instance.
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`.
:param document: The document to insert. If the document is missing an
_id field one will be added.
"""
self._doc = document
def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None:
"""Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`."""
bulkobj.add_insert(self._doc) # type: ignore[arg-type]
def __repr__(self) -> str:
return f"InsertOne({self._doc!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return other._doc == self._doc
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
class DeleteOne:
"""Represents a delete_one operation."""
__slots__ = ("_filter", "_collation", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create a DeleteOne instance.
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`.
:param filter: A query that matches the document to delete.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.4 and above.
.. versionchanged:: 3.11
Added the ``hint`` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._collation = collation
def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None:
"""Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`."""
bulkobj.add_delete(
self._filter,
1,
collation=validate_collation_or_none(self._collation),
hint=self._hint,
)
def __repr__(self) -> str:
return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return (other._filter, other._collation, other._hint) == (
self._filter,
self._collation,
self._hint,
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
class DeleteMany:
"""Represents a delete_many operation."""
__slots__ = ("_filter", "_collation", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create a DeleteMany instance.
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`.
:param filter: A query that matches the documents to delete.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.4 and above.
.. versionchanged:: 3.11
Added the ``hint`` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._collation = collation
def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None:
"""Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`."""
bulkobj.add_delete(
self._filter,
0,
collation=validate_collation_or_none(self._collation),
hint=self._hint,
)
def __repr__(self) -> str:
return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return (other._filter, other._collation, other._hint) == (
self._filter,
self._collation,
self._hint,
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
class ReplaceOne(Generic[_DocumentType]):
"""Represents a replace_one operation."""
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
replacement: Union[_DocumentType, RawBSONDocument],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create a ReplaceOne instance.
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`.
:param filter: A query that matches the document to replace.
:param replacement: The new document.
:param upsert: If ``True``, perform an insert if no documents
match the filter.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
.. versionchanged:: 3.11
Added the ``hint`` option.
.. versionchanged:: 3.5
Added the ``collation`` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
if upsert is not None:
validate_boolean("upsert", upsert)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._doc = replacement
self._upsert = upsert
self._collation = collation
def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None:
"""Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`."""
bulkobj.add_replace(
self._filter,
self._doc,
self._upsert,
collation=validate_collation_or_none(self._collation),
hint=self._hint,
)
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return (
other._filter,
other._doc,
other._upsert,
other._collation,
other._hint,
) == (
self._filter,
self._doc,
self._upsert,
self._collation,
other._hint,
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __repr__(self) -> str:
return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format(
self.__class__.__name__,
self._filter,
self._doc,
self._upsert,
self._collation,
self._hint,
)
class _UpdateOp:
"""Private base class for update operations."""
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
doc: Union[Mapping[str, Any], _Pipeline],
upsert: bool,
collation: Optional[_CollationIn],
array_filters: Optional[list[Mapping[str, Any]]],
hint: Optional[_IndexKeyHint],
):
if filter is not None:
validate_is_mapping("filter", filter)
if upsert is not None:
validate_boolean("upsert", upsert)
if array_filters is not None:
validate_list("array_filters", array_filters)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._doc = doc
self._upsert = upsert
self._collation = collation
self._array_filters = array_filters
def __eq__(self, other: object) -> bool:
if isinstance(other, type(self)):
return (
other._filter,
other._doc,
other._upsert,
other._collation,
other._array_filters,
other._hint,
) == (
self._filter,
self._doc,
self._upsert,
self._collation,
self._array_filters,
self._hint,
)
return NotImplemented
def __repr__(self) -> str:
return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format(
self.__class__.__name__,
self._filter,
self._doc,
self._upsert,
self._collation,
self._array_filters,
self._hint,
)
class UpdateOne(_UpdateOp):
"""Represents an update_one operation."""
__slots__ = ()
def __init__(
self,
filter: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Represents an update_one operation.
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`.
:param filter: A query that matches the document to update.
:param update: The modifications to apply.
:param upsert: If ``True``, perform an insert if no documents
match the filter.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param array_filters: A list of filters specifying which
array elements an update should apply.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
.. versionchanged:: 3.11
Added the `hint` option.
.. versionchanged:: 3.9
Added the ability to accept a pipeline as the `update`.
.. versionchanged:: 3.6
Added the `array_filters` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
super().__init__(filter, update, upsert, collation, array_filters, hint)
def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None:
"""Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`."""
bulkobj.add_update(
self._filter,
self._doc,
False,
self._upsert,
collation=validate_collation_or_none(self._collation),
array_filters=self._array_filters,
hint=self._hint,
)
class UpdateMany(_UpdateOp):
"""Represents an update_many operation."""
__slots__ = ()
def __init__(
self,
filter: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create an UpdateMany instance.
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`.
:param filter: A query that matches the documents to update.
:param update: The modifications to apply.
:param upsert: If ``True``, perform an insert if no documents
match the filter.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param array_filters: A list of filters specifying which
array elements an update should apply.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
.. versionchanged:: 3.11
Added the `hint` option.
.. versionchanged:: 3.9
Added the ability to accept a pipeline as the `update`.
.. versionchanged:: 3.6
Added the `array_filters` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
super().__init__(filter, update, upsert, collation, array_filters, hint)
def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None:
"""Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`."""
bulkobj.add_update(
self._filter,
self._doc,
True,
self._upsert,
collation=validate_collation_or_none(self._collation),
array_filters=self._array_filters,
hint=self._hint,
)
class IndexModel:
"""Represents an index to create."""
__slots__ = ("__document",)
def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None:
"""Create an Index instance.
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_indexes` and :meth:`~pymongo.collection.Collection.create_indexes`.
Takes either a single key or a list containing (key, direction) pairs
or keys. If no direction is given, :data:`~pymongo.ASCENDING` will
be assumed.
The key(s) must be an instance of :class:`str`, and the direction(s) must
be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`,
:data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`,
:data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`).
Valid options include, but are not limited to:
- `name`: custom name to use for this index - if none is
given, a name will be generated.
- `unique`: if ``True``, creates a uniqueness constraint on the index.
- `background`: if ``True``, this index should be created in the
background.
- `sparse`: if ``True``, omit from the index any documents that lack
the indexed field.
- `bucketSize`: for use with geoHaystack indexes.
Number of documents to group together within a certain proximity
to a given longitude and latitude.
- `min`: minimum value for keys in a :data:`~pymongo.GEO2D`
index.
- `max`: maximum value for keys in a :data:`~pymongo.GEO2D`
index.
- `expireAfterSeconds`: <int> Used to create an expiring (TTL)
collection. MongoDB will automatically delete documents from
this collection after <int> seconds. The indexed field must
be a UTC datetime or the data will not expire.
- `partialFilterExpression`: A document that specifies a filter for
a partial index.
- `collation`: An instance of :class:`~pymongo.collation.Collation`
that specifies the collation to use.
- `wildcardProjection`: Allows users to include or exclude specific
field paths from a `wildcard index`_ using the { "$**" : 1} key
pattern. Requires MongoDB >= 4.2.
- `hidden`: if ``True``, this index will be hidden from the query
planner and will not be evaluated as part of query plan
selection. Requires MongoDB >= 4.4.
See the MongoDB documentation for a full list of supported options by
server version.
:param keys: a single key or a list containing (key, direction) pairs
or keys specifying the index to create.
:param kwargs: any additional index creation
options (see the above list) should be passed as keyword
arguments.
.. versionchanged:: 3.11
Added the ``hidden`` option.
.. versionchanged:: 3.2
Added the ``partialFilterExpression`` option to support partial
indexes.
.. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/
"""
keys = _index_list(keys)
if kwargs.get("name") is None:
kwargs["name"] = _gen_index_name(keys)
kwargs["key"] = _index_document(keys)
collation = validate_collation_or_none(kwargs.pop("collation", None))
self.__document = kwargs
if collation is not None:
self.__document["collation"] = collation
@property
def document(self) -> dict[str, Any]:
"""An index document suitable for passing to the createIndexes
command.
"""
return self.__document
class SearchIndexModel:
"""Represents a search index to create."""
__slots__ = ("__document",)
def __init__(
self,
definition: Mapping[str, Any],
name: Optional[str] = None,
type: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Create a Search Index instance.
For use with :meth:`~pymongo.collection.AsyncCollection.create_search_index` and :meth:`~pymongo.collection.AsyncCollection.create_search_indexes`.
:param definition: The definition for this index.
:param name: The name for this index, if present.
:param type: The type for this index which defaults to "search". Alternative values include "vectorSearch".
:param kwargs: Keyword arguments supplying any additional options.
.. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster.
.. versionadded:: 4.5
.. versionchanged:: 4.7
Added the type and kwargs arguments.
"""
self.__document: dict[str, Any] = {}
if name is not None:
self.__document["name"] = name
self.__document["definition"] = definition
if type is not None:
self.__document["type"] = type
self.__document.update(kwargs)
@property
def document(self) -> Mapping[str, Any]:
"""The document for this index."""
return self.__document

View File

@ -19,3 +19,4 @@ from pymongo.synchronous.pool import * # noqa: F403
from pymongo.synchronous.pool import __doc__ as original_doc
__doc__ = original_doc
__all__ = ["PoolOptions"] # noqa: F405

507
pymongo/pool_options.py Normal file
View File

@ -0,0 +1,507 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""AsyncConnection pool options for AsyncMongoClient/MongoClient."""
from __future__ import annotations
import copy
import os
import platform
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any, MutableMapping, Optional
import bson
from pymongo import __version__
from pymongo.common import (
MAX_CONNECTING,
MAX_IDLE_TIME_SEC,
MAX_POOL_SIZE,
MIN_POOL_SIZE,
WAIT_QUEUE_TIMEOUT,
)
if TYPE_CHECKING:
from pymongo.auth_shared import MongoCredential
from pymongo.compression_support import CompressionSettings
from pymongo.driver_info import DriverInfo
from pymongo.monitoring import _EventListeners
from pymongo.pyopenssl_context import SSLContext
from pymongo.server_api import ServerApi
_METADATA: dict[str, Any] = {"driver": {"name": "PyMongo", "version": __version__}}
if sys.platform.startswith("linux"):
# platform.linux_distribution was deprecated in Python 3.5
# and removed in Python 3.8. Starting in Python 3.5 it
# raises DeprecationWarning
# DeprecationWarning: dist() and linux_distribution() functions are deprecated in Python 3.5
_name = platform.system()
_METADATA["os"] = {
"type": _name,
"name": _name,
"architecture": platform.machine(),
# Kernel version (e.g. 4.4.0-17-generic).
"version": platform.release(),
}
elif sys.platform == "darwin":
_METADATA["os"] = {
"type": platform.system(),
"name": platform.system(),
"architecture": platform.machine(),
# (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin
# kernel version.
"version": platform.mac_ver()[0],
}
elif sys.platform == "win32":
_METADATA["os"] = {
"type": platform.system(),
# "Windows XP", "Windows 7", "Windows 10", etc.
"name": " ".join((platform.system(), platform.release())),
"architecture": platform.machine(),
# Windows patch level (e.g. 5.1.2600-SP3)
"version": "-".join(platform.win32_ver()[1:3]),
}
elif sys.platform.startswith("java"):
_name, _ver, _arch = platform.java_ver()[-1]
_METADATA["os"] = {
# Linux, Windows 7, Mac OS X, etc.
"type": _name,
"name": _name,
# x86, x86_64, AMD64, etc.
"architecture": _arch,
# Linux kernel version, OSX version, etc.
"version": _ver,
}
else:
# Get potential alias (e.g. SunOS 5.11 becomes Solaris 2.11)
_aliased = platform.system_alias(platform.system(), platform.release(), platform.version())
_METADATA["os"] = {
"type": platform.system(),
"name": " ".join([part for part in _aliased[:2] if part]),
"architecture": platform.machine(),
"version": _aliased[2],
}
if platform.python_implementation().startswith("PyPy"):
_METADATA["platform"] = " ".join(
(
platform.python_implementation(),
".".join(map(str, sys.pypy_version_info)), # type: ignore
"(Python %s)" % ".".join(map(str, sys.version_info)),
)
)
elif sys.platform.startswith("java"):
_METADATA["platform"] = " ".join(
(
platform.python_implementation(),
".".join(map(str, sys.version_info)),
"(%s)" % " ".join((platform.system(), platform.release())),
)
)
else:
_METADATA["platform"] = " ".join(
(platform.python_implementation(), ".".join(map(str, sys.version_info)))
)
DOCKER_ENV_PATH = "/.dockerenv"
ENV_VAR_K8S = "KUBERNETES_SERVICE_HOST"
RUNTIME_NAME_DOCKER = "docker"
ORCHESTRATOR_NAME_K8S = "kubernetes"
def get_container_env_info() -> dict[str, str]:
"""Returns the runtime and orchestrator of a container.
If neither value is present, the metadata client.env.container field will be omitted."""
container = {}
if Path(DOCKER_ENV_PATH).exists():
container["runtime"] = RUNTIME_NAME_DOCKER
if os.getenv(ENV_VAR_K8S):
container["orchestrator"] = ORCHESTRATOR_NAME_K8S
return container
def _is_lambda() -> bool:
if os.getenv("AWS_LAMBDA_RUNTIME_API"):
return True
env = os.getenv("AWS_EXECUTION_ENV")
if env:
return env.startswith("AWS_Lambda_")
return False
def _is_azure_func() -> bool:
return bool(os.getenv("FUNCTIONS_WORKER_RUNTIME"))
def _is_gcp_func() -> bool:
return bool(os.getenv("K_SERVICE") or os.getenv("FUNCTION_NAME"))
def _is_vercel() -> bool:
return bool(os.getenv("VERCEL"))
def _is_faas() -> bool:
return _is_lambda() or _is_azure_func() or _is_gcp_func() or _is_vercel()
def _getenv_int(key: str) -> Optional[int]:
"""Like os.getenv but returns an int, or None if the value is missing/malformed."""
val = os.getenv(key)
if not val:
return None
try:
return int(val)
except ValueError:
return None
def _metadata_env() -> dict[str, Any]:
env: dict[str, Any] = {}
container = get_container_env_info()
if container:
env["container"] = container
# Skip if multiple (or no) envs are matched.
if (_is_lambda(), _is_azure_func(), _is_gcp_func(), _is_vercel()).count(True) != 1:
return env
if _is_lambda():
env["name"] = "aws.lambda"
region = os.getenv("AWS_REGION")
if region:
env["region"] = region
memory_mb = _getenv_int("AWS_LAMBDA_FUNCTION_MEMORY_SIZE")
if memory_mb is not None:
env["memory_mb"] = memory_mb
elif _is_azure_func():
env["name"] = "azure.func"
elif _is_gcp_func():
env["name"] = "gcp.func"
region = os.getenv("FUNCTION_REGION")
if region:
env["region"] = region
memory_mb = _getenv_int("FUNCTION_MEMORY_MB")
if memory_mb is not None:
env["memory_mb"] = memory_mb
timeout_sec = _getenv_int("FUNCTION_TIMEOUT_SEC")
if timeout_sec is not None:
env["timeout_sec"] = timeout_sec
elif _is_vercel():
env["name"] = "vercel"
region = os.getenv("VERCEL_REGION")
if region:
env["region"] = region
return env
_MAX_METADATA_SIZE = 512
# See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations
def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None:
"""Perform metadata truncation."""
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
return
# 1. Omit fields from env except env.name.
env_name = metadata.get("env", {}).get("name")
if env_name:
metadata["env"] = {"name": env_name}
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
return
# 2. Omit fields from os except os.type.
os_type = metadata.get("os", {}).get("type")
if os_type:
metadata["os"] = {"type": os_type}
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
return
# 3. Omit the env document entirely.
metadata.pop("env", None)
encoded_size = len(bson.encode(metadata))
if encoded_size <= _MAX_METADATA_SIZE:
return
# 4. Truncate platform.
overflow = encoded_size - _MAX_METADATA_SIZE
plat = metadata.get("platform", "")
if plat:
plat = plat[:-overflow]
if plat:
metadata["platform"] = plat
else:
metadata.pop("platform", None)
encoded_size = len(bson.encode(metadata))
if encoded_size <= _MAX_METADATA_SIZE:
return
# 5. Truncate driver info.
overflow = encoded_size - _MAX_METADATA_SIZE
driver = metadata.get("driver", {})
if driver:
# Truncate driver version.
driver_version = driver.get("version")[:-overflow]
if len(driver_version) >= len(_METADATA["driver"]["version"]):
metadata["driver"]["version"] = driver_version
else:
metadata["driver"]["version"] = _METADATA["driver"]["version"]
encoded_size = len(bson.encode(metadata))
if encoded_size <= _MAX_METADATA_SIZE:
return
# Truncate driver name.
overflow = encoded_size - _MAX_METADATA_SIZE
driver_name = driver.get("name")[:-overflow]
if len(driver_name) >= len(_METADATA["driver"]["name"]):
metadata["driver"]["name"] = driver_name
else:
metadata["driver"]["name"] = _METADATA["driver"]["name"]
# If the first getaddrinfo call of this interpreter's life is on a thread,
# while the main thread holds the import lock, getaddrinfo deadlocks trying
# to import the IDNA codec. Import it here, where presumably we're on the
# main thread, to avoid the deadlock. See PYTHON-607.
"foo".encode("idna")
class PoolOptions:
"""Read only connection pool options for an AsyncMongoClient/MongoClient.
Should not be instantiated directly by application developers. Access
a client's pool options via
:attr:`~pymongo.client_options.ClientOptions.pool_options` instead::
pool_opts = client.options.pool_options
pool_opts.max_pool_size
pool_opts.min_pool_size
"""
__slots__ = (
"__max_pool_size",
"__min_pool_size",
"__max_idle_time_seconds",
"__connect_timeout",
"__socket_timeout",
"__wait_queue_timeout",
"__ssl_context",
"__tls_allow_invalid_hostnames",
"__event_listeners",
"__appname",
"__driver",
"__metadata",
"__compression_settings",
"__max_connecting",
"__pause_enabled",
"__server_api",
"__load_balanced",
"__credentials",
)
def __init__(
self,
max_pool_size: int = MAX_POOL_SIZE,
min_pool_size: int = MIN_POOL_SIZE,
max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC,
connect_timeout: Optional[float] = None,
socket_timeout: Optional[float] = None,
wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT,
ssl_context: Optional[SSLContext] = None,
tls_allow_invalid_hostnames: bool = False,
event_listeners: Optional[_EventListeners] = None,
appname: Optional[str] = None,
driver: Optional[DriverInfo] = None,
compression_settings: Optional[CompressionSettings] = None,
max_connecting: int = MAX_CONNECTING,
pause_enabled: bool = True,
server_api: Optional[ServerApi] = None,
load_balanced: Optional[bool] = None,
credentials: Optional[MongoCredential] = None,
):
self.__max_pool_size = max_pool_size
self.__min_pool_size = min_pool_size
self.__max_idle_time_seconds = max_idle_time_seconds
self.__connect_timeout = connect_timeout
self.__socket_timeout = socket_timeout
self.__wait_queue_timeout = wait_queue_timeout
self.__ssl_context = ssl_context
self.__tls_allow_invalid_hostnames = tls_allow_invalid_hostnames
self.__event_listeners = event_listeners
self.__appname = appname
self.__driver = driver
self.__compression_settings = compression_settings
self.__max_connecting = max_connecting
self.__pause_enabled = pause_enabled
self.__server_api = server_api
self.__load_balanced = load_balanced
self.__credentials = credentials
self.__metadata = copy.deepcopy(_METADATA)
if appname:
self.__metadata["application"] = {"name": appname}
# Combine the "driver" AsyncMongoClient option with PyMongo's info, like:
# {
# 'driver': {
# 'name': 'PyMongo|MyDriver',
# 'version': '4.2.0|1.2.3',
# },
# 'platform': 'CPython 3.8.0|MyPlatform'
# }
if driver:
if driver.name:
self.__metadata["driver"]["name"] = "{}|{}".format(
_METADATA["driver"]["name"],
driver.name,
)
if driver.version:
self.__metadata["driver"]["version"] = "{}|{}".format(
_METADATA["driver"]["version"],
driver.version,
)
if driver.platform:
self.__metadata["platform"] = "{}|{}".format(_METADATA["platform"], driver.platform)
env = _metadata_env()
if env:
self.__metadata["env"] = env
_truncate_metadata(self.__metadata)
@property
def _credentials(self) -> Optional[MongoCredential]:
"""A :class:`~pymongo.auth.MongoCredentials` instance or None."""
return self.__credentials
@property
def non_default_options(self) -> dict[str, Any]:
"""The non-default options this pool was created with.
Added for CMAP's :class:`PoolCreatedEvent`.
"""
opts = {}
if self.__max_pool_size != MAX_POOL_SIZE:
opts["maxPoolSize"] = self.__max_pool_size
if self.__min_pool_size != MIN_POOL_SIZE:
opts["minPoolSize"] = self.__min_pool_size
if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC:
assert self.__max_idle_time_seconds is not None
opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000
if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT:
assert self.__wait_queue_timeout is not None
opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000
if self.__max_connecting != MAX_CONNECTING:
opts["maxConnecting"] = self.__max_connecting
return opts
@property
def max_pool_size(self) -> float:
"""The maximum allowable number of concurrent connections to each
connected server. Requests to a server will block if there are
`maxPoolSize` outstanding connections to the requested server.
Defaults to 100. Cannot be 0.
When a server's pool has reached `max_pool_size`, operations for that
server block waiting for a socket to be returned to the pool. If
``waitQueueTimeoutMS`` is set, a blocked operation will raise
:exc:`~pymongo.errors.ConnectionFailure` after a timeout.
By default ``waitQueueTimeoutMS`` is not set.
"""
return self.__max_pool_size
@property
def min_pool_size(self) -> int:
"""The minimum required number of concurrent connections that the pool
will maintain to each connected server. Default is 0.
"""
return self.__min_pool_size
@property
def max_connecting(self) -> int:
"""The maximum number of concurrent connection creation attempts per
pool. Defaults to 2.
"""
return self.__max_connecting
@property
def pause_enabled(self) -> bool:
return self.__pause_enabled
@property
def max_idle_time_seconds(self) -> Optional[int]:
"""The maximum number of seconds that a connection can remain
idle in the pool before being removed and replaced. Defaults to
`None` (no limit).
"""
return self.__max_idle_time_seconds
@property
def connect_timeout(self) -> Optional[float]:
"""How long a connection can take to be opened before timing out."""
return self.__connect_timeout
@property
def socket_timeout(self) -> Optional[float]:
"""How long a send or receive on a socket can take before timing out."""
return self.__socket_timeout
@property
def wait_queue_timeout(self) -> Optional[int]:
"""How long a thread will wait for a socket from the pool if the pool
has no free sockets.
"""
return self.__wait_queue_timeout
@property
def _ssl_context(self) -> Optional[SSLContext]:
"""An SSLContext instance or None."""
return self.__ssl_context
@property
def tls_allow_invalid_hostnames(self) -> bool:
"""If True skip ssl.match_hostname."""
return self.__tls_allow_invalid_hostnames
@property
def _event_listeners(self) -> Optional[_EventListeners]:
"""An instance of pymongo.monitoring._EventListeners."""
return self.__event_listeners
@property
def appname(self) -> Optional[str]:
"""The application name, for sending with hello in server handshake."""
return self.__appname
@property
def driver(self) -> Optional[DriverInfo]:
"""Driver name and version, for sending with hello in handshake."""
return self.__driver
@property
def _compression_settings(self) -> Optional[CompressionSettings]:
return self.__compression_settings
@property
def metadata(self) -> dict[str, Any]:
"""A dict of metadata about the application, driver, os, and platform."""
return self.__metadata.copy()
@property
def server_api(self) -> Optional[ServerApi]:
"""A pymongo.server_api.ServerApi or None."""
return self.__server_api
@property
def load_balanced(self) -> Optional[bool]:
"""True if this Pool is configured in load balanced mode."""
return self.__load_balanced

View File

@ -1,6 +1,6 @@
# Copyright 2024-present MongoDB, Inc.
# Copyright 2012-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License",
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
@ -12,10 +12,612 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Re-import of synchronous ReadPreferences API for compatibility."""
"""Utilities for choosing which member of a replica set to read from."""
from __future__ import annotations
from pymongo.synchronous.read_preferences import * # noqa: F403
from pymongo.synchronous.read_preferences import __doc__ as original_doc
from collections import abc
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
__doc__ = original_doc
from pymongo import max_staleness_selectors
from pymongo.errors import ConfigurationError
from pymongo.server_selectors import (
member_with_tags_server_selector,
secondary_with_tags_server_selector,
)
if TYPE_CHECKING:
from pymongo.server_selectors import Selection
from pymongo.topology_description import TopologyDescription
_PRIMARY = 0
_PRIMARY_PREFERRED = 1
_SECONDARY = 2
_SECONDARY_PREFERRED = 3
_NEAREST = 4
_MONGOS_MODES = (
"primary",
"primaryPreferred",
"secondary",
"secondaryPreferred",
"nearest",
)
_Hedge = Mapping[str, Any]
_TagSets = Sequence[Mapping[str, Any]]
def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]:
"""Validate tag sets for a MongoClient."""
if tag_sets is None:
return tag_sets
if not isinstance(tag_sets, (list, tuple)):
raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence")
if len(tag_sets) == 0:
raise ValueError(
f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags"
)
for tags in tag_sets:
if not isinstance(tags, abc.Mapping):
raise TypeError(
f"Tag set {tags!r} invalid, must be an instance of dict, "
"bson.son.SON or other type that inherits from "
"collection.Mapping"
)
return list(tag_sets)
def _invalid_max_staleness_msg(max_staleness: Any) -> str:
return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness
# Some duplication with common.py to avoid import cycle.
def _validate_max_staleness(max_staleness: Any) -> int:
"""Validate max_staleness."""
if max_staleness == -1:
return -1
if not isinstance(max_staleness, int):
raise TypeError(_invalid_max_staleness_msg(max_staleness))
if max_staleness <= 0:
raise ValueError(_invalid_max_staleness_msg(max_staleness))
return max_staleness
def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]:
"""Validate hedge."""
if hedge is None:
return None
if not isinstance(hedge, dict):
raise TypeError(f"hedge must be a dictionary, not {hedge!r}")
return hedge
class _ServerMode:
"""Base class for all read preferences."""
__slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge")
def __init__(
self,
mode: int,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
self.__mongos_mode = _MONGOS_MODES[mode]
self.__mode = mode
self.__tag_sets = _validate_tag_sets(tag_sets)
self.__max_staleness = _validate_max_staleness(max_staleness)
self.__hedge = _validate_hedge(hedge)
@property
def name(self) -> str:
"""The name of this read preference."""
return self.__class__.__name__
@property
def mongos_mode(self) -> str:
"""The mongos mode of this read preference."""
return self.__mongos_mode
@property
def document(self) -> dict[str, Any]:
"""Read preference as a document."""
doc: dict[str, Any] = {"mode": self.__mongos_mode}
if self.__tag_sets not in (None, [{}]):
doc["tags"] = self.__tag_sets
if self.__max_staleness != -1:
doc["maxStalenessSeconds"] = self.__max_staleness
if self.__hedge not in (None, {}):
doc["hedge"] = self.__hedge
return doc
@property
def mode(self) -> int:
"""The mode of this read preference instance."""
return self.__mode
@property
def tag_sets(self) -> _TagSets:
"""Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to
read only from members whose ``dc`` tag has the value ``"ny"``.
To specify a priority-order for tag sets, provide a list of
tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag
set, ``{}``, means "read from any member that matches the mode,
ignoring tags." MongoClient tries each set of tags in turn
until it finds a set of tags with at least one matching member.
For example, to only send a query to an analytic node::
Nearest(tag_sets=[{"node":"analytics"}])
Or using :class:`SecondaryPreferred`::
SecondaryPreferred(tag_sets=[{"node":"analytics"}])
.. seealso:: `Data-Center Awareness
<https://www.mongodb.com/docs/manual/data-center-awareness/>`_
"""
return list(self.__tag_sets) if self.__tag_sets else [{}]
@property
def max_staleness(self) -> int:
"""The maximum estimated length of time (in seconds) a replica set
secondary can fall behind the primary in replication before it will
no longer be selected for operations, or -1 for no maximum.
"""
return self.__max_staleness
@property
def hedge(self) -> Optional[_Hedge]:
"""The read preference ``hedge`` parameter.
A dictionary that configures how the server will perform hedged reads.
It consists of the following keys:
- ``enabled``: Enables or disables hedged reads in sharded clusters.
Hedged reads are automatically enabled in MongoDB 4.4+ when using a
``nearest`` read preference. To explicitly enable hedged reads, set
the ``enabled`` key to ``true``::
>>> Nearest(hedge={'enabled': True})
To explicitly disable hedged reads, set the ``enabled`` key to
``False``::
>>> Nearest(hedge={'enabled': False})
.. versionadded:: 3.11
"""
return self.__hedge
@property
def min_wire_version(self) -> int:
"""The wire protocol version the server must support.
Some read preferences impose version requirements on all servers (e.g.
maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5).
All servers' maxWireVersion must be at least this read preference's
`min_wire_version`, or the driver raises
:exc:`~pymongo.errors.ConfigurationError`.
"""
return 0 if self.__max_staleness == -1 else 5
def __repr__(self) -> str:
return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format(
self.name,
self.__tag_sets,
self.__max_staleness,
self.__hedge,
)
def __eq__(self, other: Any) -> bool:
if isinstance(other, _ServerMode):
return (
self.mode == other.mode
and self.tag_sets == other.tag_sets
and self.max_staleness == other.max_staleness
and self.hedge == other.hedge
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __getstate__(self) -> dict[str, Any]:
"""Return value of object for pickling.
Needed explicitly because __slots__() defined.
"""
return {
"mode": self.__mode,
"tag_sets": self.__tag_sets,
"max_staleness": self.__max_staleness,
"hedge": self.__hedge,
}
def __setstate__(self, value: Mapping[str, Any]) -> None:
"""Restore from pickling."""
self.__mode = value["mode"]
self.__mongos_mode = _MONGOS_MODES[self.__mode]
self.__tag_sets = _validate_tag_sets(value["tag_sets"])
self.__max_staleness = _validate_max_staleness(value["max_staleness"])
self.__hedge = _validate_hedge(value["hedge"])
def __call__(self, selection: Selection) -> Selection:
return selection
class Primary(_ServerMode):
"""Primary read preference.
* When directly connected to one mongod queries are allowed if the server
is standalone or a replica set primary.
* When connected to a mongos queries are sent to the primary of a shard.
* When connected to a replica set queries are sent to the primary of
the replica set.
"""
__slots__ = ()
def __init__(self) -> None:
super().__init__(_PRIMARY)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to a Selection."""
return selection.primary_selection
def __repr__(self) -> str:
return "Primary()"
def __eq__(self, other: Any) -> bool:
if isinstance(other, _ServerMode):
return other.mode == _PRIMARY
return NotImplemented
class PrimaryPreferred(_ServerMode):
"""PrimaryPreferred read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are sent to the primary of a shard if
available, otherwise a shard secondary.
* When connected to a replica set queries are sent to the primary if
available, otherwise a secondary.
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
created reads will be routed to an available secondary until the
primary of the replica set is discovered.
:param tag_sets: The :attr:`~tag_sets` to use if the primary is not
available.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` to use if the primary is not available.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
if selection.primary:
return selection.primary_selection
else:
return secondary_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
class Secondary(_ServerMode):
"""Secondary read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among shard
secondaries. An error is raised if no secondaries are available.
* When connected to a replica set queries are distributed among
secondaries. An error is raised if no secondaries are available.
:param tag_sets: The :attr:`~tag_sets` for this read preference.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_SECONDARY, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
return secondary_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
class SecondaryPreferred(_ServerMode):
"""SecondaryPreferred read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among shard
secondaries, or the shard primary if no secondary is available.
* When connected to a replica set queries are distributed among
secondaries, or the primary if no secondary is available.
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
created reads will be routed to the primary of the replica set until
an available secondary is discovered.
:param tag_sets: The :attr:`~tag_sets` for this read preference.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
secondaries = secondary_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
if secondaries:
return secondaries
else:
return selection.primary_selection
class Nearest(_ServerMode):
"""Nearest read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among all members of
a shard.
* When connected to a replica set queries are distributed among all
members.
:param tag_sets: The :attr:`~tag_sets` for this read preference.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_NEAREST, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
return member_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
class _AggWritePref:
"""Agg $out/$merge write preference.
* If there are readable servers and there is any pre-5.0 server, use
primary read preference.
* Otherwise use `pref` read preference.
:param pref: The read preference to use on MongoDB 5.0+.
"""
__slots__ = ("pref", "effective_pref")
def __init__(self, pref: _ServerMode):
self.pref = pref
self.effective_pref: _ServerMode = ReadPreference.PRIMARY
def selection_hook(self, topology_description: TopologyDescription) -> None:
common_wv = topology_description.common_wire_version
if (
topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED)
and common_wv
and common_wv < 13
):
self.effective_pref = ReadPreference.PRIMARY
else:
self.effective_pref = self.pref
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to a Selection."""
return self.effective_pref(selection)
def __repr__(self) -> str:
return f"_AggWritePref(pref={self.pref!r})"
# Proxy other calls to the effective_pref so that _AggWritePref can be
# used in place of an actual read preference.
def __getattr__(self, name: str) -> Any:
return getattr(self.effective_pref, name)
_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest)
def make_read_preference(
mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1
) -> _ServerMode:
if mode == _PRIMARY:
if tag_sets not in (None, [{}]):
raise ConfigurationError("Read preference primary cannot be combined with tags")
if max_staleness != -1:
raise ConfigurationError(
"Read preference primary cannot be combined with maxStalenessSeconds"
)
return Primary()
return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore
_MODES = (
"PRIMARY",
"PRIMARY_PREFERRED",
"SECONDARY",
"SECONDARY_PREFERRED",
"NEAREST",
)
class ReadPreference:
"""An enum that defines some commonly used read preference modes.
Apps can also create a custom read preference, for example::
Nearest(tag_sets=[{"node":"analytics"}])
See :doc:`/examples/high_availability` for code examples.
A read preference is used in three cases:
:class:`~pymongo.mongo_client.MongoClient` connected to a single mongod:
- ``PRIMARY``: Queries are allowed if the server is standalone or a replica
set primary.
- All other modes allow queries to standalone servers, to a replica set
primary, or to replica set secondaries.
:class:`~pymongo.mongo_client.MongoClient` initialized with the
``replicaSet`` option:
- ``PRIMARY``: Read from the primary. This is the default, and provides the
strongest consistency. If no primary is available, raise
:class:`~pymongo.errors.AutoReconnect`.
- ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is
none, read from a secondary.
- ``SECONDARY``: Read from a secondary. If no secondary is available,
raise :class:`~pymongo.errors.AutoReconnect`.
- ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise
from the primary.
- ``NEAREST``: Read from any member.
:class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a
sharded cluster of replica sets:
- ``PRIMARY``: Read from the primary of the shard, or raise
:class:`~pymongo.errors.OperationFailure` if there is none.
This is the default.
- ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is
none, read from a secondary of the shard.
- ``SECONDARY``: Read from a secondary of the shard, or raise
:class:`~pymongo.errors.OperationFailure` if there is none.
- ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available,
otherwise from the shard primary.
- ``NEAREST``: Read from any shard member.
"""
PRIMARY = Primary()
PRIMARY_PREFERRED = PrimaryPreferred()
SECONDARY = Secondary()
SECONDARY_PREFERRED = SecondaryPreferred()
NEAREST = Nearest()
def read_pref_mode_from_name(name: str) -> int:
"""Get the read preference mode from mongos/uri name."""
return _MONGOS_MODES.index(name)
class MovingAverage:
"""Tracks an exponentially-weighted moving average."""
average: Optional[float]
def __init__(self) -> None:
self.average = None
def add_sample(self, sample: float) -> None:
if sample < 0:
# Likely system time change while waiting for hello response
# and not using time.monotonic. Ignore it, the next one will
# probably be valid.
return
if self.average is None:
self.average = sample
else:
# The Server Selection Spec requires an exponentially weighted
# average with alpha = 0.2.
self.average = 0.8 * self.average + 0.2 * sample
def get(self) -> Optional[float]:
"""Get the calculated average, or None if no samples yet."""
return self.average
def reset(self) -> None:
self.average = None

View File

@ -20,11 +20,8 @@ from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union
if TYPE_CHECKING:
from datetime import timedelta
from pymongo.synchronous.message import _OpMsg, _OpReply
from pymongo.synchronous.pool import Connection
from pymongo.synchronous.typings import _Address, _DocumentOut
_IS_SYNC = True
from pymongo.message import _OpMsg, _OpReply
from pymongo.typings import _Address, _AgnosticConnection, _DocumentOut
class Response:
@ -92,7 +89,7 @@ class PinnedResponse(Response):
self,
data: Union[_OpMsg, _OpReply],
address: _Address,
conn: Connection,
conn: _AgnosticConnection,
request_id: int,
duration: Optional[timedelta],
from_command: bool,
@ -103,7 +100,7 @@ class PinnedResponse(Response):
:param data: A network response message.
:param address: (host, port) of the source server.
:param conn: The Connection used for the initial query.
:param conn: The AsyncConnection/Connection used for the initial query.
:param request_id: The request id of this operation.
:param duration: The duration of the operation.
:param from_command: If the response is the result of a db command.
@ -116,8 +113,8 @@ class PinnedResponse(Response):
self._more_to_come = more_to_come
@property
def conn(self) -> Connection:
"""The Connection used for the initial query.
def conn(self) -> _AgnosticConnection:
"""The AsyncConnection/Connection used for the initial query.
The server will send batches on this socket, without waiting for
getMores from the client, until the result set is exhausted or there

View File

@ -1,4 +1,4 @@
# Copyright 2024-present MongoDB, Inc.
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,10 +12,288 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Re-import of synchronous ServerDescription API for compatibility."""
"""Represent one server the driver is connected to."""
from __future__ import annotations
from pymongo.synchronous.server_description import * # noqa: F403
from pymongo.synchronous.server_description import __doc__ as original_doc
import time
import warnings
from typing import Any, Mapping, Optional
__doc__ = original_doc
from bson import EPOCH_NAIVE
from bson.objectid import ObjectId
from pymongo.hello import Hello
from pymongo.server_type import SERVER_TYPE
from pymongo.typings import ClusterTime, _Address
class ServerDescription:
"""Immutable representation of one server.
:param address: A (host, port) pair
:param hello: Optional Hello instance
:param round_trip_time: Optional float
:param error: Optional, the last error attempting to connect to the server
:param round_trip_time: Optional float, the min latency from the most recent samples
"""
__slots__ = (
"_address",
"_server_type",
"_all_hosts",
"_tags",
"_replica_set_name",
"_primary",
"_max_bson_size",
"_max_message_size",
"_max_write_batch_size",
"_min_wire_version",
"_max_wire_version",
"_round_trip_time",
"_min_round_trip_time",
"_me",
"_is_writable",
"_is_readable",
"_ls_timeout_minutes",
"_error",
"_set_version",
"_election_id",
"_cluster_time",
"_last_write_date",
"_last_update_time",
"_topology_version",
)
def __init__(
self,
address: _Address,
hello: Optional[Hello] = None,
round_trip_time: Optional[float] = None,
error: Optional[Exception] = None,
min_round_trip_time: float = 0.0,
) -> None:
self._address = address
if not hello:
hello = Hello({})
self._server_type = hello.server_type
self._all_hosts = hello.all_hosts
self._tags = hello.tags
self._replica_set_name = hello.replica_set_name
self._primary = hello.primary
self._max_bson_size = hello.max_bson_size
self._max_message_size = hello.max_message_size
self._max_write_batch_size = hello.max_write_batch_size
self._min_wire_version = hello.min_wire_version
self._max_wire_version = hello.max_wire_version
self._set_version = hello.set_version
self._election_id = hello.election_id
self._cluster_time = hello.cluster_time
self._is_writable = hello.is_writable
self._is_readable = hello.is_readable
self._ls_timeout_minutes = hello.logical_session_timeout_minutes
self._round_trip_time = round_trip_time
self._min_round_trip_time = min_round_trip_time
self._me = hello.me
self._last_update_time = time.monotonic()
self._error = error
self._topology_version = hello.topology_version
if error:
details = getattr(error, "details", None)
if isinstance(details, dict):
self._topology_version = details.get("topologyVersion")
self._last_write_date: Optional[float]
if hello.last_write_date:
# Convert from datetime to seconds.
delta = hello.last_write_date - EPOCH_NAIVE
self._last_write_date = delta.total_seconds()
else:
self._last_write_date = None
@property
def address(self) -> _Address:
"""The address (host, port) of this server."""
return self._address
@property
def server_type(self) -> int:
"""The type of this server."""
return self._server_type
@property
def server_type_name(self) -> str:
"""The server type as a human readable string.
.. versionadded:: 3.4
"""
return SERVER_TYPE._fields[self._server_type]
@property
def all_hosts(self) -> set[tuple[str, int]]:
"""List of hosts, passives, and arbiters known to this server."""
return self._all_hosts
@property
def tags(self) -> Mapping[str, Any]:
return self._tags
@property
def replica_set_name(self) -> Optional[str]:
"""Replica set name or None."""
return self._replica_set_name
@property
def primary(self) -> Optional[tuple[str, int]]:
"""This server's opinion about who the primary is, or None."""
return self._primary
@property
def max_bson_size(self) -> int:
return self._max_bson_size
@property
def max_message_size(self) -> int:
return self._max_message_size
@property
def max_write_batch_size(self) -> int:
return self._max_write_batch_size
@property
def min_wire_version(self) -> int:
return self._min_wire_version
@property
def max_wire_version(self) -> int:
return self._max_wire_version
@property
def set_version(self) -> Optional[int]:
return self._set_version
@property
def election_id(self) -> Optional[ObjectId]:
return self._election_id
@property
def cluster_time(self) -> Optional[ClusterTime]:
return self._cluster_time
@property
def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]:
warnings.warn(
"'election_tuple' is deprecated, use 'set_version' and 'election_id' instead",
DeprecationWarning,
stacklevel=2,
)
return self._set_version, self._election_id
@property
def me(self) -> Optional[tuple[str, int]]:
return self._me
@property
def logical_session_timeout_minutes(self) -> Optional[int]:
return self._ls_timeout_minutes
@property
def last_write_date(self) -> Optional[float]:
return self._last_write_date
@property
def last_update_time(self) -> float:
return self._last_update_time
@property
def round_trip_time(self) -> Optional[float]:
"""The current average latency or None."""
# This override is for unittesting only!
if self._address in self._host_to_round_trip_time:
return self._host_to_round_trip_time[self._address]
return self._round_trip_time
@property
def min_round_trip_time(self) -> float:
"""The min latency from the most recent samples."""
return self._min_round_trip_time
@property
def error(self) -> Optional[Exception]:
"""The last error attempting to connect to the server, or None."""
return self._error
@property
def is_writable(self) -> bool:
return self._is_writable
@property
def is_readable(self) -> bool:
return self._is_readable
@property
def mongos(self) -> bool:
return self._server_type == SERVER_TYPE.Mongos
@property
def is_server_type_known(self) -> bool:
return self.server_type != SERVER_TYPE.Unknown
@property
def retryable_writes_supported(self) -> bool:
"""Checks if this server supports retryable writes."""
return (
self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)
) or self._server_type == SERVER_TYPE.LoadBalancer
@property
def retryable_reads_supported(self) -> bool:
"""Checks if this server supports retryable writes."""
return self._max_wire_version >= 6
@property
def topology_version(self) -> Optional[Mapping[str, Any]]:
return self._topology_version
def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription:
unknown = ServerDescription(self.address, error=error)
unknown._topology_version = self.topology_version
return unknown
def __eq__(self, other: Any) -> bool:
if isinstance(other, ServerDescription):
return (
(self._address == other.address)
and (self._server_type == other.server_type)
and (self._min_wire_version == other.min_wire_version)
and (self._max_wire_version == other.max_wire_version)
and (self._me == other.me)
and (self._all_hosts == other.all_hosts)
and (self._tags == other.tags)
and (self._replica_set_name == other.replica_set_name)
and (self._set_version == other.set_version)
and (self._election_id == other.election_id)
and (self._primary == other.primary)
and (self._ls_timeout_minutes == other.logical_session_timeout_minutes)
and (self._error == other.error)
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __repr__(self) -> str:
errmsg = ""
if self.error:
errmsg = f", error={self.error!r}"
return "<{} {} server_type: {}, rtt: {}{}>".format(
self.__class__.__name__,
self.address,
self.server_type_name,
self.round_trip_time,
errmsg,
)
# For unittesting only. Use under no circumstances!
_host_to_round_trip_time: dict = {}

View File

@ -20,10 +20,9 @@ from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, TypeVar, cas
from pymongo.server_type import SERVER_TYPE
if TYPE_CHECKING:
from pymongo.synchronous.server_description import ServerDescription
from pymongo.synchronous.topology_description import TopologyDescription
from pymongo.server_description import ServerDescription
from pymongo.topology_description import TopologyDescription
_IS_SYNC = True
T = TypeVar("T")
TagSet = Mapping[str, Any]

View File

@ -19,14 +19,12 @@ import ipaddress
import random
from typing import TYPE_CHECKING, Any, Optional, Union
from pymongo.common import CONNECT_TIMEOUT
from pymongo.errors import ConfigurationError
from pymongo.synchronous.common import CONNECT_TIMEOUT
if TYPE_CHECKING:
from dns import resolver
_IS_SYNC = True
def _have_dnspython() -> bool:
try:

View File

@ -18,20 +18,20 @@ from __future__ import annotations
from collections.abc import Callable, Mapping, MutableMapping
from typing import TYPE_CHECKING, Any, Optional, Union
from pymongo import common
from pymongo.collation import validate_collation_or_none
from pymongo.errors import ConfigurationError
from pymongo.synchronous import common
from pymongo.synchronous.collation import validate_collation_or_none
from pymongo.synchronous.read_preferences import ReadPreference, _AggWritePref
from pymongo.read_preferences import ReadPreference, _AggWritePref
if TYPE_CHECKING:
from pymongo.read_preferences import _ServerMode
from pymongo.synchronous.client_session import ClientSession
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.pool import Connection
from pymongo.synchronous.read_preferences import _ServerMode
from pymongo.synchronous.server import Server
from pymongo.synchronous.typings import _DocumentType, _Pipeline
from pymongo.typings import _DocumentType, _Pipeline
_IS_SYNC = True

View File

@ -18,16 +18,12 @@ from __future__ import annotations
import functools
import hashlib
import hmac
import os
import socket
import typing
from base64 import standard_b64decode, standard_b64encode
from collections import namedtuple
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Mapping,
MutableMapping,
Optional,
@ -36,20 +32,22 @@ from typing import (
from urllib.parse import quote
from bson.binary import Binary
from pymongo.auth_shared import (
MongoCredential,
_authenticate_scram_start,
_parse_scram_response,
_xor,
)
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.saslprep import saslprep
from pymongo.synchronous.auth_aws import _authenticate_aws
from pymongo.synchronous.auth_oidc import (
_authenticate_oidc,
_get_authenticator,
_OIDCAzureCallback,
_OIDCGCPCallback,
_OIDCProperties,
_OIDCTestCallback,
)
if TYPE_CHECKING:
from pymongo.synchronous.hello import Hello
from pymongo.hello import Hello
from pymongo.synchronous.pool import Connection
HAVE_KERBEROS = True
@ -68,210 +66,6 @@ except ImportError:
_IS_SYNC = True
MECHANISMS = frozenset(
[
"GSSAPI",
"MONGODB-CR",
"MONGODB-OIDC",
"MONGODB-X509",
"MONGODB-AWS",
"PLAIN",
"SCRAM-SHA-1",
"SCRAM-SHA-256",
"DEFAULT",
]
)
"""The authentication mechanisms supported by PyMongo."""
class _Cache:
__slots__ = ("data",)
_hash_val = hash("_Cache")
def __init__(self) -> None:
self.data = None
def __eq__(self, other: object) -> bool:
# Two instances must always compare equal.
if isinstance(other, _Cache):
return True
return NotImplemented
def __ne__(self, other: object) -> bool:
if isinstance(other, _Cache):
return False
return NotImplemented
def __hash__(self) -> int:
return self._hash_val
MongoCredential = namedtuple(
"MongoCredential",
["mechanism", "source", "username", "password", "mechanism_properties", "cache"],
)
"""A hashable namedtuple of values used for authentication."""
GSSAPIProperties = namedtuple(
"GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"]
)
"""Mechanism properties for GSSAPI authentication."""
_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
"""Mechanism properties for MONGODB-AWS authentication."""
def _build_credentials_tuple(
mech: str,
source: Optional[str],
user: str,
passwd: str,
extra: Mapping[str, Any],
database: Optional[str],
) -> MongoCredential:
"""Build and return a mechanism specific credentials tuple."""
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
raise ConfigurationError(f"{mech} requires a username.")
if mech == "GSSAPI":
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for GSSAPI")
properties = extra.get("authmechanismproperties", {})
service_name = properties.get("SERVICE_NAME", "mongodb")
canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False))
service_realm = properties.get("SERVICE_REALM")
props = GSSAPIProperties(
service_name=service_name,
canonicalize_host_name=canonicalize,
service_realm=service_realm,
)
# Source is always $external.
return MongoCredential(mech, "$external", user, passwd, props, None)
elif mech == "MONGODB-X509":
if passwd is not None:
raise ConfigurationError("Passwords are not supported by MONGODB-X509")
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for MONGODB-X509")
# Source is always $external, user can be None.
return MongoCredential(mech, "$external", user, None, None, None)
elif mech == "MONGODB-AWS":
if user is not None and passwd is None:
raise ConfigurationError("username without a password is not supported by MONGODB-AWS")
if source is not None and source != "$external":
raise ConfigurationError(
"authentication source must be $external or None for MONGODB-AWS"
)
properties = extra.get("authmechanismproperties", {})
aws_session_token = properties.get("AWS_SESSION_TOKEN")
aws_props = _AWSProperties(aws_session_token=aws_session_token)
# user can be None for temporary link-local EC2 credentials.
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
elif mech == "MONGODB-OIDC":
properties = extra.get("authmechanismproperties", {})
callback = properties.get("OIDC_CALLBACK")
human_callback = properties.get("OIDC_HUMAN_CALLBACK")
environ = properties.get("ENVIRONMENT")
token_resource = properties.get("TOKEN_RESOURCE", "")
default_allowed = [
"*.mongodb.net",
"*.mongodb-dev.net",
"*.mongodb-qa.net",
"*.mongodbgov.net",
"localhost",
"127.0.0.1",
"::1",
]
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
msg = (
"authentication with MONGODB-OIDC requires providing either a callback or a environment"
)
if passwd is not None:
msg = "password is not supported by MONGODB-OIDC"
raise ConfigurationError(msg)
if callback or human_callback:
if environ is not None:
raise ConfigurationError(msg)
if callback and human_callback:
msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK"
raise ConfigurationError(msg)
elif environ is not None:
if environ == "test":
if user is not None:
msg = "test environment for MONGODB-OIDC does not support username"
raise ConfigurationError(msg)
callback = _OIDCTestCallback()
elif environ == "azure":
passwd = None
if not token_resource:
raise ConfigurationError(
"Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCAzureCallback(token_resource)
elif environ == "gcp":
passwd = None
if not token_resource:
raise ConfigurationError(
"GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCGCPCallback(token_resource)
else:
raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}")
else:
raise ConfigurationError(msg)
oidc_props = _OIDCProperties(
callback=callback,
human_callback=human_callback,
environment=environ,
allowed_hosts=allowed_hosts,
token_resource=token_resource,
username=user,
)
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
elif mech == "PLAIN":
source_database = source or database or "$external"
return MongoCredential(mech, source_database, user, passwd, None, None)
else:
source_database = source or database or "admin"
if passwd is None:
raise ConfigurationError("A password is required.")
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
def _xor(fir: bytes, sec: bytes) -> bytes:
"""XOR two byte strings together."""
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]:
"""Split a scram response into key, value pairs."""
return dict(
typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1))
for item in response.split(b",")
)
def _authenticate_scram_start(
credentials: MongoCredential, mechanism: str
) -> tuple[bytes, bytes, MutableMapping[str, Any]]:
username = credentials.username
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
nonce = standard_b64encode(os.urandom(32))
first_bare = b"n=" + user + b",r=" + nonce
cmd = {
"saslStart": 1,
"mechanism": mechanism,
"payload": Binary(b"n,," + first_bare),
"autoAuthorize": 1,
"options": {"skipEmptyExchange": True},
}
return nonce, first_bare, cmd
def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanism: str) -> None:
"""Authenticate using SCRAM."""

View File

@ -23,7 +23,7 @@ from pymongo.errors import ConfigurationError, OperationFailure
if TYPE_CHECKING:
from bson.typings import _ReadableBuffer
from pymongo.synchronous.auth import MongoCredential
from pymongo.auth_shared import MongoCredential
from pymongo.synchronous.pool import Connection
_IS_SYNC = True

View File

@ -15,79 +15,35 @@
"""MONGODB-OIDC Authentication helpers."""
from __future__ import annotations
import abc
import os
import threading
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
from urllib.parse import quote
import bson
from bson.binary import Binary
from pymongo._azure_helpers import _get_azure_response
from pymongo._csot import remaining
from pymongo._gcp_helpers import _get_gcp_response
from pymongo.auth_oidc_shared import (
CALLBACK_VERSION,
HUMAN_CALLBACK_TIMEOUT_SECONDS,
MACHINE_CALLBACK_TIMEOUT_SECONDS,
TIME_BETWEEN_CALLS_SECONDS,
OIDCCallback,
OIDCCallbackContext,
OIDCCallbackResult,
OIDCIdPInfo,
_OIDCProperties,
)
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.helpers_constants import _AUTHENTICATION_FAILURE_CODE
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
if TYPE_CHECKING:
from pymongo.synchronous.auth import MongoCredential
from pymongo.auth_shared import MongoCredential
from pymongo.synchronous.pool import Connection
_IS_SYNC = True
@dataclass
class OIDCIdPInfo:
issuer: str
clientId: Optional[str] = field(default=None)
requestScopes: Optional[list[str]] = field(default=None)
@dataclass
class OIDCCallbackContext:
timeout_seconds: float
username: str
version: int
refresh_token: Optional[str] = field(default=None)
idp_info: Optional[OIDCIdPInfo] = field(default=None)
@dataclass
class OIDCCallbackResult:
access_token: str
expires_in_seconds: Optional[float] = field(default=None)
refresh_token: Optional[str] = field(default=None)
class OIDCCallback(abc.ABC):
"""A base class for defining OIDC callbacks."""
@abc.abstractmethod
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
"""Convert the given BSON value into our own type."""
@dataclass
class _OIDCProperties:
callback: Optional[OIDCCallback] = field(default=None)
human_callback: Optional[OIDCCallback] = field(default=None)
environment: Optional[str] = field(default=None)
allowed_hosts: list[str] = field(default_factory=list)
token_resource: Optional[str] = field(default=None)
username: str = ""
"""Mechanism properties for MONGODB-OIDC authentication."""
TOKEN_BUFFER_MINUTES = 5
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
CALLBACK_VERSION = 1
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
TIME_BETWEEN_CALLS_SECONDS = 0.1
def _get_authenticator(
credentials: MongoCredential, address: tuple[str, int]
) -> _OIDCAuthenticator:
@ -117,48 +73,6 @@ def _get_authenticator(
return credentials.cache.data
class _OIDCTestCallback(OIDCCallback):
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
token_file = os.environ.get("OIDC_TOKEN_FILE")
if not token_file:
raise RuntimeError(
'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set'
)
with open(token_file) as fid:
return OIDCCallbackResult(access_token=fid.read().strip())
class _OIDCAWSCallback(OIDCCallback):
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
if not token_file:
raise RuntimeError(
'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set'
)
with open(token_file) as fid:
return OIDCCallbackResult(access_token=fid.read().strip())
class _OIDCAzureCallback(OIDCCallback):
def __init__(self, token_resource: str) -> None:
self.token_resource = quote(token_resource)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds)
return OIDCCallbackResult(
access_token=resp["access_token"], expires_in_seconds=resp["expires_in"]
)
class _OIDCGCPCallback(OIDCCallback):
def __init__(self, token_resource: str) -> None:
self.token_resource = quote(token_resource)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
return OIDCCallbackResult(access_token=resp["access_token"])
@dataclass
class _OIDCAuthenticator:
username: str

View File

@ -19,6 +19,8 @@
from __future__ import annotations
import copy
import datetime
import logging
from collections.abc import MutableMapping
from itertools import islice
from typing import (
@ -26,7 +28,6 @@ from typing import (
Any,
Iterator,
Mapping,
NoReturn,
Optional,
Type,
Union,
@ -34,140 +35,50 @@ from typing import (
from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from pymongo import _csot
from pymongo.errors import (
BulkWriteError,
ConfigurationError,
InvalidOperation,
OperationFailure,
from pymongo import _csot, common
from pymongo.bulk_shared import (
_COMMANDS,
_DELETE_ALL,
_merge_command,
_raise_bulk_write_error,
_Run,
)
from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES
from pymongo.synchronous import common
from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern
from pymongo.synchronous.common import (
from pymongo.common import (
validate_is_document_type,
validate_ok_for_replace,
validate_ok_for_update,
)
from pymongo.synchronous.helpers import _get_wce_doc
from pymongo.synchronous.message import (
from pymongo.errors import (
ConfigurationError,
InvalidOperation,
NotPrimaryError,
OperationFailure,
)
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import (
_DELETE,
_INSERT,
_UPDATE,
_BulkWriteContext,
_convert_exception,
_convert_write_result,
_EncryptedBulkWriteContext,
_randint,
)
from pymongo.synchronous.read_preferences import ReadPreference
from pymongo.read_preferences import ReadPreference
from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern
from pymongo.synchronous.helpers import _handle_reauth
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.pool import Connection
from pymongo.synchronous.typings import _DocumentOut, _DocumentType, _Pipeline
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
_IS_SYNC = True
_DELETE_ALL: int = 0
_DELETE_ONE: int = 1
# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err
_BAD_VALUE: int = 2
_UNKNOWN_ERROR: int = 8
_WRITE_CONCERN_ERROR: int = 64
_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete")
class _Run:
"""Represents a batch of write operations."""
def __init__(self, op_type: int) -> None:
"""Initialize a new Run object."""
self.op_type: int = op_type
self.index_map: list[int] = []
self.ops: list[Any] = []
self.idx_offset: int = 0
def index(self, idx: int) -> int:
"""Get the original index of an operation in this run.
:param idx: The Run index that maps to the original index.
"""
return self.index_map[idx]
def add(self, original_index: int, operation: Any) -> None:
"""Add an operation to this Run instance.
:param original_index: The original index of this operation
within a larger bulk operation.
:param operation: The operation document.
"""
self.index_map.append(original_index)
self.ops.append(operation)
def _merge_command(
run: _Run,
full_result: MutableMapping[str, Any],
offset: int,
result: Mapping[str, Any],
) -> None:
"""Merge a write command result into the full bulk result."""
affected = result.get("n", 0)
if run.op_type == _INSERT:
full_result["nInserted"] += affected
elif run.op_type == _DELETE:
full_result["nRemoved"] += affected
elif run.op_type == _UPDATE:
upserted = result.get("upserted")
if upserted:
n_upserted = len(upserted)
for doc in upserted:
doc["index"] = run.index(doc["index"] + offset)
full_result["upserted"].extend(upserted)
full_result["nUpserted"] += n_upserted
full_result["nMatched"] += affected - n_upserted
else:
full_result["nMatched"] += affected
full_result["nModified"] += result["nModified"]
write_errors = result.get("writeErrors")
if write_errors:
for doc in write_errors:
# Leave the server response intact for APM.
replacement = doc.copy()
idx = doc["index"] + offset
replacement["index"] = run.index(idx)
# Add the failed operation to the error document.
replacement["op"] = run.ops[idx]
full_result["writeErrors"].append(replacement)
wce = _get_wce_doc(result)
if wce:
full_result["writeConcernErrors"].append(wce)
def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn:
"""Raise a BulkWriteError from the full bulk api result."""
# retryWrites on MMAPv1 should raise an actionable error.
if full_result["writeErrors"]:
full_result["writeErrors"].sort(key=lambda error: error["index"])
err = full_result["writeErrors"][0]
code = err["code"]
msg = err["errmsg"]
if code == 20 and msg.startswith("Transaction numbers"):
errmsg = (
"This MongoDB deployment does not support "
"retryable writes. Please add retryWrites=false "
"to your connection string."
)
raise OperationFailure(errmsg, code, full_result)
raise BulkWriteError(full_result)
class _Bulk:
"""The private guts of the bulk write API."""
@ -204,13 +115,16 @@ class _Bulk:
# Extra state so that we know where to pick up on a retry attempt.
self.current_run = None
self.next_run = None
self.is_encrypted = False
@property
def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
encrypter = self.collection.database.client._encrypter
if encrypter and not encrypter._bypass_auto_encryption:
self.is_encrypted = True
return _EncryptedBulkWriteContext
else:
self.is_encrypted = False
return _BulkWriteContext
def add_insert(self, document: _DocumentOut) -> None:
@ -315,6 +229,230 @@ class _Bulk:
if run.ops:
yield run
@_handle_reauth
def write_command(
self,
bwc: _BulkWriteContext,
cmd: MutableMapping[str, Any],
request_id: int,
msg: bytes,
docs: list[Mapping[str, Any]],
client: MongoClient,
) -> dict[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""
cmd[bwc.field] = docs
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._start(cmd, request_id, docs)
try:
reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc]
duration = datetime.datetime.now() - bwc.start_time
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if bwc.publish:
bwc._fail(request_id, failure, duration)
raise
finally:
bwc.start_time = datetime.datetime.now()
return reply # type: ignore[return-value]
def unack_write(
self,
bwc: _BulkWriteContext,
cmd: MutableMapping[str, Any],
request_id: int,
msg: bytes,
max_doc_size: int,
docs: list[Mapping[str, Any]],
client: MongoClient,
) -> Optional[Mapping[str, Any]]:
"""A proxy for Connection.unack_write that handles event publishing."""
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
cmd = bwc._start(cmd, request_id, docs)
try:
result = bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override]
duration = datetime.datetime.now() - bwc.start_time
if result is not None:
reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type]
else:
# Comply with APM spec.
reply = {"ok": 1}
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration)
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, OperationFailure):
failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type]
elif isinstance(exc, NotPrimaryError):
failure = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if bwc.publish:
assert bwc.start_time is not None
bwc._fail(request_id, failure, duration)
raise
finally:
bwc.start_time = datetime.datetime.now()
return result # type: ignore[return-value]
def _execute_batch_unack(
self,
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
cmd: dict[str, Any],
ops: list[Mapping[str, Any]],
client: MongoClient,
) -> list[Mapping[str, Any]]:
if self.is_encrypted:
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
bwc.conn.command( # type: ignore[misc]
bwc.db_name,
batched_cmd, # type: ignore[arg-type]
write_concern=WriteConcern(w=0),
session=bwc.session, # type: ignore[arg-type]
client=client, # type: ignore[arg-type]
)
else:
request_id, msg, to_send = bwc.batch_command(cmd, ops)
# Though this isn't strictly a "legacy" write, the helper
# handles publishing commands and sending our message
# without receiving a result. Send 0 for max_doc_size
# to disable size checking. Size checking is handled while
# the documents are encoded to BSON.
self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) # type: ignore[arg-type]
return to_send
def _execute_batch(
self,
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
cmd: dict[str, Any],
ops: list[Mapping[str, Any]],
client: MongoClient,
) -> tuple[dict[str, Any], list[Mapping[str, Any]]]:
if self.is_encrypted:
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
result = bwc.conn.command( # type: ignore[misc]
bwc.db_name,
batched_cmd, # type: ignore[arg-type]
codec_options=bwc.codec,
session=bwc.session, # type: ignore[arg-type]
client=client, # type: ignore[arg-type]
)
else:
request_id, msg, to_send = bwc.batch_command(cmd, ops)
result = self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type]
client._process_response(result, bwc.session) # type: ignore[arg-type]
return result, to_send # type: ignore[return-value]
def _execute_command(
self,
generator: Iterator[Any],
@ -387,7 +525,7 @@ class _Bulk:
# Run as many ops as possible in one command.
if write_concern.acknowledged:
result, to_send = bwc.execute(cmd, ops, client)
result, to_send = self._execute_batch(bwc, cmd, ops, client)
# Retryable writeConcernErrors halt the execution of this run.
wce = result.get("writeConcernError", {})
@ -407,7 +545,7 @@ class _Bulk:
if self.ordered and "writeErrors" in result:
break
else:
to_send = bwc.execute_unack(cmd, ops, client)
to_send = self._execute_batch_unack(bwc, cmd, ops, client)
run.idx_offset += len(to_send)
@ -458,7 +596,7 @@ class _Bulk:
retryable_bulk,
session,
operation,
bulk=self,
bulk=self, # type: ignore[arg-type]
operation_id=op_id,
)
@ -499,7 +637,7 @@ class _Bulk:
conn.add_server_api(cmd)
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible.
to_send = bwc.execute_unack(cmd, ops, client)
to_send = self._execute_batch_unack(bwc, cmd, ops, client)
run.idx_offset += len(to_send)
self.current_run = run = next(generator, None)

View File

@ -21,7 +21,8 @@ from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
from bson import CodecOptions, _bson_to_dict
from bson.raw_bson import RawBSONDocument
from bson.timestamp import Timestamp
from pymongo import _csot
from pymongo import _csot, common
from pymongo.collation import validate_collation_or_none
from pymongo.errors import (
ConnectionFailure,
CursorNotFound,
@ -29,16 +30,14 @@ from pymongo.errors import (
OperationFailure,
PyMongoError,
)
from pymongo.synchronous import common
from pymongo.operations import _Op
from pymongo.synchronous.aggregation import (
_AggregationCommand,
_CollectionAggregationCommand,
_DatabaseAggregationCommand,
)
from pymongo.synchronous.collation import validate_collation_or_none
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.operations import _Op
from pymongo.synchronous.typings import _CollationIn, _DocumentType, _Pipeline
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
_IS_SYNC = True

View File

@ -1,334 +0,0 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Tools to parse mongo client options."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast
from bson.codec_options import _parse_codec_options
from pymongo.errors import ConfigurationError
from pymongo.read_concern import ReadConcern
from pymongo.ssl_support import get_ssl_context
from pymongo.synchronous import common
from pymongo.synchronous.compression_support import CompressionSettings
from pymongo.synchronous.monitoring import _EventListener, _EventListeners
from pymongo.synchronous.pool import PoolOptions
from pymongo.synchronous.read_preferences import (
_ServerMode,
make_read_preference,
read_pref_mode_from_name,
)
from pymongo.synchronous.server_selectors import any_server_selector
from pymongo.write_concern import WriteConcern, validate_boolean
if TYPE_CHECKING:
from bson.codec_options import CodecOptions
from pymongo.pyopenssl_context import SSLContext
from pymongo.synchronous.auth import MongoCredential
from pymongo.synchronous.encryption_options import AutoEncryptionOpts
from pymongo.synchronous.topology_description import _ServerSelector
_IS_SYNC = True
def _parse_credentials(
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
) -> Optional[MongoCredential]:
"""Parse authentication credentials."""
mechanism = options.get("authmechanism", "DEFAULT" if username else None)
source = options.get("authsource")
if username or mechanism:
from pymongo.synchronous.auth import _build_credentials_tuple
return _build_credentials_tuple(mechanism, source, username, password, options, database)
return None
def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode:
"""Parse read preference options."""
if "read_preference" in options:
return options["read_preference"]
name = options.get("readpreference", "primary")
mode = read_pref_mode_from_name(name)
tags = options.get("readpreferencetags")
max_staleness = options.get("maxstalenessseconds", -1)
return make_read_preference(mode, tags, max_staleness)
def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern:
"""Parse write concern options."""
concern = options.get("w")
wtimeout = options.get("wtimeoutms")
j = options.get("journal")
fsync = options.get("fsync")
return WriteConcern(concern, wtimeout, j, fsync)
def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
"""Parse read concern options."""
concern = options.get("readconcernlevel")
return ReadConcern(concern)
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
"""Parse ssl options."""
use_tls = options.get("tls")
if use_tls is not None:
validate_boolean("tls", use_tls)
certfile = options.get("tlscertificatekeyfile")
passphrase = options.get("tlscertificatekeyfilepassword")
ca_certs = options.get("tlscafile")
crlfile = options.get("tlscrlfile")
allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False)
allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False)
disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False)
enabled_tls_opts = []
for opt in (
"tlscertificatekeyfile",
"tlscertificatekeyfilepassword",
"tlscafile",
"tlscrlfile",
):
# Any non-null value of these options implies tls=True.
if opt in options and options[opt]:
enabled_tls_opts.append(opt)
for opt in (
"tlsallowinvalidcertificates",
"tlsallowinvalidhostnames",
"tlsdisableocspendpointcheck",
):
# A value of False for these options implies tls=True.
if opt in options and not options[opt]:
enabled_tls_opts.append(opt)
if enabled_tls_opts:
if use_tls is None:
# Implicitly enable TLS when one of the tls* options is set.
use_tls = True
elif not use_tls:
# Error since tls is explicitly disabled but a tls option is set.
raise ConfigurationError(
"TLS has not been enabled but the "
"following tls parameters have been set: "
"%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts)
)
if use_tls:
ctx = get_ssl_context(
certfile,
passphrase,
ca_certs,
crlfile,
allow_invalid_certificates,
allow_invalid_hostnames,
disable_ocsp_endpoint_check,
)
return ctx, allow_invalid_hostnames
return None, allow_invalid_hostnames
def _parse_pool_options(
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
) -> PoolOptions:
"""Parse connection pool options."""
credentials = _parse_credentials(username, password, database, options)
max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE)
min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE)
max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC)
if max_pool_size is not None and min_pool_size > max_pool_size:
raise ValueError("minPoolSize must be smaller or equal to maxPoolSize")
connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT)
socket_timeout = options.get("sockettimeoutms")
wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT)
event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners"))
appname = options.get("appname")
driver = options.get("driver")
server_api = options.get("server_api")
compression_settings = CompressionSettings(
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
)
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
load_balanced = options.get("loadbalanced")
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
return PoolOptions(
max_pool_size,
min_pool_size,
max_idle_time_seconds,
connect_timeout,
socket_timeout,
wait_queue_timeout,
ssl_context,
tls_allow_invalid_hostnames,
_EventListeners(event_listeners),
appname,
driver,
compression_settings,
max_connecting=max_connecting,
server_api=server_api,
load_balanced=load_balanced,
credentials=credentials,
)
class ClientOptions:
"""Read only configuration options for a MongoClient.
Should not be instantiated directly by application developers. Access
a client's options via :attr:`pymongo.mongo_client.MongoClient.options`
instead.
"""
def __init__(
self, username: str, password: str, database: Optional[str], options: Mapping[str, Any]
):
self.__options = options
self.__codec_options = _parse_codec_options(options)
self.__direct_connection = options.get("directconnection")
self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS)
# self.__server_selection_timeout is in seconds. Must use full name for
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
self.__server_selection_timeout = options.get(
"serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT
)
self.__pool_options = _parse_pool_options(username, password, database, options)
self.__read_preference = _parse_read_preference(options)
self.__replica_set_name = options.get("replicaset")
self.__write_concern = _parse_write_concern(options)
self.__read_concern = _parse_read_concern(options)
self.__connect = options.get("connect")
self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY)
self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES)
self.__retry_reads = options.get("retryreads", common.RETRY_READS)
self.__server_selector = options.get("server_selector", any_server_selector)
self.__auto_encryption_opts = options.get("auto_encryption_opts")
self.__load_balanced = options.get("loadbalanced")
self.__timeout = options.get("timeoutms")
self.__server_monitoring_mode = options.get(
"servermonitoringmode", common.SERVER_MONITORING_MODE
)
@property
def _options(self) -> Mapping[str, Any]:
"""The original options used to create this ClientOptions."""
return self.__options
@property
def connect(self) -> Optional[bool]:
"""Whether to begin discovering a MongoDB topology automatically."""
return self.__connect
@property
def codec_options(self) -> CodecOptions:
"""A :class:`~bson.codec_options.CodecOptions` instance."""
return self.__codec_options
@property
def direct_connection(self) -> Optional[bool]:
"""Whether to connect to the deployment in 'Single' topology."""
return self.__direct_connection
@property
def local_threshold_ms(self) -> int:
"""The local threshold for this instance."""
return self.__local_threshold_ms
@property
def server_selection_timeout(self) -> int:
"""The server selection timeout for this instance in seconds."""
return self.__server_selection_timeout
@property
def server_selector(self) -> _ServerSelector:
return self.__server_selector
@property
def heartbeat_frequency(self) -> int:
"""The monitoring frequency in seconds."""
return self.__heartbeat_frequency
@property
def pool_options(self) -> PoolOptions:
"""A :class:`~pymongo.pool.PoolOptions` instance."""
return self.__pool_options
@property
def read_preference(self) -> _ServerMode:
"""A read preference instance."""
return self.__read_preference
@property
def replica_set_name(self) -> Optional[str]:
"""Replica set name or None."""
return self.__replica_set_name
@property
def write_concern(self) -> WriteConcern:
"""A :class:`~pymongo.write_concern.WriteConcern` instance."""
return self.__write_concern
@property
def read_concern(self) -> ReadConcern:
"""A :class:`~pymongo.read_concern.ReadConcern` instance."""
return self.__read_concern
@property
def timeout(self) -> Optional[float]:
"""The configured timeoutMS converted to seconds, or None.
.. versionadded:: 4.2
"""
return self.__timeout
@property
def retry_writes(self) -> bool:
"""If this instance should retry supported write operations."""
return self.__retry_writes
@property
def retry_reads(self) -> bool:
"""If this instance should retry supported read operations."""
return self.__retry_reads
@property
def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]:
"""A :class:`~pymongo.encryption.AutoEncryptionOpts` or None."""
return self.__auto_encryption_opts
@property
def load_balanced(self) -> Optional[bool]:
"""True if the client was configured to connect to a load balancer."""
return self.__load_balanced
@property
def event_listeners(self) -> list[_EventListeners]:
"""The event listeners registered for this client.
See :mod:`~pymongo.monitoring` for details.
.. versionadded:: 4.0
"""
assert self.__pool_options._event_listeners is not None
return self.__pool_options._event_listeners.event_listeners()
@property
def server_monitoring_mode(self) -> str:
"""The configured serverMonitoringMode option.
.. versionadded:: 4.5
"""
return self.__server_monitoring_mode

View File

@ -164,12 +164,12 @@ from pymongo.errors import (
PyMongoError,
WTimeoutError,
)
from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
from pymongo.operations import _Op
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous.cursor import _ConnectionManager
from pymongo.synchronous.operations import _Op
from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
@ -177,7 +177,7 @@ if TYPE_CHECKING:
from pymongo.synchronous.pool import Connection
from pymongo.synchronous.server import Server
from pymongo.synchronous.typings import ClusterTime, _Address
from pymongo.typings import ClusterTime, _Address
_IS_SYNC = True

View File

@ -1,226 +0,0 @@
# Copyright 2016 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for working with `collations`_.
.. _collations: https://www.mongodb.com/docs/manual/reference/collation/
"""
from __future__ import annotations
from typing import Any, Mapping, Optional, Union
from pymongo.synchronous import common
from pymongo.write_concern import validate_boolean
_IS_SYNC = True
class CollationStrength:
"""
An enum that defines values for `strength` on a
:class:`~pymongo.collation.Collation`.
"""
PRIMARY = 1
"""Differentiate base (unadorned) characters."""
SECONDARY = 2
"""Differentiate character accents."""
TERTIARY = 3
"""Differentiate character case."""
QUATERNARY = 4
"""Differentiate words with and without punctuation."""
IDENTICAL = 5
"""Differentiate unicode code point (characters are exactly identical)."""
class CollationAlternate:
"""
An enum that defines values for `alternate` on a
:class:`~pymongo.collation.Collation`.
"""
NON_IGNORABLE = "non-ignorable"
"""Spaces and punctuation are treated as base characters."""
SHIFTED = "shifted"
"""Spaces and punctuation are *not* considered base characters.
Spaces and punctuation are distinguished regardless when the
:class:`~pymongo.collation.Collation` strength is at least
:data:`~pymongo.collation.CollationStrength.QUATERNARY`.
"""
class CollationMaxVariable:
"""
An enum that defines values for `max_variable` on a
:class:`~pymongo.collation.Collation`.
"""
PUNCT = "punct"
"""Both punctuation and spaces are ignored."""
SPACE = "space"
"""Spaces alone are ignored."""
class CollationCaseFirst:
"""
An enum that defines values for `case_first` on a
:class:`~pymongo.collation.Collation`.
"""
UPPER = "upper"
"""Sort uppercase characters first."""
LOWER = "lower"
"""Sort lowercase characters first."""
OFF = "off"
"""Default for locale or collation strength."""
class Collation:
"""Collation
:param locale: (string) The locale of the collation. This should be a string
that identifies an `ICU locale ID` exactly. For example, ``en_US`` is
valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB
documentation for a list of supported locales.
:param caseLevel: (optional) If ``True``, turn on case sensitivity if
`strength` is 1 or 2 (case sensitivity is implied if `strength` is
greater than 2). Defaults to ``False``.
:param caseFirst: (optional) Specify that either uppercase or lowercase
characters take precedence. Must be one of the following values:
* :data:`~CollationCaseFirst.UPPER`
* :data:`~CollationCaseFirst.LOWER`
* :data:`~CollationCaseFirst.OFF` (the default)
:param strength: Specify the comparison strength. This is also
known as the ICU comparison level. This must be one of the following
values:
* :data:`~CollationStrength.PRIMARY`
* :data:`~CollationStrength.SECONDARY`
* :data:`~CollationStrength.TERTIARY` (the default)
* :data:`~CollationStrength.QUATERNARY`
* :data:`~CollationStrength.IDENTICAL`
Each successive level builds upon the previous. For example, a
`strength` of :data:`~CollationStrength.SECONDARY` differentiates
characters based both on the unadorned base character and its accents.
:param numericOrdering: If ``True``, order numbers numerically
instead of in collation order (defaults to ``False``).
:param alternate: Specify whether spaces and punctuation are
considered base characters. This must be one of the following values:
* :data:`~CollationAlternate.NON_IGNORABLE` (the default)
* :data:`~CollationAlternate.SHIFTED`
:param maxVariable: When `alternate` is
:data:`~CollationAlternate.SHIFTED`, this option specifies what
characters may be ignored. This must be one of the following values:
* :data:`~CollationMaxVariable.PUNCT` (the default)
* :data:`~CollationMaxVariable.SPACE`
:param normalization: If ``True``, normalizes text into Unicode
NFD. Defaults to ``False``.
:param backwards: If ``True``, accents on characters are
considered from the back of the word to the front, as it is done in some
French dictionary ordering traditions. Defaults to ``False``.
:param kwargs: Keyword arguments supplying any additional options
to be sent with this Collation object.
.. versionadded: 3.4
"""
__slots__ = ("__document",)
def __init__(
self,
locale: str,
caseLevel: Optional[bool] = None,
caseFirst: Optional[str] = None,
strength: Optional[int] = None,
numericOrdering: Optional[bool] = None,
alternate: Optional[str] = None,
maxVariable: Optional[str] = None,
normalization: Optional[bool] = None,
backwards: Optional[bool] = None,
**kwargs: Any,
) -> None:
locale = common.validate_string("locale", locale)
self.__document: dict[str, Any] = {"locale": locale}
if caseLevel is not None:
self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel)
if caseFirst is not None:
self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst)
if strength is not None:
self.__document["strength"] = common.validate_integer("strength", strength)
if numericOrdering is not None:
self.__document["numericOrdering"] = validate_boolean(
"numericOrdering", numericOrdering
)
if alternate is not None:
self.__document["alternate"] = common.validate_string("alternate", alternate)
if maxVariable is not None:
self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable)
if normalization is not None:
self.__document["normalization"] = validate_boolean("normalization", normalization)
if backwards is not None:
self.__document["backwards"] = validate_boolean("backwards", backwards)
self.__document.update(kwargs)
@property
def document(self) -> dict[str, Any]:
"""The document representation of this collation.
.. note::
:class:`Collation` is immutable. Mutating the value of
:attr:`document` does not mutate this :class:`Collation`.
"""
return self.__document.copy()
def __repr__(self) -> str:
document = self.document
return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document))
def __eq__(self, other: Any) -> bool:
if isinstance(other, Collation):
return self.document == other.document
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def validate_collation_or_none(
value: Optional[Union[Mapping[str, Any], Collation]]
) -> Optional[dict[str, Any]]:
if value is None:
return None
if isinstance(value, Collation):
return value.document
if isinstance(value, dict):
return value
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")

View File

@ -41,41 +41,18 @@ from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from bson.timestamp import Timestamp
from pymongo import ASCENDING, _csot
from pymongo import ASCENDING, _csot, common, helpers_shared, message
from pymongo.collation import validate_collation_or_none
from pymongo.common import _ecoc_coll_name, _esc_coll_name
from pymongo.errors import (
ConfigurationError,
InvalidName,
InvalidOperation,
OperationFailure,
)
from pymongo.read_concern import DEFAULT_READ_CONCERN
from pymongo.results import (
BulkWriteResult,
DeleteResult,
InsertManyResult,
InsertOneResult,
UpdateResult,
)
from pymongo.synchronous import common, helpers, message
from pymongo.synchronous.aggregation import (
_CollectionAggregationCommand,
_CollectionRawAggregationCommand,
)
from pymongo.synchronous.bulk import _Bulk
from pymongo.synchronous.change_stream import CollectionChangeStream
from pymongo.synchronous.collation import validate_collation_or_none
from pymongo.synchronous.command_cursor import (
CommandCursor,
RawBatchCommandCursor,
)
from pymongo.synchronous.common import _ecoc_coll_name, _esc_coll_name
from pymongo.synchronous.cursor import (
Cursor,
RawBatchCursor,
)
from pymongo.synchronous.helpers import _check_write_command_response
from pymongo.synchronous.message import _UNICODE_REPLACE_CODEC_OPTIONS
from pymongo.synchronous.operations import (
from pymongo.helpers_shared import _check_write_command_response
from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS
from pymongo.operations import (
DeleteMany,
DeleteOne,
IndexModel,
@ -88,8 +65,30 @@ from pymongo.synchronous.operations import (
_IndexList,
_Op,
)
from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode
from pymongo.synchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
from pymongo.read_concern import DEFAULT_READ_CONCERN
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import (
BulkWriteResult,
DeleteResult,
InsertManyResult,
InsertOneResult,
UpdateResult,
)
from pymongo.synchronous.aggregation import (
_CollectionAggregationCommand,
_CollectionRawAggregationCommand,
)
from pymongo.synchronous.bulk import _Bulk
from pymongo.synchronous.change_stream import CollectionChangeStream
from pymongo.synchronous.command_cursor import (
CommandCursor,
RawBatchCommandCursor,
)
from pymongo.synchronous.cursor import (
Cursor,
RawBatchCursor,
)
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean
_IS_SYNC = True
@ -125,10 +124,10 @@ class ReturnDocument:
if TYPE_CHECKING:
import bson
from pymongo.collation import Collation
from pymongo.read_concern import ReadConcern
from pymongo.synchronous.aggregation import _AggregationCommand
from pymongo.synchronous.client_session import ClientSession
from pymongo.synchronous.collation import Collation
from pymongo.synchronous.database import Database
from pymongo.synchronous.pool import Connection
from pymongo.synchronous.server import Server
@ -995,7 +994,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
)
if not isinstance(hint, str):
hint = helpers._index_document(hint)
hint = helpers_shared._index_document(hint)
update_doc["hint"] = hint
command = {"update": self.name, "ordered": ordered, "updates": [update_doc]}
if let is not None:
@ -1476,7 +1475,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
)
if not isinstance(hint, str):
hint = helpers._index_document(hint)
hint = helpers_shared._index_document(hint)
delete_doc["hint"] = hint
command = {"delete": self.name, "ordered": ordered, "deletes": [delete_doc]}
@ -2092,7 +2091,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
pipeline.append({"$group": {"_id": 1, "n": {"$sum": 1}}})
cmd = {"aggregate": self._name, "pipeline": pipeline, "cursor": {}}
if "hint" in kwargs and not isinstance(kwargs["hint"], str):
kwargs["hint"] = helpers._index_document(kwargs["hint"])
kwargs["hint"] = helpers_shared._index_document(kwargs["hint"])
collation = validate_collation_or_none(kwargs.pop("collation", None))
cmd.update(kwargs)
@ -2425,7 +2424,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
) -> None:
name = index_or_name
if isinstance(index_or_name, list):
name = helpers._gen_index_name(index_or_name)
name = helpers_shared._gen_index_name(index_or_name)
if not isinstance(name, str):
raise TypeError("index_or_name must be an instance of str or list")
@ -3154,15 +3153,15 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd["let"] = let
cmd.update(kwargs)
if projection is not None:
cmd["fields"] = helpers._fields_list_to_dict(projection, "projection")
cmd["fields"] = helpers_shared._fields_list_to_dict(projection, "projection")
if sort is not None:
cmd["sort"] = helpers._index_document(sort)
cmd["sort"] = helpers_shared._index_document(sort)
if upsert is not None:
validate_boolean("upsert", upsert)
cmd["upsert"] = upsert
if hint is not None:
if not isinstance(hint, str):
hint = helpers._index_document(hint)
hint = helpers_shared._index_document(hint)
write_concern = self._write_concern_for_cmd(cmd, session)

View File

@ -31,16 +31,16 @@ from typing import (
from bson import CodecOptions, _convert_raw_document_lists_to_streams
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.synchronous.cursor import _ConnectionManager
from pymongo.synchronous.message import (
from pymongo.message import (
_CursorAddress,
_GetMore,
_OpMsg,
_OpReply,
_RawBatchGetMore,
)
from pymongo.synchronous.response import PinnedResponse
from pymongo.synchronous.typings import _Address, _DocumentOut, _DocumentType
from pymongo.response import PinnedResponse
from pymongo.synchronous.cursor import _ConnectionManager
from pymongo.typings import _Address, _DocumentOut, _DocumentType
if TYPE_CHECKING:
from pymongo.synchronous.client_session import ClientSession
@ -272,7 +272,7 @@ class CommandCursor(Generic[_DocumentType]):
if isinstance(response, PinnedResponse):
if not self._sock_mgr:
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type]
if response.from_command:
cursor = response.docs[0]["cursor"]
documents = cursor["nextBatch"]

View File

@ -36,17 +36,16 @@ from typing import (
from bson import RE_TYPE, _convert_raw_document_lists_to_streams
from bson.code import Code
from bson.son import SON
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.lock import _create_lock
from pymongo.synchronous import helpers
from pymongo.synchronous.collation import validate_collation_or_none
from pymongo.synchronous.common import (
from pymongo import helpers_shared
from pymongo.collation import validate_collation_or_none
from pymongo.common import (
validate_is_document_type,
validate_is_mapping,
)
from pymongo.synchronous.helpers import next
from pymongo.synchronous.message import (
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.lock import _create_lock
from pymongo.message import (
_CursorAddress,
_GetMore,
_OpMsg,
@ -55,18 +54,19 @@ from pymongo.synchronous.message import (
_RawBatchGetMore,
_RawBatchQuery,
)
from pymongo.synchronous.response import PinnedResponse
from pymongo.synchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
from pymongo.response import PinnedResponse
from pymongo.synchronous.helpers import next
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
from pymongo.write_concern import validate_boolean
if TYPE_CHECKING:
from _typeshed import SupportsItems
from bson.codec_options import CodecOptions
from pymongo.read_preferences import _ServerMode
from pymongo.synchronous.client_session import ClientSession
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.pool import Connection
from pymongo.synchronous.read_preferences import _ServerMode
_IS_SYNC = True
@ -179,7 +179,7 @@ class Cursor(Generic[_DocumentType]):
allow_disk_use = validate_boolean("allow_disk_use", allow_disk_use)
if projection is not None:
projection = helpers._fields_list_to_dict(projection, "projection")
projection = helpers_shared._fields_list_to_dict(projection, "projection")
if let is not None:
validate_is_document_type("let", let)
@ -191,7 +191,7 @@ class Cursor(Generic[_DocumentType]):
self._skip = skip
self._limit = limit
self._batch_size = batch_size
self._ordering = sort and helpers._index_document(sort) or None
self._ordering = sort and helpers_shared._index_document(sort) or None
self._max_scan = max_scan
self._explain = False
self._comment = comment
@ -740,8 +740,8 @@ class Cursor(Generic[_DocumentType]):
key, if not given :data:`~pymongo.ASCENDING` is assumed
"""
self._check_okay_to_chain()
keys = helpers._index_list(key_or_list, direction)
self._ordering = helpers._index_document(keys)
keys = helpers_shared._index_list(key_or_list, direction)
self._ordering = helpers_shared._index_document(keys)
return self
def explain(self) -> _DocumentType:
@ -772,7 +772,7 @@ class Cursor(Generic[_DocumentType]):
if isinstance(index, str):
self._hint = index
else:
self._hint = helpers._index_document(index)
self._hint = helpers_shared._index_document(index)
def hint(self, index: Optional[_Hint]) -> Cursor[_DocumentType]:
"""Adds a 'hint', telling Mongo the proper index to use for the query.
@ -1120,7 +1120,7 @@ class Cursor(Generic[_DocumentType]):
self._address = response.address
if isinstance(response, PinnedResponse):
if not self._sock_mgr:
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type]
cmd_name = operation.name
docs = response.docs

View File

@ -33,18 +33,17 @@ from typing import (
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
from bson.dbref import DBRef
from bson.timestamp import Timestamp
from pymongo import _csot
from pymongo import _csot, common
from pymongo.common import _ecoc_coll_name, _esc_coll_name
from pymongo.database_shared import _check_name, _CodecDocumentType
from pymongo.errors import CollectionInvalid, InvalidOperation
from pymongo.synchronous import common
from pymongo.operations import _Op
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.synchronous.aggregation import _DatabaseAggregationCommand
from pymongo.synchronous.change_stream import DatabaseChangeStream
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.common import _ecoc_coll_name, _esc_coll_name
from pymongo.synchronous.operations import _Op
from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode
from pymongo.synchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
if TYPE_CHECKING:
import bson
@ -151,7 +150,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
>>> db1.read_preference
Primary()
>>> from pymongo.synchronous.read_preferences import Secondary
>>> from pymongo.read_preferences import Secondary
>>> db2 = db1.with_options(read_preference=Secondary([{'node': 'analytics'}]))
>>> db1.read_preference
Primary()

Some files were not shown because too many files have changed in this diff Show More