Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
66935c06cb
@ -58,14 +58,12 @@ functions:
|
||||
|
||||
export MONGO_ORCHESTRATION_HOME="$DRIVERS_TOOLS/.evergreen/orchestration"
|
||||
export MONGODB_BINARIES="$DRIVERS_TOOLS/mongodb/bin"
|
||||
export UPLOAD_BUCKET="${project}"
|
||||
|
||||
cat <<EOT > expansion.yml
|
||||
CURRENT_VERSION: "$CURRENT_VERSION"
|
||||
DRIVERS_TOOLS: "$DRIVERS_TOOLS"
|
||||
MONGO_ORCHESTRATION_HOME: "$MONGO_ORCHESTRATION_HOME"
|
||||
MONGODB_BINARIES: "$MONGODB_BINARIES"
|
||||
UPLOAD_BUCKET: "$UPLOAD_BUCKET"
|
||||
PROJECT_DIRECTORY: "$PROJECT_DIRECTORY"
|
||||
PREPARE_SHELL: |
|
||||
set -o errexit
|
||||
@ -73,7 +71,6 @@ functions:
|
||||
export DRIVERS_TOOLS="$DRIVERS_TOOLS"
|
||||
export MONGO_ORCHESTRATION_HOME="$MONGO_ORCHESTRATION_HOME"
|
||||
export MONGODB_BINARIES="$MONGODB_BINARIES"
|
||||
export UPLOAD_BUCKET="$UPLOAD_BUCKET"
|
||||
export PROJECT_DIRECTORY="$PROJECT_DIRECTORY"
|
||||
|
||||
export TMPDIR="$MONGO_ORCHESTRATION_HOME/db"
|
||||
@ -103,30 +100,35 @@ functions:
|
||||
echo "{ \"releases\": { \"default\": \"$MONGODB_BINARIES\" }}" > $MONGO_ORCHESTRATION_HOME/orchestration.config
|
||||
|
||||
"upload coverage" :
|
||||
- command: ec2.assume_role
|
||||
params:
|
||||
role_arn: ${assume_role_arn}
|
||||
- command: s3.put
|
||||
params:
|
||||
aws_key: ${aws_key}
|
||||
aws_secret: ${aws_secret}
|
||||
aws_key: ${AWS_ACCESS_KEY_ID}
|
||||
aws_secret: ${AWS_SECRET_ACCESS_KEY}
|
||||
aws_session_token: ${AWS_SESSION_TOKEN}
|
||||
local_file: src/.coverage
|
||||
optional: true
|
||||
# Upload the coverage report for all tasks in a single build to the same directory.
|
||||
remote_file: ${UPLOAD_BUCKET}/coverage/${revision}/${version_id}/coverage/coverage.${build_variant}.${task_name}
|
||||
bucket: mciuploads
|
||||
remote_file: coverage/${revision}/${version_id}/coverage/coverage.${build_variant}.${task_name}
|
||||
bucket: ${bucket_name}
|
||||
permissions: public-read
|
||||
content_type: text/html
|
||||
display_name: "Raw Coverage Report"
|
||||
|
||||
"download and merge coverage" :
|
||||
- command: ec2.assume_role
|
||||
params:
|
||||
role_arn: ${assume_role_arn}
|
||||
- command: shell.exec
|
||||
params:
|
||||
silent: true
|
||||
working_dir: "src"
|
||||
include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]
|
||||
script: |
|
||||
export AWS_ACCESS_KEY_ID=${aws_key}
|
||||
export AWS_SECRET_ACCESS_KEY=${aws_secret}
|
||||
|
||||
# Download all the task coverage files.
|
||||
aws s3 cp --recursive s3://mciuploads/${UPLOAD_BUCKET}/coverage/${revision}/${version_id}/coverage/ coverage/
|
||||
aws s3 cp --recursive s3://${bucket_name}/coverage/${revision}/${version_id}/coverage/ coverage/
|
||||
- command: shell.exec
|
||||
params:
|
||||
working_dir: "src"
|
||||
@ -138,24 +140,27 @@ functions:
|
||||
params:
|
||||
silent: true
|
||||
working_dir: "src"
|
||||
include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]
|
||||
script: |
|
||||
export AWS_ACCESS_KEY_ID=${aws_key}
|
||||
export AWS_SECRET_ACCESS_KEY=${aws_secret}
|
||||
aws s3 cp htmlcov/ s3://mciuploads/${UPLOAD_BUCKET}/coverage/${revision}/${version_id}/htmlcov/ --recursive --acl public-read --region us-east-1
|
||||
aws s3 cp htmlcov/ s3://${bucket_name}/coverage/${revision}/${version_id}/htmlcov/ --recursive --acl public-read --region us-east-1
|
||||
# Attach the index.html with s3.put so it shows up in the Evergreen UI.
|
||||
- command: s3.put
|
||||
params:
|
||||
aws_key: ${aws_key}
|
||||
aws_secret: ${aws_secret}
|
||||
aws_key: ${AWS_ACCESS_KEY_ID}
|
||||
aws_secret: ${AWS_SECRET_ACCESS_KEY}
|
||||
aws_session_token: ${AWS_SESSION_TOKEN}
|
||||
local_file: src/htmlcov/index.html
|
||||
remote_file: ${UPLOAD_BUCKET}/coverage/${revision}/${version_id}/htmlcov/index.html
|
||||
bucket: mciuploads
|
||||
remote_file: coverage/${revision}/${version_id}/htmlcov/index.html
|
||||
bucket: ${bucket_name}
|
||||
permissions: public-read
|
||||
content_type: text/html
|
||||
display_name: "Coverage Report HTML"
|
||||
|
||||
|
||||
"upload mo artifacts":
|
||||
- command: ec2.assume_role
|
||||
params:
|
||||
role_arn: ${assume_role_arn}
|
||||
- command: shell.exec
|
||||
params:
|
||||
script: |
|
||||
@ -174,37 +179,43 @@ functions:
|
||||
- "./**.mdmp" # Windows: minidumps
|
||||
- command: s3.put
|
||||
params:
|
||||
aws_key: ${aws_key}
|
||||
aws_secret: ${aws_secret}
|
||||
aws_key: ${AWS_ACCESS_KEY_ID}
|
||||
aws_secret: ${AWS_SECRET_ACCESS_KEY}
|
||||
aws_session_token: ${AWS_SESSION_TOKEN}
|
||||
local_file: mongo-coredumps.tgz
|
||||
remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/coredumps/${task_id}-${execution}-mongodb-coredumps.tar.gz
|
||||
bucket: mciuploads
|
||||
remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/coredumps/${task_id}-${execution}-mongodb-coredumps.tar.gz
|
||||
bucket: ${bucket_name}
|
||||
permissions: public-read
|
||||
content_type: ${content_type|application/gzip}
|
||||
display_name: Core Dumps - Execution
|
||||
optional: true
|
||||
- command: s3.put
|
||||
params:
|
||||
aws_key: ${aws_key}
|
||||
aws_secret: ${aws_secret}
|
||||
aws_key: ${AWS_ACCESS_KEY_ID}
|
||||
aws_secret: ${AWS_SECRET_ACCESS_KEY}
|
||||
aws_session_token: ${AWS_SESSION_TOKEN}
|
||||
local_file: mongodb-logs.tar.gz
|
||||
remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-mongodb-logs.tar.gz
|
||||
bucket: mciuploads
|
||||
remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-mongodb-logs.tar.gz
|
||||
bucket: ${bucket_name}
|
||||
permissions: public-read
|
||||
content_type: ${content_type|application/x-gzip}
|
||||
display_name: "mongodb-logs.tar.gz"
|
||||
- command: s3.put
|
||||
params:
|
||||
aws_key: ${aws_key}
|
||||
aws_secret: ${aws_secret}
|
||||
aws_key: ${AWS_ACCESS_KEY_ID}
|
||||
aws_secret: ${AWS_SECRET_ACCESS_KEY}
|
||||
aws_session_token: ${AWS_SESSION_TOKEN}
|
||||
local_file: drivers-tools/.evergreen/orchestration/server.log
|
||||
remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-orchestration.log
|
||||
bucket: mciuploads
|
||||
remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-orchestration.log
|
||||
bucket: ${bucket_name}
|
||||
permissions: public-read
|
||||
content_type: ${content_type|text/plain}
|
||||
display_name: "orchestration.log"
|
||||
|
||||
"upload working dir":
|
||||
- command: ec2.assume_role
|
||||
params:
|
||||
role_arn: ${assume_role_arn}
|
||||
- command: archive.targz_pack
|
||||
params:
|
||||
target: "working-dir.tar.gz"
|
||||
@ -213,11 +224,12 @@ functions:
|
||||
- "./**"
|
||||
- command: s3.put
|
||||
params:
|
||||
aws_key: ${aws_key}
|
||||
aws_secret: ${aws_secret}
|
||||
aws_key: ${AWS_ACCESS_KEY_ID}
|
||||
aws_secret: ${AWS_SECRET_ACCESS_KEY}
|
||||
aws_session_token: ${AWS_SESSION_TOKEN}
|
||||
local_file: working-dir.tar.gz
|
||||
remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-working-dir.tar.gz
|
||||
bucket: mciuploads
|
||||
remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-working-dir.tar.gz
|
||||
bucket: ${bucket_name}
|
||||
permissions: public-read
|
||||
content_type: ${content_type|application/x-gzip}
|
||||
display_name: "working-dir.tar.gz"
|
||||
@ -232,11 +244,12 @@ functions:
|
||||
- "*.lock"
|
||||
- command: s3.put
|
||||
params:
|
||||
aws_key: ${aws_key}
|
||||
aws_secret: ${aws_secret}
|
||||
aws_key: ${AWS_ACCESS_KEY_ID}
|
||||
aws_secret: ${AWS_SECRET_ACCESS_KEY}
|
||||
aws_session_token: ${AWS_SESSION_TOKEN}
|
||||
local_file: drivers-dir.tar.gz
|
||||
remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-drivers-dir.tar.gz
|
||||
bucket: mciuploads
|
||||
remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-drivers-dir.tar.gz
|
||||
bucket: ${bucket_name}
|
||||
permissions: public-read
|
||||
content_type: ${content_type|application/x-gzip}
|
||||
display_name: "drivers-dir.tar.gz"
|
||||
@ -785,6 +798,9 @@ functions:
|
||||
VERSION=${VERSION} ENSURE_UNIVERSAL2=${ENSURE_UNIVERSAL2} .evergreen/release.sh
|
||||
|
||||
"upload release":
|
||||
- command: ec2.assume_role
|
||||
params:
|
||||
role_arn: ${assume_role_arn}
|
||||
- command: archive.targz_pack
|
||||
params:
|
||||
target: "release-files.tgz"
|
||||
@ -793,25 +809,27 @@ functions:
|
||||
- "*"
|
||||
- command: s3.put
|
||||
params:
|
||||
aws_key: ${aws_key}
|
||||
aws_secret: ${aws_secret}
|
||||
aws_key: ${AWS_ACCESS_KEY_ID}
|
||||
aws_secret: ${AWS_SECRET_ACCESS_KEY}
|
||||
aws_session_token: ${AWS_SESSION_TOKEN}
|
||||
local_file: release-files.tgz
|
||||
remote_file: ${UPLOAD_BUCKET}/release/${revision}/${task_id}-${execution}-release-files.tar.gz
|
||||
bucket: mciuploads
|
||||
remote_file: release/${revision}/${task_id}-${execution}-release-files.tar.gz
|
||||
bucket: ${bucket_name}
|
||||
permissions: public-read
|
||||
content_type: ${content_type|application/gzip}
|
||||
display_name: Release files
|
||||
|
||||
"download and merge releases":
|
||||
- command: ec2.assume_role
|
||||
params:
|
||||
role_arn: ${assume_role_arn}
|
||||
- command: shell.exec
|
||||
params:
|
||||
silent: true
|
||||
include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]
|
||||
script: |
|
||||
export AWS_ACCESS_KEY_ID=${aws_key}
|
||||
export AWS_SECRET_ACCESS_KEY=${aws_secret}
|
||||
|
||||
# Download all the task coverage files.
|
||||
aws s3 cp --recursive s3://mciuploads/${UPLOAD_BUCKET}/release/${revision}/ release/
|
||||
aws s3 cp --recursive s3://${bucket_name}/release/${revision}/ release/
|
||||
- command: shell.exec
|
||||
params:
|
||||
shell: "bash"
|
||||
@ -845,11 +863,12 @@ functions:
|
||||
- "*"
|
||||
- command: s3.put
|
||||
params:
|
||||
aws_key: ${aws_key}
|
||||
aws_secret: ${aws_secret}
|
||||
aws_key: ${AWS_ACCESS_KEY_ID}
|
||||
aws_secret: ${AWS_SECRET_ACCESS_KEY}
|
||||
aws_session_token: ${AWS_SESSION_TOKEN}
|
||||
local_file: release-files-all.tgz
|
||||
remote_file: ${UPLOAD_BUCKET}/release-all/${revision}/${task_id}-${execution}-release-files-all.tar.gz
|
||||
bucket: mciuploads
|
||||
remote_file: release-all/${revision}/${task_id}-${execution}-release-files-all.tar.gz
|
||||
bucket: ${bucket_name}
|
||||
permissions: public-read
|
||||
content_type: ${content_type|application/gzip}
|
||||
display_name: Release files all
|
||||
@ -2108,7 +2127,7 @@ tasks:
|
||||
script: |
|
||||
${PREPARE_SHELL}
|
||||
export PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3
|
||||
export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/mciuploads/libmongocrypt/debian10/master/latest/libmongocrypt.tar.gz
|
||||
export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/${bucket_name}/libmongocrypt/debian10/master/latest/libmongocrypt.tar.gz
|
||||
SUCCESS=false TEST_FLE_GCP_AUTO=1 ./.evergreen/tox.sh -m test-eg
|
||||
|
||||
- name: testazurekms-task
|
||||
|
||||
@ -258,6 +258,10 @@ if [ -z "$GREEN_FRAMEWORK" ]; then
|
||||
# Use --capture=tee-sys so pytest prints test output inline:
|
||||
# https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html
|
||||
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS
|
||||
if [ -z "$TEST_ARGS" ]; then # TODO: remove this in PYTHON-4528
|
||||
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 test/synchronous/ $TEST_ARGS
|
||||
fi
|
||||
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 test/asynchronous/ $TEST_ARGS
|
||||
else
|
||||
python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS
|
||||
fi
|
||||
|
||||
3
.github/workflows/codeql.yml
vendored
3
.github/workflows/codeql.yml
vendored
@ -26,9 +26,6 @@ jobs:
|
||||
# required for all workflows
|
||||
security-events: write
|
||||
|
||||
# required to fetch internal or private CodeQL packs
|
||||
packages: read
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
||||
6
.github/workflows/dist.yml
vendored
6
.github/workflows/dist.yml
vendored
@ -10,6 +10,10 @@ on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
workflow_call:
|
||||
inputs:
|
||||
ref:
|
||||
required: true
|
||||
type: string
|
||||
|
||||
concurrency:
|
||||
group: dist-${{ github.ref }}
|
||||
@ -44,6 +48,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
@ -99,6 +104,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
|
||||
11
.github/workflows/release-python.yml
vendored
11
.github/workflows/release-python.yml
vendored
@ -19,7 +19,7 @@ env:
|
||||
PRODUCT_NAME: PyMongo
|
||||
# Changes per branch
|
||||
SILK_ASSET_GROUP: mongodb-python-driver
|
||||
EVERGREEN_PROJECT: mongodb-python-driver
|
||||
EVERGREEN_PROJECT: mongo-python-driver
|
||||
|
||||
defaults:
|
||||
run:
|
||||
@ -32,6 +32,8 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
outputs:
|
||||
version: ${{ steps.pre-publish.outputs.version }}
|
||||
steps:
|
||||
- uses: mongodb-labs/drivers-github-tools/secure-checkout@v2
|
||||
with:
|
||||
@ -44,6 +46,7 @@ jobs:
|
||||
aws_secret_id: ${{ secrets.AWS_SECRET_ID }}
|
||||
artifactory_username: ${{ vars.ARTIFACTORY_USERNAME }}
|
||||
- uses: mongodb-labs/drivers-github-tools/python/pre-publish@v2
|
||||
id: pre-publish
|
||||
with:
|
||||
version: ${{ inputs.version }}
|
||||
dry_run: ${{ inputs.dry_run }}
|
||||
@ -51,12 +54,16 @@ jobs:
|
||||
build-dist:
|
||||
needs: [pre-publish]
|
||||
uses: ./.github/workflows/dist.yml
|
||||
with:
|
||||
ref: ${{ needs.pre-publish.outputs.version }}
|
||||
|
||||
static-scan:
|
||||
needs: [pre-publish]
|
||||
uses: ./.github/workflows/codeql.yml
|
||||
permissions:
|
||||
security-events: write
|
||||
with:
|
||||
ref: ${{ github.ref }}
|
||||
ref: ${{ needs.pre-publish.outputs.version }}
|
||||
|
||||
publish:
|
||||
needs: [build-dist, static-scan]
|
||||
|
||||
5
.github/workflows/test-python.yml
vendored
5
.github/workflows/test-python.yml
vendored
@ -71,6 +71,9 @@ jobs:
|
||||
- name: Run tests
|
||||
run: |
|
||||
tox -m test
|
||||
- name: Run async tests
|
||||
run: |
|
||||
tox -m test-async
|
||||
|
||||
doctest:
|
||||
runs-on: ubuntu-latest
|
||||
@ -203,3 +206,5 @@ jobs:
|
||||
which python
|
||||
pip install -e ".[test]"
|
||||
pytest -v
|
||||
pytest -v test/synchronous/
|
||||
pytest -v test/asynchronous/
|
||||
|
||||
@ -14,7 +14,7 @@ a native Python driver for MongoDB. The `gridfs` package is a
|
||||
[gridfs](https://github.com/mongodb/specifications/blob/master/source/gridfs/gridfs-spec.rst/)
|
||||
implementation on top of `pymongo`.
|
||||
|
||||
PyMongo supports MongoDB 3.6, 4.0, 4.2, 4.4, 5.0, 6.0, and 7.0.
|
||||
PyMongo supports MongoDB 3.6, 4.0, 4.2, 4.4, 5.0, 6.0, 7.0, and 8.0.
|
||||
|
||||
## Support / Feedback
|
||||
|
||||
|
||||
@ -6,7 +6,11 @@ Changes in Version 4.9.0
|
||||
|
||||
PyMongo 4.9 brings a number of improvements including:
|
||||
|
||||
- Added support for MongoDB 8.0.
|
||||
- A new asynchronous API with full asyncio support.
|
||||
- Add support for :attr:`~pymongo.encryption.Algorithm.RANGE` and deprecate
|
||||
:attr:`~pymongo.encryption.Algorithm.RANGEPREVIEW`.
|
||||
- pymongocrypt>=1.10 is now required for :ref:`In-Use Encryption` support.
|
||||
|
||||
Issues Resolved
|
||||
...............
|
||||
@ -26,12 +30,20 @@ PyMongo 4.8 brings a number of improvements including:
|
||||
|
||||
- The handshake metadata for "os.name" on Windows has been simplified to "Windows" to improve import time.
|
||||
- The repr of ``bson.binary.Binary`` is now redacted when the subtype is SENSITIVE_SUBTYPE(8).
|
||||
- Secure Software Development Life Cycle automation for release process.
|
||||
GitHub Releases now include a Software Bill of Materials, and signature
|
||||
files corresponding to the distribution files released on PyPI.
|
||||
- Fixed a bug in change streams where both ``startAtOperationTime`` and ``resumeToken``
|
||||
could be added to a retry attempt, which caused the retry to fail.
|
||||
- Fallback to stdlib ``ssl`` module when ``pyopenssl`` import fails with AttributeError.
|
||||
- Improved performance of MongoClient operations, especially when many operations are being run concurrently.
|
||||
|
||||
Unavoidable breaking changes
|
||||
............................
|
||||
|
||||
- Since we are now using ``hatch`` as our build backend, we no longer have a ``setup.py`` file
|
||||
and require installation using ``pip``.
|
||||
- Since we are now using ``hatch`` as our build backend, we no longer have a usable ``setup.py`` file
|
||||
and require installation using ``pip``. Attempts to invoke the ``setup.py`` file will raise an exception.
|
||||
Additionally, ``pip`` >= 21.3 is now required for editable installs.
|
||||
|
||||
Issues Resolved
|
||||
...............
|
||||
|
||||
@ -42,13 +42,12 @@ from gridfs.grid_file_shared import (
|
||||
_clear_entity_type_registry,
|
||||
)
|
||||
from pymongo import ASCENDING, DESCENDING, WriteConcern, _csot
|
||||
from pymongo.asynchronous.client_session import ClientSession
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.common import validate_string
|
||||
from pymongo.asynchronous.cursor import AsyncCursor
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.helpers import _check_write_command_response, anext
|
||||
from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.asynchronous.helpers import anext
|
||||
from pymongo.common import validate_string
|
||||
from pymongo.errors import (
|
||||
BulkWriteError,
|
||||
ConfigurationError,
|
||||
@ -57,11 +56,13 @@ from pymongo.errors import (
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.helpers_shared import _check_write_command_response
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _disallow_transactions(session: Optional[ClientSession]) -> None:
|
||||
def _disallow_transactions(session: Optional[AsyncClientSession]) -> None:
|
||||
if session and session.in_transaction:
|
||||
raise InvalidOperation("GridFS does not support multi-document transactions")
|
||||
|
||||
@ -155,7 +156,7 @@ class AsyncGridFS:
|
||||
await grid_file.write(data)
|
||||
return await grid_file._id
|
||||
|
||||
async def get(self, file_id: Any, session: Optional[ClientSession] = None) -> AsyncGridOut:
|
||||
async def get(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> AsyncGridOut:
|
||||
"""Get a file from GridFS by ``"_id"``.
|
||||
|
||||
Returns an instance of :class:`~gridfs.grid_file.GridOut`,
|
||||
@ -163,7 +164,7 @@ class AsyncGridFS:
|
||||
|
||||
:param file_id: ``"_id"`` of the file to get
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
@ -178,7 +179,7 @@ class AsyncGridFS:
|
||||
self,
|
||||
filename: Optional[str] = None,
|
||||
version: Optional[int] = -1,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGridOut:
|
||||
"""Get a file from GridFS by ``"filename"`` or metadata fields.
|
||||
@ -205,7 +206,7 @@ class AsyncGridFS:
|
||||
:param version: version of the file to get (defaults
|
||||
to -1, the most recent version uploaded)
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
:param kwargs: find files by custom metadata.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
@ -234,7 +235,10 @@ class AsyncGridFS:
|
||||
raise NoFile("no version %d for filename %r" % (version, filename)) from None
|
||||
|
||||
async def get_last_version(
|
||||
self, filename: Optional[str] = None, session: Optional[ClientSession] = None, **kwargs: Any
|
||||
self,
|
||||
filename: Optional[str] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGridOut:
|
||||
"""Get the most recent version of a file in GridFS by ``"filename"``
|
||||
or metadata fields.
|
||||
@ -244,7 +248,7 @@ class AsyncGridFS:
|
||||
|
||||
:param filename: ``"filename"`` of the file to get, or `None`
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
:param kwargs: find files by custom metadata.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
@ -253,7 +257,7 @@ class AsyncGridFS:
|
||||
return await self.get_version(filename=filename, session=session, **kwargs)
|
||||
|
||||
# TODO add optional safe mode for chunk removal?
|
||||
async def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None:
|
||||
async def delete(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> None:
|
||||
"""Delete a file from GridFS by ``"_id"``.
|
||||
|
||||
Deletes all data belonging to the file with ``"_id"``:
|
||||
@ -269,7 +273,7 @@ class AsyncGridFS:
|
||||
|
||||
:param file_id: ``"_id"`` of the file to delete
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
@ -281,12 +285,12 @@ class AsyncGridFS:
|
||||
await self._files.delete_one({"_id": file_id}, session=session)
|
||||
await self._chunks.delete_many({"files_id": file_id}, session=session)
|
||||
|
||||
async def list(self, session: Optional[ClientSession] = None) -> list[str]:
|
||||
async def list(self, session: Optional[AsyncClientSession] = None) -> list[str]:
|
||||
"""List the names of all files stored in this instance of
|
||||
:class:`GridFS`.
|
||||
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
@ -306,7 +310,7 @@ class AsyncGridFS:
|
||||
async def find_one(
|
||||
self,
|
||||
filter: Optional[Any] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Optional[AsyncGridOut]:
|
||||
@ -327,7 +331,7 @@ class AsyncGridFS:
|
||||
:param args: any additional positional arguments are
|
||||
the same as the arguments to :meth:`find`.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
:param kwargs: any additional keyword arguments
|
||||
are the same as the arguments to :meth:`find`.
|
||||
|
||||
@ -370,7 +374,7 @@ class AsyncGridFS:
|
||||
:meth:`~pymongo.collection.Collection.find`
|
||||
in :class:`~pymongo.collection.Collection`.
|
||||
|
||||
If a :class:`~pymongo.client_session.ClientSession` is passed to
|
||||
If a :class:`~pymongo.client_session.AsyncClientSession` is passed to
|
||||
:meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances
|
||||
are associated with that session.
|
||||
|
||||
@ -406,7 +410,7 @@ class AsyncGridFS:
|
||||
async def exists(
|
||||
self,
|
||||
document_or_id: Optional[Any] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Check if a file exists in this instance of :class:`GridFS`.
|
||||
@ -438,7 +442,7 @@ class AsyncGridFS:
|
||||
:param document_or_id: query document, or _id of the
|
||||
document to check for
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
:param kwargs: keyword arguments are used as a
|
||||
query document, if they're present.
|
||||
|
||||
@ -525,7 +529,7 @@ class AsyncGridFSBucket:
|
||||
filename: str,
|
||||
chunk_size_bytes: Optional[int] = None,
|
||||
metadata: Optional[Mapping[str, Any]] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> AsyncGridIn:
|
||||
"""Opens a Stream that the application can write the contents of the
|
||||
file to.
|
||||
@ -556,7 +560,7 @@ class AsyncGridFSBucket:
|
||||
files collection document. If not provided the metadata field will
|
||||
be omitted from the files collection document.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
@ -580,7 +584,7 @@ class AsyncGridFSBucket:
|
||||
filename: str,
|
||||
chunk_size_bytes: Optional[int] = None,
|
||||
metadata: Optional[Mapping[str, Any]] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> AsyncGridIn:
|
||||
"""Opens a Stream that the application can write the contents of the
|
||||
file to.
|
||||
@ -615,7 +619,7 @@ class AsyncGridFSBucket:
|
||||
files collection document. If not provided the metadata field will
|
||||
be omitted from the files collection document.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
@ -641,7 +645,7 @@ class AsyncGridFSBucket:
|
||||
source: Any,
|
||||
chunk_size_bytes: Optional[int] = None,
|
||||
metadata: Optional[Mapping[str, Any]] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> ObjectId:
|
||||
"""Uploads a user file to a GridFS bucket.
|
||||
|
||||
@ -672,7 +676,7 @@ class AsyncGridFSBucket:
|
||||
files collection document. If not provided the metadata field will
|
||||
be omitted from the files collection document.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
@ -692,7 +696,7 @@ class AsyncGridFSBucket:
|
||||
source: Any,
|
||||
chunk_size_bytes: Optional[int] = None,
|
||||
metadata: Optional[Mapping[str, Any]] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> None:
|
||||
"""Uploads a user file to a GridFS bucket with a custom file id.
|
||||
|
||||
@ -724,7 +728,7 @@ class AsyncGridFSBucket:
|
||||
files collection document. If not provided the metadata field will
|
||||
be omitted from the files collection document.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
@ -735,7 +739,7 @@ class AsyncGridFSBucket:
|
||||
await gin.write(source)
|
||||
|
||||
async def open_download_stream(
|
||||
self, file_id: Any, session: Optional[ClientSession] = None
|
||||
self, file_id: Any, session: Optional[AsyncClientSession] = None
|
||||
) -> AsyncGridOut:
|
||||
"""Opens a Stream from which the application can read the contents of
|
||||
the stored file specified by file_id.
|
||||
@ -755,7 +759,7 @@ class AsyncGridFSBucket:
|
||||
|
||||
:param file_id: The _id of the file to be downloaded.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
@ -768,7 +772,7 @@ class AsyncGridFSBucket:
|
||||
|
||||
@_csot.apply
|
||||
async def download_to_stream(
|
||||
self, file_id: Any, destination: Any, session: Optional[ClientSession] = None
|
||||
self, file_id: Any, destination: Any, session: Optional[AsyncClientSession] = None
|
||||
) -> None:
|
||||
"""Downloads the contents of the stored file specified by file_id and
|
||||
writes the contents to `destination`.
|
||||
@ -790,7 +794,7 @@ class AsyncGridFSBucket:
|
||||
:param file_id: The _id of the file to be downloaded.
|
||||
:param destination: a file-like object implementing :meth:`write`.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
@ -803,7 +807,7 @@ class AsyncGridFSBucket:
|
||||
destination.write(chunk)
|
||||
|
||||
@_csot.apply
|
||||
async def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None:
|
||||
async def delete(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> None:
|
||||
"""Given an file_id, delete this stored file's files collection document
|
||||
and associated chunks from a GridFS bucket.
|
||||
|
||||
@ -819,7 +823,7 @@ class AsyncGridFSBucket:
|
||||
|
||||
:param file_id: The _id of the file to be deleted.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
@ -859,7 +863,7 @@ class AsyncGridFSBucket:
|
||||
:meth:`~pymongo.collection.Collection.find`
|
||||
in :class:`~pymongo.collection.Collection`.
|
||||
|
||||
If a :class:`~pymongo.client_session.ClientSession` is passed to
|
||||
If a :class:`~pymongo.client_session.AsyncClientSession` is passed to
|
||||
:meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances
|
||||
are associated with that session.
|
||||
|
||||
@ -878,7 +882,7 @@ class AsyncGridFSBucket:
|
||||
return AsyncGridOutCursor(self._collection, *args, **kwargs)
|
||||
|
||||
async def open_download_stream_by_name(
|
||||
self, filename: str, revision: int = -1, session: Optional[ClientSession] = None
|
||||
self, filename: str, revision: int = -1, session: Optional[AsyncClientSession] = None
|
||||
) -> AsyncGridOut:
|
||||
"""Opens a Stream from which the application can read the contents of
|
||||
`filename` and optional `revision`.
|
||||
@ -902,7 +906,7 @@ class AsyncGridFSBucket:
|
||||
filename and different uploadDate) of the file to retrieve.
|
||||
Defaults to -1 (the most recent revision).
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
:Note: Revision numbers are defined as follows:
|
||||
|
||||
@ -937,7 +941,7 @@ class AsyncGridFSBucket:
|
||||
filename: str,
|
||||
destination: Any,
|
||||
revision: int = -1,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> None:
|
||||
"""Write the contents of `filename` (with optional `revision`) to
|
||||
`destination`.
|
||||
@ -961,7 +965,7 @@ class AsyncGridFSBucket:
|
||||
filename and different uploadDate) of the file to retrieve.
|
||||
Defaults to -1 (the most recent revision).
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
:Note: Revision numbers are defined as follows:
|
||||
|
||||
@ -985,7 +989,7 @@ class AsyncGridFSBucket:
|
||||
destination.write(chunk)
|
||||
|
||||
async def rename(
|
||||
self, file_id: Any, new_filename: str, session: Optional[ClientSession] = None
|
||||
self, file_id: Any, new_filename: str, session: Optional[AsyncClientSession] = None
|
||||
) -> None:
|
||||
"""Renames the stored file with the specified file_id.
|
||||
|
||||
@ -1002,7 +1006,7 @@ class AsyncGridFSBucket:
|
||||
:param file_id: The _id of the file to be renamed.
|
||||
:param new_filename: The new name of the file.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
@ -1024,7 +1028,7 @@ class AsyncGridIn:
|
||||
def __init__(
|
||||
self,
|
||||
root_collection: AsyncCollection,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Write a file to GridFS
|
||||
@ -1059,7 +1063,7 @@ class AsyncGridIn:
|
||||
|
||||
:param root_collection: root collection to write to
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession` to use for all
|
||||
:class:`~pymongo.client_session.AsyncClientSession` to use for all
|
||||
commands
|
||||
:param kwargs: Any: file level options (see above)
|
||||
|
||||
@ -1402,7 +1406,7 @@ class AsyncGridOut(io.IOBase):
|
||||
root_collection: AsyncCollection,
|
||||
file_id: Optional[int] = None,
|
||||
file_document: Optional[Any] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> None:
|
||||
"""Read a file from GridFS
|
||||
|
||||
@ -1420,7 +1424,7 @@ class AsyncGridOut(io.IOBase):
|
||||
:param file_document: file document from
|
||||
`root_collection.files`
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession` to use for all
|
||||
:class:`~pymongo.client_session.AsyncClientSession` to use for all
|
||||
commands
|
||||
|
||||
.. versionchanged:: 3.8
|
||||
@ -1734,7 +1738,7 @@ class _AsyncGridOutChunkIterator:
|
||||
self,
|
||||
grid_out: AsyncGridOut,
|
||||
chunks: AsyncCollection,
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
next_chunk: Any,
|
||||
) -> None:
|
||||
self._id = grid_out._id
|
||||
@ -1824,7 +1828,9 @@ class _AsyncGridOutChunkIterator:
|
||||
|
||||
|
||||
class AsyncGridOutIterator:
|
||||
def __init__(self, grid_out: AsyncGridOut, chunks: AsyncCollection, session: ClientSession):
|
||||
def __init__(
|
||||
self, grid_out: AsyncGridOut, chunks: AsyncCollection, session: AsyncClientSession
|
||||
):
|
||||
self._chunk_iter = _AsyncGridOutChunkIterator(grid_out, chunks, session, 0)
|
||||
|
||||
def __aiter__(self) -> AsyncGridOutIterator:
|
||||
@ -1851,7 +1857,7 @@ class AsyncGridOutCursor(AsyncCursor):
|
||||
no_cursor_timeout: bool = False,
|
||||
sort: Optional[Any] = None,
|
||||
batch_size: int = 0,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> None:
|
||||
"""Create a new cursor, similar to the normal
|
||||
:class:`~pymongo.cursor.Cursor`.
|
||||
@ -1894,6 +1900,6 @@ class AsyncGridOutCursor(AsyncCursor):
|
||||
def remove_option(self, *args: Any, **kwargs: Any) -> NoReturn:
|
||||
raise NotImplementedError("Method does not exist for GridOutCursor")
|
||||
|
||||
def _clone_base(self, session: Optional[ClientSession]) -> AsyncGridOutCursor:
|
||||
def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncGridOutCursor:
|
||||
"""Creates an empty GridOutCursor for information to be copied into."""
|
||||
return AsyncGridOutCursor(self._root_collection, session=session)
|
||||
|
||||
@ -5,7 +5,7 @@ import warnings
|
||||
from typing import Any, Optional
|
||||
|
||||
from pymongo import ASCENDING
|
||||
from pymongo.asynchronous.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.errors import InvalidOperation
|
||||
|
||||
_SEEK_SET = os.SEEK_SET
|
||||
|
||||
@ -42,6 +42,7 @@ from gridfs.grid_file_shared import (
|
||||
_grid_out_property,
|
||||
)
|
||||
from pymongo import ASCENDING, DESCENDING, WriteConcern, _csot
|
||||
from pymongo.common import validate_string
|
||||
from pymongo.errors import (
|
||||
BulkWriteError,
|
||||
ConfigurationError,
|
||||
@ -50,13 +51,13 @@ from pymongo.errors import (
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.helpers_shared import _check_write_command_response
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.common import validate_string
|
||||
from pymongo.synchronous.cursor import Cursor
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.helpers import _check_write_command_response, next
|
||||
from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.synchronous.helpers import next
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
@ -234,7 +235,10 @@ class GridFS:
|
||||
raise NoFile("no version %d for filename %r" % (version, filename)) from None
|
||||
|
||||
def get_last_version(
|
||||
self, filename: Optional[str] = None, session: Optional[ClientSession] = None, **kwargs: Any
|
||||
self,
|
||||
filename: Optional[str] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
**kwargs: Any,
|
||||
) -> GridOut:
|
||||
"""Get the most recent version of a file in GridFS by ``"filename"``
|
||||
or metadata fields.
|
||||
@ -497,7 +501,7 @@ class GridFSBucket:
|
||||
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
|
||||
"""
|
||||
if not isinstance(db, Database):
|
||||
raise TypeError("database must be an instance of AsyncDatabase")
|
||||
raise TypeError("database must be an instance of Database")
|
||||
|
||||
db = _clear_entity_type_registry(db)
|
||||
|
||||
@ -1028,7 +1032,7 @@ class GridIn:
|
||||
provided by :class:`~gridfs.GridFS`.
|
||||
|
||||
Raises :class:`TypeError` if `root_collection` is not an
|
||||
instance of :class:`~pymongo.collection.AsyncCollection`.
|
||||
instance of :class:`~pymongo.collection.Collection`.
|
||||
|
||||
Any of the file level options specified in the `GridFS Spec
|
||||
<http://dochub.mongodb.org/core/gridfsspec>`_ may be passed as
|
||||
@ -1069,10 +1073,10 @@ class GridIn:
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
`root_collection` must use an acknowledged
|
||||
:attr:`~pymongo.collection.AsyncCollection.write_concern`
|
||||
:attr:`~pymongo.collection.Collection.write_concern`
|
||||
"""
|
||||
if not isinstance(root_collection, Collection):
|
||||
raise TypeError("root_collection must be an instance of AsyncCollection")
|
||||
raise TypeError("root_collection must be an instance of Collection")
|
||||
|
||||
if not root_collection.write_concern.acknowledged:
|
||||
raise ConfigurationError("root_collection must use acknowledged write_concern")
|
||||
@ -1401,7 +1405,7 @@ class GridOut(io.IOBase):
|
||||
Either `file_id` or `file_document` must be specified,
|
||||
`file_document` will be given priority if present. Raises
|
||||
:class:`TypeError` if `root_collection` is not an instance of
|
||||
:class:`~pymongo.collection.AsyncCollection`.
|
||||
:class:`~pymongo.collection.Collection`.
|
||||
|
||||
:param root_collection: root collection to read from
|
||||
:param file_id: value of ``"_id"`` for the file to read
|
||||
@ -1424,7 +1428,7 @@ class GridOut(io.IOBase):
|
||||
from the server. Metadata is fetched when first needed.
|
||||
"""
|
||||
if not isinstance(root_collection, Collection):
|
||||
raise TypeError("root_collection must be an instance of AsyncCollection")
|
||||
raise TypeError("root_collection must be an instance of Collection")
|
||||
_disallow_transactions(session)
|
||||
|
||||
root_collection = _clear_entity_type_registry(root_collection)
|
||||
@ -1482,7 +1486,7 @@ class GridOut(io.IOBase):
|
||||
self.open() # type: ignore[unused-coroutine]
|
||||
elif not self._file:
|
||||
raise InvalidOperation(
|
||||
"You must call AsyncGridOut.open() before accessing the %s property" % name
|
||||
"You must call GridOut.open() before accessing the %s property" % name
|
||||
)
|
||||
if name in self._file:
|
||||
return self._file[name]
|
||||
@ -1677,13 +1681,13 @@ class GridOut(io.IOBase):
|
||||
return False
|
||||
|
||||
def __enter__(self) -> GridOut:
|
||||
"""Makes it possible to use :class:`AsyncGridOut` files
|
||||
"""Makes it possible to use :class:`GridOut` files
|
||||
with the async context manager protocol.
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any:
|
||||
"""Makes it possible to use :class:`AsyncGridOut` files
|
||||
"""Makes it possible to use :class:`GridOut` files
|
||||
with the async context manager protocol.
|
||||
"""
|
||||
self.close()
|
||||
|
||||
@ -89,11 +89,9 @@ TEXT = "text"
|
||||
from pymongo import _csot
|
||||
from pymongo._version import __version__, get_version_string, version_tuple
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION
|
||||
from pymongo.cursor import CursorType
|
||||
from pymongo.synchronous.collection import ReturnDocument
|
||||
from pymongo.synchronous.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.operations import (
|
||||
from pymongo.operations import (
|
||||
DeleteMany,
|
||||
DeleteOne,
|
||||
IndexModel,
|
||||
@ -102,7 +100,9 @@ from pymongo.synchronous.operations import (
|
||||
UpdateMany,
|
||||
UpdateOne,
|
||||
)
|
||||
from pymongo.synchronous.read_preferences import ReadPreference
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.synchronous.collection import ReturnDocument
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
version = __version__
|
||||
|
||||
@ -18,20 +18,20 @@ from __future__ import annotations
|
||||
from collections.abc import Callable, Mapping, MutableMapping
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo.asynchronous.collation import validate_collation_or_none
|
||||
from pymongo.asynchronous.read_preferences import ReadPreference, _AggWritePref
|
||||
from pymongo import common
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_preferences import ReadPreference, _AggWritePref
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.client_session import ClientSession
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.read_preferences import _ServerMode
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.asynchronous.server import Server
|
||||
from pymongo.asynchronous.typings import _DocumentType, _Pipeline
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.typings import _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
@ -53,7 +53,7 @@ class _AggregationCommand:
|
||||
explicit_session: bool,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
user_fields: Optional[MutableMapping[str, Any]] = None,
|
||||
result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None,
|
||||
result_processor: Optional[Callable[[Mapping[str, Any], AsyncConnection], None]] = None,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
if "explain" in options:
|
||||
@ -121,7 +121,7 @@ class _AggregationCommand:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_read_preference(
|
||||
self, session: Optional[ClientSession]
|
||||
self, session: Optional[AsyncClientSession]
|
||||
) -> Union[_AggWritePref, _ServerMode]:
|
||||
if self._write_preference:
|
||||
return self._write_preference
|
||||
@ -132,9 +132,9 @@ class _AggregationCommand:
|
||||
|
||||
async def get_cursor(
|
||||
self,
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
server: Server,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
read_preference: _ServerMode,
|
||||
) -> AsyncCommandCursor[_DocumentType]:
|
||||
# Serialize command.
|
||||
|
||||
@ -18,17 +18,13 @@ from __future__ import annotations
|
||||
import functools
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import socket
|
||||
import typing
|
||||
from base64 import standard_b64decode, standard_b64encode
|
||||
from collections import namedtuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
@ -41,17 +37,19 @@ from pymongo.asynchronous.auth_aws import _authenticate_aws
|
||||
from pymongo.asynchronous.auth_oidc import (
|
||||
_authenticate_oidc,
|
||||
_get_authenticator,
|
||||
_OIDCAzureCallback,
|
||||
_OIDCGCPCallback,
|
||||
_OIDCProperties,
|
||||
_OIDCTestCallback,
|
||||
)
|
||||
from pymongo.auth_shared import (
|
||||
MongoCredential,
|
||||
_authenticate_scram_start,
|
||||
_parse_scram_response,
|
||||
_xor,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.saslprep import saslprep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.hello import Hello
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.hello import Hello
|
||||
|
||||
HAVE_KERBEROS = True
|
||||
_USE_PRINCIPAL = False
|
||||
@ -69,213 +67,9 @@ except ImportError:
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
MECHANISMS = frozenset(
|
||||
[
|
||||
"GSSAPI",
|
||||
"MONGODB-CR",
|
||||
"MONGODB-OIDC",
|
||||
"MONGODB-X509",
|
||||
"MONGODB-AWS",
|
||||
"PLAIN",
|
||||
"SCRAM-SHA-1",
|
||||
"SCRAM-SHA-256",
|
||||
"DEFAULT",
|
||||
]
|
||||
)
|
||||
"""The authentication mechanisms supported by PyMongo."""
|
||||
|
||||
|
||||
class _Cache:
|
||||
__slots__ = ("data",)
|
||||
|
||||
_hash_val = hash("_Cache")
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = None
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
# Two instances must always compare equal.
|
||||
if isinstance(other, _Cache):
|
||||
return True
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
if isinstance(other, _Cache):
|
||||
return False
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self._hash_val
|
||||
|
||||
|
||||
MongoCredential = namedtuple(
|
||||
"MongoCredential",
|
||||
["mechanism", "source", "username", "password", "mechanism_properties", "cache"],
|
||||
)
|
||||
"""A hashable namedtuple of values used for authentication."""
|
||||
|
||||
|
||||
GSSAPIProperties = namedtuple(
|
||||
"GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"]
|
||||
)
|
||||
"""Mechanism properties for GSSAPI authentication."""
|
||||
|
||||
|
||||
_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
|
||||
"""Mechanism properties for MONGODB-AWS authentication."""
|
||||
|
||||
|
||||
def _build_credentials_tuple(
|
||||
mech: str,
|
||||
source: Optional[str],
|
||||
user: str,
|
||||
passwd: str,
|
||||
extra: Mapping[str, Any],
|
||||
database: Optional[str],
|
||||
) -> MongoCredential:
|
||||
"""Build and return a mechanism specific credentials tuple."""
|
||||
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
|
||||
raise ConfigurationError(f"{mech} requires a username.")
|
||||
if mech == "GSSAPI":
|
||||
if source is not None and source != "$external":
|
||||
raise ValueError("authentication source must be $external or None for GSSAPI")
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
service_name = properties.get("SERVICE_NAME", "mongodb")
|
||||
canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False))
|
||||
service_realm = properties.get("SERVICE_REALM")
|
||||
props = GSSAPIProperties(
|
||||
service_name=service_name,
|
||||
canonicalize_host_name=canonicalize,
|
||||
service_realm=service_realm,
|
||||
)
|
||||
# Source is always $external.
|
||||
return MongoCredential(mech, "$external", user, passwd, props, None)
|
||||
elif mech == "MONGODB-X509":
|
||||
if passwd is not None:
|
||||
raise ConfigurationError("Passwords are not supported by MONGODB-X509")
|
||||
if source is not None and source != "$external":
|
||||
raise ValueError("authentication source must be $external or None for MONGODB-X509")
|
||||
# Source is always $external, user can be None.
|
||||
return MongoCredential(mech, "$external", user, None, None, None)
|
||||
elif mech == "MONGODB-AWS":
|
||||
if user is not None and passwd is None:
|
||||
raise ConfigurationError("username without a password is not supported by MONGODB-AWS")
|
||||
if source is not None and source != "$external":
|
||||
raise ConfigurationError(
|
||||
"authentication source must be $external or None for MONGODB-AWS"
|
||||
)
|
||||
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
aws_session_token = properties.get("AWS_SESSION_TOKEN")
|
||||
aws_props = _AWSProperties(aws_session_token=aws_session_token)
|
||||
# user can be None for temporary link-local EC2 credentials.
|
||||
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
|
||||
elif mech == "MONGODB-OIDC":
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
callback = properties.get("OIDC_CALLBACK")
|
||||
human_callback = properties.get("OIDC_HUMAN_CALLBACK")
|
||||
environ = properties.get("ENVIRONMENT")
|
||||
token_resource = properties.get("TOKEN_RESOURCE", "")
|
||||
default_allowed = [
|
||||
"*.mongodb.net",
|
||||
"*.mongodb-dev.net",
|
||||
"*.mongodb-qa.net",
|
||||
"*.mongodbgov.net",
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
]
|
||||
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
|
||||
msg = (
|
||||
"authentication with MONGODB-OIDC requires providing either a callback or a environment"
|
||||
)
|
||||
if passwd is not None:
|
||||
msg = "password is not supported by MONGODB-OIDC"
|
||||
raise ConfigurationError(msg)
|
||||
if callback or human_callback:
|
||||
if environ is not None:
|
||||
raise ConfigurationError(msg)
|
||||
if callback and human_callback:
|
||||
msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK"
|
||||
raise ConfigurationError(msg)
|
||||
elif environ is not None:
|
||||
if environ == "test":
|
||||
if user is not None:
|
||||
msg = "test environment for MONGODB-OIDC does not support username"
|
||||
raise ConfigurationError(msg)
|
||||
callback = _OIDCTestCallback()
|
||||
elif environ == "azure":
|
||||
passwd = None
|
||||
if not token_resource:
|
||||
raise ConfigurationError(
|
||||
"Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
|
||||
)
|
||||
callback = _OIDCAzureCallback(token_resource)
|
||||
elif environ == "gcp":
|
||||
passwd = None
|
||||
if not token_resource:
|
||||
raise ConfigurationError(
|
||||
"GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
|
||||
)
|
||||
callback = _OIDCGCPCallback(token_resource)
|
||||
else:
|
||||
raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}")
|
||||
else:
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
oidc_props = _OIDCProperties(
|
||||
callback=callback,
|
||||
human_callback=human_callback,
|
||||
environment=environ,
|
||||
allowed_hosts=allowed_hosts,
|
||||
token_resource=token_resource,
|
||||
username=user,
|
||||
)
|
||||
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
|
||||
|
||||
elif mech == "PLAIN":
|
||||
source_database = source or database or "$external"
|
||||
return MongoCredential(mech, source_database, user, passwd, None, None)
|
||||
else:
|
||||
source_database = source or database or "admin"
|
||||
if passwd is None:
|
||||
raise ConfigurationError("A password is required.")
|
||||
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
|
||||
|
||||
|
||||
def _xor(fir: bytes, sec: bytes) -> bytes:
|
||||
"""XOR two byte strings together."""
|
||||
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
|
||||
|
||||
|
||||
def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]:
|
||||
"""Split a scram response into key, value pairs."""
|
||||
return dict(
|
||||
typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1))
|
||||
for item in response.split(b",")
|
||||
)
|
||||
|
||||
|
||||
def _authenticate_scram_start(
|
||||
credentials: MongoCredential, mechanism: str
|
||||
) -> tuple[bytes, bytes, MutableMapping[str, Any]]:
|
||||
username = credentials.username
|
||||
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
|
||||
nonce = standard_b64encode(os.urandom(32))
|
||||
first_bare = b"n=" + user + b",r=" + nonce
|
||||
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": mechanism,
|
||||
"payload": Binary(b"n,," + first_bare),
|
||||
"autoAuthorize": 1,
|
||||
"options": {"skipEmptyExchange": True},
|
||||
}
|
||||
return nonce, first_bare, cmd
|
||||
|
||||
|
||||
async def _authenticate_scram(
|
||||
credentials: MongoCredential, conn: Connection, mechanism: str
|
||||
credentials: MongoCredential, conn: AsyncConnection, mechanism: str
|
||||
) -> None:
|
||||
"""Authenticate using SCRAM."""
|
||||
username = credentials.username
|
||||
@ -398,7 +192,7 @@ def _canonicalize_hostname(hostname: str) -> str:
|
||||
return name[0].lower()
|
||||
|
||||
|
||||
async def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None:
|
||||
async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnection) -> None:
|
||||
"""Authenticate using GSSAPI."""
|
||||
if not HAVE_KERBEROS:
|
||||
raise ConfigurationError(
|
||||
@ -509,7 +303,7 @@ async def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -
|
||||
raise OperationFailure(str(exc)) from None
|
||||
|
||||
|
||||
async def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None:
|
||||
async def _authenticate_plain(credentials: MongoCredential, conn: AsyncConnection) -> None:
|
||||
"""Authenticate using SASL PLAIN (RFC 4616)"""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
@ -524,7 +318,7 @@ async def _authenticate_plain(credentials: MongoCredential, conn: Connection) ->
|
||||
await conn.command(source, cmd)
|
||||
|
||||
|
||||
async def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None:
|
||||
async def _authenticate_x509(credentials: MongoCredential, conn: AsyncConnection) -> None:
|
||||
"""Authenticate using MONGODB-X509."""
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
@ -535,7 +329,7 @@ async def _authenticate_x509(credentials: MongoCredential, conn: Connection) ->
|
||||
await conn.command("$external", cmd)
|
||||
|
||||
|
||||
async def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None:
|
||||
async def _authenticate_mongo_cr(credentials: MongoCredential, conn: AsyncConnection) -> None:
|
||||
"""Authenticate using MONGODB-CR."""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
@ -550,7 +344,7 @@ async def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection)
|
||||
await conn.command(source, query)
|
||||
|
||||
|
||||
async def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None:
|
||||
async def _authenticate_default(credentials: MongoCredential, conn: AsyncConnection) -> None:
|
||||
if conn.max_wire_version >= 7:
|
||||
if conn.negotiated_mechs:
|
||||
mechs = conn.negotiated_mechs
|
||||
@ -652,7 +446,7 @@ _SPECULATIVE_AUTH_MAP: Mapping[str, Any] = {
|
||||
|
||||
|
||||
async def authenticate(
|
||||
credentials: MongoCredential, conn: Connection, reauthenticate: bool = False
|
||||
credentials: MongoCredential, conn: AsyncConnection, reauthenticate: bool = False
|
||||
) -> None:
|
||||
"""Authenticate connection."""
|
||||
mechanism = credentials.mechanism
|
||||
|
||||
@ -23,13 +23,13 @@ from pymongo.errors import ConfigurationError, OperationFailure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.typings import _ReadableBuffer
|
||||
from pymongo.asynchronous.auth import MongoCredential
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.auth_shared import MongoCredential
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
async def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
|
||||
async def _authenticate_aws(credentials: MongoCredential, conn: AsyncConnection) -> None:
|
||||
"""Authenticate using MONGODB-AWS."""
|
||||
try:
|
||||
import pymongo_auth_aws # type:ignore[import]
|
||||
|
||||
@ -15,79 +15,35 @@
|
||||
"""MONGODB-OIDC Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
from pymongo._azure_helpers import _get_azure_response
|
||||
from pymongo._csot import remaining
|
||||
from pymongo._gcp_helpers import _get_gcp_response
|
||||
from pymongo.auth_oidc_shared import (
|
||||
CALLBACK_VERSION,
|
||||
HUMAN_CALLBACK_TIMEOUT_SECONDS,
|
||||
MACHINE_CALLBACK_TIMEOUT_SECONDS,
|
||||
TIME_BETWEEN_CALLS_SECONDS,
|
||||
OIDCCallback,
|
||||
OIDCCallbackContext,
|
||||
OIDCCallbackResult,
|
||||
OIDCIdPInfo,
|
||||
_OIDCProperties,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.helpers_constants import _AUTHENTICATION_FAILURE_CODE
|
||||
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.auth import MongoCredential
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.auth_shared import MongoCredential
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCIdPInfo:
|
||||
issuer: str
|
||||
clientId: Optional[str] = field(default=None)
|
||||
requestScopes: Optional[list[str]] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCCallbackContext:
|
||||
timeout_seconds: float
|
||||
username: str
|
||||
version: int
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[OIDCIdPInfo] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCCallbackResult:
|
||||
access_token: str
|
||||
expires_in_seconds: Optional[float] = field(default=None)
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
class OIDCCallback(abc.ABC):
|
||||
"""A base class for defining OIDC callbacks."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
"""Convert the given BSON value into our own type."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCProperties:
|
||||
callback: Optional[OIDCCallback] = field(default=None)
|
||||
human_callback: Optional[OIDCCallback] = field(default=None)
|
||||
environment: Optional[str] = field(default=None)
|
||||
allowed_hosts: list[str] = field(default_factory=list)
|
||||
token_resource: Optional[str] = field(default=None)
|
||||
username: str = ""
|
||||
|
||||
|
||||
"""Mechanism properties for MONGODB-OIDC authentication."""
|
||||
|
||||
TOKEN_BUFFER_MINUTES = 5
|
||||
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
|
||||
CALLBACK_VERSION = 1
|
||||
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
|
||||
TIME_BETWEEN_CALLS_SECONDS = 0.1
|
||||
|
||||
|
||||
def _get_authenticator(
|
||||
credentials: MongoCredential, address: tuple[str, int]
|
||||
) -> _OIDCAuthenticator:
|
||||
@ -117,48 +73,6 @@ def _get_authenticator(
|
||||
return credentials.cache.data
|
||||
|
||||
|
||||
class _OIDCTestCallback(OIDCCallback):
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
token_file = os.environ.get("OIDC_TOKEN_FILE")
|
||||
if not token_file:
|
||||
raise RuntimeError(
|
||||
'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set'
|
||||
)
|
||||
with open(token_file) as fid:
|
||||
return OIDCCallbackResult(access_token=fid.read().strip())
|
||||
|
||||
|
||||
class _OIDCAWSCallback(OIDCCallback):
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
|
||||
if not token_file:
|
||||
raise RuntimeError(
|
||||
'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set'
|
||||
)
|
||||
with open(token_file) as fid:
|
||||
return OIDCCallbackResult(access_token=fid.read().strip())
|
||||
|
||||
|
||||
class _OIDCAzureCallback(OIDCCallback):
|
||||
def __init__(self, token_resource: str) -> None:
|
||||
self.token_resource = quote(token_resource)
|
||||
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds)
|
||||
return OIDCCallbackResult(
|
||||
access_token=resp["access_token"], expires_in_seconds=resp["expires_in"]
|
||||
)
|
||||
|
||||
|
||||
class _OIDCGCPCallback(OIDCCallback):
|
||||
def __init__(self, token_resource: str) -> None:
|
||||
self.token_resource = quote(token_resource)
|
||||
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
|
||||
return OIDCCallbackResult(access_token=resp["access_token"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCAuthenticator:
|
||||
username: str
|
||||
@ -170,7 +84,7 @@ class _OIDCAuthenticator:
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
last_call_time: float = field(default=0)
|
||||
|
||||
async def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
async def reauthenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
|
||||
"""Handle a reauthenticate from the server."""
|
||||
# Invalidate the token for the connection.
|
||||
self._invalidate(conn)
|
||||
@ -179,7 +93,7 @@ class _OIDCAuthenticator:
|
||||
return await self._authenticate_machine(conn)
|
||||
return await self._authenticate_human(conn)
|
||||
|
||||
async def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
async def authenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
|
||||
"""Handle an initial authenticate request."""
|
||||
# First handle speculative auth.
|
||||
# If it succeeded, we are done.
|
||||
@ -203,7 +117,7 @@ class _OIDCAuthenticator:
|
||||
return None
|
||||
return self._get_start_command({"jwt": self.access_token})
|
||||
|
||||
async def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]:
|
||||
async def _authenticate_machine(self, conn: AsyncConnection) -> Mapping[str, Any]:
|
||||
# If there is a cached access token, try to authenticate with it. If
|
||||
# authentication fails with error code 18, invalidate the access token,
|
||||
# fetch a new access token, and try to authenticate again. If authentication
|
||||
@ -217,7 +131,7 @@ class _OIDCAuthenticator:
|
||||
raise
|
||||
return await self._sasl_start_jwt(conn)
|
||||
|
||||
async def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
async def _authenticate_human(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
|
||||
# If we have a cached access token, try a JwtStepRequest.
|
||||
# authentication fails with error code 18, invalidate the access token,
|
||||
# and try to authenticate again. If authentication fails for any other
|
||||
@ -307,7 +221,7 @@ class _OIDCAuthenticator:
|
||||
return self.access_token
|
||||
|
||||
async def _run_command(
|
||||
self, conn: Connection, cmd: MutableMapping[str, Any]
|
||||
self, conn: AsyncConnection, cmd: MutableMapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
try:
|
||||
return await conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
|
||||
@ -321,7 +235,7 @@ class _OIDCAuthenticator:
|
||||
return False
|
||||
return err.code == _AUTHENTICATION_FAILURE_CODE
|
||||
|
||||
def _invalidate(self, conn: Connection) -> None:
|
||||
def _invalidate(self, conn: AsyncConnection) -> None:
|
||||
# Ignore the invalidation if a token gen id is given and is less than our
|
||||
# current token gen id.
|
||||
token_gen_id = conn.oidc_token_gen_id or 0
|
||||
@ -330,7 +244,7 @@ class _OIDCAuthenticator:
|
||||
self.access_token = None
|
||||
|
||||
async def _sasl_continue_jwt(
|
||||
self, conn: Connection, start_resp: Mapping[str, Any]
|
||||
self, conn: AsyncConnection, start_resp: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
self.access_token = None
|
||||
self.refresh_token = None
|
||||
@ -342,7 +256,7 @@ class _OIDCAuthenticator:
|
||||
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
|
||||
return await self._run_command(conn, cmd)
|
||||
|
||||
async def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]:
|
||||
async def _sasl_start_jwt(self, conn: AsyncConnection) -> Mapping[str, Any]:
|
||||
access_token = self._get_access_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
cmd = self._get_start_command({"jwt": access_token})
|
||||
@ -370,7 +284,7 @@ class _OIDCAuthenticator:
|
||||
|
||||
|
||||
async def _authenticate_oidc(
|
||||
credentials: MongoCredential, conn: Connection, reauthenticate: bool
|
||||
credentials: MongoCredential, conn: AsyncConnection, reauthenticate: bool
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""Authenticate using MONGODB-OIDC."""
|
||||
authenticator = _get_authenticator(credentials, conn.address)
|
||||
|
||||
@ -19,6 +19,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import logging
|
||||
from collections.abc import MutableMapping
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
@ -26,7 +28,6 @@ from typing import (
|
||||
Any,
|
||||
Iterator,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
@ -34,142 +35,52 @@ from typing import (
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo import _csot
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo.asynchronous.client_session import ClientSession, _validate_session_write_concern
|
||||
from pymongo.asynchronous.common import (
|
||||
from pymongo import _csot, common
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern
|
||||
from pymongo.asynchronous.helpers import _handle_reauth
|
||||
from pymongo.bulk_shared import (
|
||||
_COMMANDS,
|
||||
_DELETE_ALL,
|
||||
_merge_command,
|
||||
_raise_bulk_write_error,
|
||||
_Run,
|
||||
)
|
||||
from pymongo.common import (
|
||||
validate_is_document_type,
|
||||
validate_ok_for_replace,
|
||||
validate_ok_for_update,
|
||||
)
|
||||
from pymongo.asynchronous.helpers import _get_wce_doc
|
||||
from pymongo.asynchronous.message import (
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
InvalidOperation,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import (
|
||||
_DELETE,
|
||||
_INSERT,
|
||||
_UPDATE,
|
||||
_BulkWriteContext,
|
||||
_convert_exception,
|
||||
_convert_write_result,
|
||||
_EncryptedBulkWriteContext,
|
||||
_randint,
|
||||
)
|
||||
from pymongo.asynchronous.read_preferences import ReadPreference
|
||||
from pymongo.errors import (
|
||||
BulkWriteError,
|
||||
ConfigurationError,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.typings import _DocumentOut, _DocumentType, _Pipeline
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
_DELETE_ALL: int = 0
|
||||
_DELETE_ONE: int = 1
|
||||
|
||||
# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err
|
||||
_BAD_VALUE: int = 2
|
||||
_UNKNOWN_ERROR: int = 8
|
||||
_WRITE_CONCERN_ERROR: int = 64
|
||||
|
||||
_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete")
|
||||
|
||||
|
||||
class _Run:
|
||||
"""Represents a batch of write operations."""
|
||||
|
||||
def __init__(self, op_type: int) -> None:
|
||||
"""Initialize a new Run object."""
|
||||
self.op_type: int = op_type
|
||||
self.index_map: list[int] = []
|
||||
self.ops: list[Any] = []
|
||||
self.idx_offset: int = 0
|
||||
|
||||
def index(self, idx: int) -> int:
|
||||
"""Get the original index of an operation in this run.
|
||||
|
||||
:param idx: The Run index that maps to the original index.
|
||||
"""
|
||||
return self.index_map[idx]
|
||||
|
||||
def add(self, original_index: int, operation: Any) -> None:
|
||||
"""Add an operation to this Run instance.
|
||||
|
||||
:param original_index: The original index of this operation
|
||||
within a larger bulk operation.
|
||||
:param operation: The operation document.
|
||||
"""
|
||||
self.index_map.append(original_index)
|
||||
self.ops.append(operation)
|
||||
|
||||
|
||||
def _merge_command(
|
||||
run: _Run,
|
||||
full_result: MutableMapping[str, Any],
|
||||
offset: int,
|
||||
result: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Merge a write command result into the full bulk result."""
|
||||
affected = result.get("n", 0)
|
||||
|
||||
if run.op_type == _INSERT:
|
||||
full_result["nInserted"] += affected
|
||||
|
||||
elif run.op_type == _DELETE:
|
||||
full_result["nRemoved"] += affected
|
||||
|
||||
elif run.op_type == _UPDATE:
|
||||
upserted = result.get("upserted")
|
||||
if upserted:
|
||||
n_upserted = len(upserted)
|
||||
for doc in upserted:
|
||||
doc["index"] = run.index(doc["index"] + offset)
|
||||
full_result["upserted"].extend(upserted)
|
||||
full_result["nUpserted"] += n_upserted
|
||||
full_result["nMatched"] += affected - n_upserted
|
||||
else:
|
||||
full_result["nMatched"] += affected
|
||||
full_result["nModified"] += result["nModified"]
|
||||
|
||||
write_errors = result.get("writeErrors")
|
||||
if write_errors:
|
||||
for doc in write_errors:
|
||||
# Leave the server response intact for APM.
|
||||
replacement = doc.copy()
|
||||
idx = doc["index"] + offset
|
||||
replacement["index"] = run.index(idx)
|
||||
# Add the failed operation to the error document.
|
||||
replacement["op"] = run.ops[idx]
|
||||
full_result["writeErrors"].append(replacement)
|
||||
|
||||
wce = _get_wce_doc(result)
|
||||
if wce:
|
||||
full_result["writeConcernErrors"].append(wce)
|
||||
|
||||
|
||||
def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn:
|
||||
"""Raise a BulkWriteError from the full bulk api result."""
|
||||
# retryWrites on MMAPv1 should raise an actionable error.
|
||||
if full_result["writeErrors"]:
|
||||
full_result["writeErrors"].sort(key=lambda error: error["index"])
|
||||
err = full_result["writeErrors"][0]
|
||||
code = err["code"]
|
||||
msg = err["errmsg"]
|
||||
if code == 20 and msg.startswith("Transaction numbers"):
|
||||
errmsg = (
|
||||
"This MongoDB deployment does not support "
|
||||
"retryable writes. Please add retryWrites=false "
|
||||
"to your connection string."
|
||||
)
|
||||
raise OperationFailure(errmsg, code, full_result)
|
||||
raise BulkWriteError(full_result)
|
||||
|
||||
|
||||
class _Bulk:
|
||||
class _AsyncBulk:
|
||||
"""The private guts of the bulk write API."""
|
||||
|
||||
def __init__(
|
||||
@ -180,7 +91,7 @@ class _Bulk:
|
||||
comment: Optional[str] = None,
|
||||
let: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Initialize a _Bulk instance."""
|
||||
"""Initialize a _AsyncBulk instance."""
|
||||
self.collection = collection.with_options(
|
||||
codec_options=collection.codec_options._replace(
|
||||
unicode_decode_error_handler="replace", document_class=dict
|
||||
@ -204,13 +115,16 @@ class _Bulk:
|
||||
# Extra state so that we know where to pick up on a retry attempt.
|
||||
self.current_run = None
|
||||
self.next_run = None
|
||||
self.is_encrypted = False
|
||||
|
||||
@property
|
||||
def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
|
||||
encrypter = self.collection.database.client._encrypter
|
||||
if encrypter and not encrypter._bypass_auto_encryption:
|
||||
self.is_encrypted = True
|
||||
return _EncryptedBulkWriteContext
|
||||
else:
|
||||
self.is_encrypted = False
|
||||
return _BulkWriteContext
|
||||
|
||||
def add_insert(self, document: _DocumentOut) -> None:
|
||||
@ -315,12 +229,236 @@ class _Bulk:
|
||||
if run.ops:
|
||||
yield run
|
||||
|
||||
@_handle_reauth
|
||||
async def write_command(
|
||||
self,
|
||||
bwc: _BulkWriteContext,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
docs: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
) -> dict[str, Any]:
|
||||
"""A proxy for SocketInfo.write_command that handles event publishing."""
|
||||
cmd[bwc.field] = docs
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._start(cmd, request_id, docs)
|
||||
try:
|
||||
reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc]
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
durationMS=duration,
|
||||
reply=reply,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
|
||||
if bwc.publish:
|
||||
bwc._fail(request_id, failure, duration)
|
||||
raise
|
||||
finally:
|
||||
bwc.start_time = datetime.datetime.now()
|
||||
return reply # type: ignore[return-value]
|
||||
|
||||
async def unack_write(
|
||||
self,
|
||||
bwc: _BulkWriteContext,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
max_doc_size: int,
|
||||
docs: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
cmd = bwc._start(cmd, request_id, docs)
|
||||
try:
|
||||
result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override]
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if result is not None:
|
||||
reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type]
|
||||
else:
|
||||
# Comply with APM spec.
|
||||
reply = {"ok": 1}
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
durationMS=duration,
|
||||
reply=reply,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._succeed(request_id, reply, duration)
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if isinstance(exc, OperationFailure):
|
||||
failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type]
|
||||
elif isinstance(exc, NotPrimaryError):
|
||||
failure = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
if bwc.publish:
|
||||
assert bwc.start_time is not None
|
||||
bwc._fail(request_id, failure, duration)
|
||||
raise
|
||||
finally:
|
||||
bwc.start_time = datetime.datetime.now()
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
async def _execute_batch_unack(
|
||||
self,
|
||||
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
|
||||
cmd: dict[str, Any],
|
||||
ops: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
if self.is_encrypted:
|
||||
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
|
||||
await bwc.conn.command( # type: ignore[misc]
|
||||
bwc.db_name,
|
||||
batched_cmd, # type: ignore[arg-type]
|
||||
write_concern=WriteConcern(w=0),
|
||||
session=bwc.session, # type: ignore[arg-type]
|
||||
client=client, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
request_id, msg, to_send = bwc.batch_command(cmd, ops)
|
||||
# Though this isn't strictly a "legacy" write, the helper
|
||||
# handles publishing commands and sending our message
|
||||
# without receiving a result. Send 0 for max_doc_size
|
||||
# to disable size checking. Size checking is handled while
|
||||
# the documents are encoded to BSON.
|
||||
await self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) # type: ignore[arg-type]
|
||||
|
||||
return to_send
|
||||
|
||||
async def _execute_batch(
|
||||
self,
|
||||
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
|
||||
cmd: dict[str, Any],
|
||||
ops: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
) -> tuple[dict[str, Any], list[Mapping[str, Any]]]:
|
||||
if self.is_encrypted:
|
||||
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
|
||||
result = await bwc.conn.command( # type: ignore[misc]
|
||||
bwc.db_name,
|
||||
batched_cmd, # type: ignore[arg-type]
|
||||
codec_options=bwc.codec,
|
||||
session=bwc.session, # type: ignore[arg-type]
|
||||
client=client, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
request_id, msg, to_send = bwc.batch_command(cmd, ops)
|
||||
result = await self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type]
|
||||
await client._process_response(result, bwc.session) # type: ignore[arg-type]
|
||||
|
||||
return result, to_send # type: ignore[return-value]
|
||||
|
||||
async def _execute_command(
|
||||
self,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[ClientSession],
|
||||
conn: Connection,
|
||||
session: Optional[AsyncClientSession],
|
||||
conn: AsyncConnection,
|
||||
op_id: int,
|
||||
retryable: bool,
|
||||
full_result: MutableMapping[str, Any],
|
||||
@ -335,8 +473,8 @@ class _Bulk:
|
||||
self.next_run = None
|
||||
run = self.current_run
|
||||
|
||||
# Connection.command validates the session, but we use
|
||||
# Connection.write_command
|
||||
# AsyncConnection.command validates the session, but we use
|
||||
# AsyncConnection.write_command
|
||||
conn.validate_session(client, session)
|
||||
last_run = False
|
||||
|
||||
@ -387,7 +525,7 @@ class _Bulk:
|
||||
|
||||
# Run as many ops as possible in one command.
|
||||
if write_concern.acknowledged:
|
||||
result, to_send = await bwc.execute(cmd, ops, client)
|
||||
result, to_send = await self._execute_batch(bwc, cmd, ops, client)
|
||||
|
||||
# Retryable writeConcernErrors halt the execution of this run.
|
||||
wce = result.get("writeConcernError", {})
|
||||
@ -407,7 +545,7 @@ class _Bulk:
|
||||
if self.ordered and "writeErrors" in result:
|
||||
break
|
||||
else:
|
||||
to_send = await bwc.execute_unack(cmd, ops, client)
|
||||
to_send = await self._execute_batch_unack(bwc, cmd, ops, client)
|
||||
|
||||
run.idx_offset += len(to_send)
|
||||
|
||||
@ -422,7 +560,7 @@ class _Bulk:
|
||||
self,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
operation: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute using write commands."""
|
||||
@ -440,7 +578,7 @@ class _Bulk:
|
||||
op_id = _randint()
|
||||
|
||||
async def retryable_bulk(
|
||||
session: Optional[ClientSession], conn: Connection, retryable: bool
|
||||
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool
|
||||
) -> None:
|
||||
await self._execute_command(
|
||||
generator,
|
||||
@ -458,7 +596,7 @@ class _Bulk:
|
||||
retryable_bulk,
|
||||
session,
|
||||
operation,
|
||||
bulk=self,
|
||||
bulk=self, # type: ignore[arg-type]
|
||||
operation_id=op_id,
|
||||
)
|
||||
|
||||
@ -466,7 +604,9 @@ class _Bulk:
|
||||
_raise_bulk_write_error(full_result)
|
||||
return full_result
|
||||
|
||||
async def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None:
|
||||
async def execute_op_msg_no_results(
|
||||
self, conn: AsyncConnection, generator: Iterator[Any]
|
||||
) -> None:
|
||||
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
|
||||
db_name = self.collection.database.name
|
||||
client = self.collection.database.client
|
||||
@ -499,13 +639,13 @@ class _Bulk:
|
||||
conn.add_server_api(cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
# Run as many ops as possible.
|
||||
to_send = await bwc.execute_unack(cmd, ops, client)
|
||||
to_send = await self._execute_batch_unack(bwc, cmd, ops, client)
|
||||
run.idx_offset += len(to_send)
|
||||
self.current_run = run = next(generator, None)
|
||||
|
||||
async def execute_command_no_results(
|
||||
self,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
) -> None:
|
||||
@ -541,7 +681,7 @@ class _Bulk:
|
||||
|
||||
async def execute_no_results(
|
||||
self,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
) -> None:
|
||||
@ -573,7 +713,7 @@ class _Bulk:
|
||||
async def execute(
|
||||
self,
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
operation: str,
|
||||
) -> Any:
|
||||
"""Execute operations."""
|
||||
|
||||
@ -21,17 +21,14 @@ from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
|
||||
from bson import CodecOptions, _bson_to_dict
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo import _csot, common
|
||||
from pymongo.asynchronous.aggregation import (
|
||||
_AggregationCommand,
|
||||
_CollectionAggregationCommand,
|
||||
_DatabaseAggregationCommand,
|
||||
)
|
||||
from pymongo.asynchronous.collation import validate_collation_or_none
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.operations import _Op
|
||||
from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.errors import (
|
||||
ConnectionFailure,
|
||||
CursorNotFound,
|
||||
@ -39,6 +36,8 @@ from pymongo.errors import (
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
)
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
@ -68,11 +67,11 @@ _RESUMABLE_GETMORE_ERRORS = frozenset(
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.client_session import ClientSession
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
|
||||
|
||||
def _resumable(exc: PyMongoError) -> bool:
|
||||
@ -114,7 +113,7 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
batch_size: Optional[int],
|
||||
collation: Optional[_CollationIn],
|
||||
start_at_operation_time: Optional[Timestamp],
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
start_after: Optional[Mapping[str, Any]],
|
||||
comment: Optional[Any] = None,
|
||||
full_document_before_change: Optional[str] = None,
|
||||
@ -211,7 +210,7 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
full_pipeline.extend(self._pipeline)
|
||||
return full_pipeline
|
||||
|
||||
def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None:
|
||||
def _process_result(self, result: Mapping[str, Any], conn: AsyncConnection) -> None:
|
||||
"""Callback that caches the postBatchResumeToken or
|
||||
startAtOperationTime from a changeStream aggregate command response
|
||||
containing an empty batch of change documents.
|
||||
@ -237,7 +236,7 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
async def _run_aggregation_cmd(
|
||||
self, session: Optional[ClientSession], explicit_session: bool
|
||||
self, session: Optional[AsyncClientSession], explicit_session: bool
|
||||
) -> AsyncCommandCursor:
|
||||
"""Run the full aggregation pipeline for this ChangeStream and return
|
||||
the corresponding AsyncCommandCursor.
|
||||
|
||||
@ -1,334 +0,0 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Tools to parse mongo client options."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast
|
||||
|
||||
from bson.codec_options import _parse_codec_options
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo.asynchronous.compression_support import CompressionSettings
|
||||
from pymongo.asynchronous.monitoring import _EventListener, _EventListeners
|
||||
from pymongo.asynchronous.pool import PoolOptions
|
||||
from pymongo.asynchronous.read_preferences import (
|
||||
_ServerMode,
|
||||
make_read_preference,
|
||||
read_pref_mode_from_name,
|
||||
)
|
||||
from pymongo.asynchronous.server_selectors import any_server_selector
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.ssl_support import get_ssl_context
|
||||
from pymongo.write_concern import WriteConcern, validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.codec_options import CodecOptions
|
||||
from pymongo.asynchronous.auth import MongoCredential
|
||||
from pymongo.asynchronous.encryption_options import AutoEncryptionOpts
|
||||
from pymongo.asynchronous.topology_description import _ServerSelector
|
||||
from pymongo.pyopenssl_context import SSLContext
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _parse_credentials(
|
||||
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
) -> Optional[MongoCredential]:
|
||||
"""Parse authentication credentials."""
|
||||
mechanism = options.get("authmechanism", "DEFAULT" if username else None)
|
||||
source = options.get("authsource")
|
||||
if username or mechanism:
|
||||
from pymongo.asynchronous.auth import _build_credentials_tuple
|
||||
|
||||
return _build_credentials_tuple(mechanism, source, username, password, options, database)
|
||||
return None
|
||||
|
||||
|
||||
def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode:
|
||||
"""Parse read preference options."""
|
||||
if "read_preference" in options:
|
||||
return options["read_preference"]
|
||||
|
||||
name = options.get("readpreference", "primary")
|
||||
mode = read_pref_mode_from_name(name)
|
||||
tags = options.get("readpreferencetags")
|
||||
max_staleness = options.get("maxstalenessseconds", -1)
|
||||
return make_read_preference(mode, tags, max_staleness)
|
||||
|
||||
|
||||
def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern:
|
||||
"""Parse write concern options."""
|
||||
concern = options.get("w")
|
||||
wtimeout = options.get("wtimeoutms")
|
||||
j = options.get("journal")
|
||||
fsync = options.get("fsync")
|
||||
return WriteConcern(concern, wtimeout, j, fsync)
|
||||
|
||||
|
||||
def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
|
||||
"""Parse read concern options."""
|
||||
concern = options.get("readconcernlevel")
|
||||
return ReadConcern(concern)
|
||||
|
||||
|
||||
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
|
||||
"""Parse ssl options."""
|
||||
use_tls = options.get("tls")
|
||||
if use_tls is not None:
|
||||
validate_boolean("tls", use_tls)
|
||||
|
||||
certfile = options.get("tlscertificatekeyfile")
|
||||
passphrase = options.get("tlscertificatekeyfilepassword")
|
||||
ca_certs = options.get("tlscafile")
|
||||
crlfile = options.get("tlscrlfile")
|
||||
allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False)
|
||||
allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False)
|
||||
disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False)
|
||||
|
||||
enabled_tls_opts = []
|
||||
for opt in (
|
||||
"tlscertificatekeyfile",
|
||||
"tlscertificatekeyfilepassword",
|
||||
"tlscafile",
|
||||
"tlscrlfile",
|
||||
):
|
||||
# Any non-null value of these options implies tls=True.
|
||||
if opt in options and options[opt]:
|
||||
enabled_tls_opts.append(opt)
|
||||
for opt in (
|
||||
"tlsallowinvalidcertificates",
|
||||
"tlsallowinvalidhostnames",
|
||||
"tlsdisableocspendpointcheck",
|
||||
):
|
||||
# A value of False for these options implies tls=True.
|
||||
if opt in options and not options[opt]:
|
||||
enabled_tls_opts.append(opt)
|
||||
|
||||
if enabled_tls_opts:
|
||||
if use_tls is None:
|
||||
# Implicitly enable TLS when one of the tls* options is set.
|
||||
use_tls = True
|
||||
elif not use_tls:
|
||||
# Error since tls is explicitly disabled but a tls option is set.
|
||||
raise ConfigurationError(
|
||||
"TLS has not been enabled but the "
|
||||
"following tls parameters have been set: "
|
||||
"%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts)
|
||||
)
|
||||
|
||||
if use_tls:
|
||||
ctx = get_ssl_context(
|
||||
certfile,
|
||||
passphrase,
|
||||
ca_certs,
|
||||
crlfile,
|
||||
allow_invalid_certificates,
|
||||
allow_invalid_hostnames,
|
||||
disable_ocsp_endpoint_check,
|
||||
)
|
||||
return ctx, allow_invalid_hostnames
|
||||
return None, allow_invalid_hostnames
|
||||
|
||||
|
||||
def _parse_pool_options(
|
||||
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
) -> PoolOptions:
|
||||
"""Parse connection pool options."""
|
||||
credentials = _parse_credentials(username, password, database, options)
|
||||
max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE)
|
||||
min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE)
|
||||
max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC)
|
||||
if max_pool_size is not None and min_pool_size > max_pool_size:
|
||||
raise ValueError("minPoolSize must be smaller or equal to maxPoolSize")
|
||||
connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT)
|
||||
socket_timeout = options.get("sockettimeoutms")
|
||||
wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT)
|
||||
event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners"))
|
||||
appname = options.get("appname")
|
||||
driver = options.get("driver")
|
||||
server_api = options.get("server_api")
|
||||
compression_settings = CompressionSettings(
|
||||
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
|
||||
)
|
||||
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
|
||||
load_balanced = options.get("loadbalanced")
|
||||
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
|
||||
return PoolOptions(
|
||||
max_pool_size,
|
||||
min_pool_size,
|
||||
max_idle_time_seconds,
|
||||
connect_timeout,
|
||||
socket_timeout,
|
||||
wait_queue_timeout,
|
||||
ssl_context,
|
||||
tls_allow_invalid_hostnames,
|
||||
_EventListeners(event_listeners),
|
||||
appname,
|
||||
driver,
|
||||
compression_settings,
|
||||
max_connecting=max_connecting,
|
||||
server_api=server_api,
|
||||
load_balanced=load_balanced,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
|
||||
class ClientOptions:
|
||||
"""Read only configuration options for an AsyncMongoClient.
|
||||
|
||||
Should not be instantiated directly by application developers. Access
|
||||
a client's options via :attr:`pymongo.mongo_client.AsyncMongoClient.options`
|
||||
instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
):
|
||||
self.__options = options
|
||||
self.__codec_options = _parse_codec_options(options)
|
||||
self.__direct_connection = options.get("directconnection")
|
||||
self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS)
|
||||
# self.__server_selection_timeout is in seconds. Must use full name for
|
||||
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
|
||||
self.__server_selection_timeout = options.get(
|
||||
"serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT
|
||||
)
|
||||
self.__pool_options = _parse_pool_options(username, password, database, options)
|
||||
self.__read_preference = _parse_read_preference(options)
|
||||
self.__replica_set_name = options.get("replicaset")
|
||||
self.__write_concern = _parse_write_concern(options)
|
||||
self.__read_concern = _parse_read_concern(options)
|
||||
self.__connect = options.get("connect")
|
||||
self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY)
|
||||
self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES)
|
||||
self.__retry_reads = options.get("retryreads", common.RETRY_READS)
|
||||
self.__server_selector = options.get("server_selector", any_server_selector)
|
||||
self.__auto_encryption_opts = options.get("auto_encryption_opts")
|
||||
self.__load_balanced = options.get("loadbalanced")
|
||||
self.__timeout = options.get("timeoutms")
|
||||
self.__server_monitoring_mode = options.get(
|
||||
"servermonitoringmode", common.SERVER_MONITORING_MODE
|
||||
)
|
||||
|
||||
@property
|
||||
def _options(self) -> Mapping[str, Any]:
|
||||
"""The original options used to create this ClientOptions."""
|
||||
return self.__options
|
||||
|
||||
@property
|
||||
def connect(self) -> Optional[bool]:
|
||||
"""Whether to begin discovering a MongoDB topology automatically."""
|
||||
return self.__connect
|
||||
|
||||
@property
|
||||
def codec_options(self) -> CodecOptions:
|
||||
"""A :class:`~bson.codec_options.CodecOptions` instance."""
|
||||
return self.__codec_options
|
||||
|
||||
@property
|
||||
def direct_connection(self) -> Optional[bool]:
|
||||
"""Whether to connect to the deployment in 'Single' topology."""
|
||||
return self.__direct_connection
|
||||
|
||||
@property
|
||||
def local_threshold_ms(self) -> int:
|
||||
"""The local threshold for this instance."""
|
||||
return self.__local_threshold_ms
|
||||
|
||||
@property
|
||||
def server_selection_timeout(self) -> int:
|
||||
"""The server selection timeout for this instance in seconds."""
|
||||
return self.__server_selection_timeout
|
||||
|
||||
@property
|
||||
def server_selector(self) -> _ServerSelector:
|
||||
return self.__server_selector
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self) -> int:
|
||||
"""The monitoring frequency in seconds."""
|
||||
return self.__heartbeat_frequency
|
||||
|
||||
@property
|
||||
def pool_options(self) -> PoolOptions:
|
||||
"""A :class:`~pymongo.pool.PoolOptions` instance."""
|
||||
return self.__pool_options
|
||||
|
||||
@property
|
||||
def read_preference(self) -> _ServerMode:
|
||||
"""A read preference instance."""
|
||||
return self.__read_preference
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""Replica set name or None."""
|
||||
return self.__replica_set_name
|
||||
|
||||
@property
|
||||
def write_concern(self) -> WriteConcern:
|
||||
"""A :class:`~pymongo.write_concern.WriteConcern` instance."""
|
||||
return self.__write_concern
|
||||
|
||||
@property
|
||||
def read_concern(self) -> ReadConcern:
|
||||
"""A :class:`~pymongo.read_concern.ReadConcern` instance."""
|
||||
return self.__read_concern
|
||||
|
||||
@property
|
||||
def timeout(self) -> Optional[float]:
|
||||
"""The configured timeoutMS converted to seconds, or None.
|
||||
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
return self.__timeout
|
||||
|
||||
@property
|
||||
def retry_writes(self) -> bool:
|
||||
"""If this instance should retry supported write operations."""
|
||||
return self.__retry_writes
|
||||
|
||||
@property
|
||||
def retry_reads(self) -> bool:
|
||||
"""If this instance should retry supported read operations."""
|
||||
return self.__retry_reads
|
||||
|
||||
@property
|
||||
def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]:
|
||||
"""A :class:`~pymongo.encryption.AutoEncryptionOpts` or None."""
|
||||
return self.__auto_encryption_opts
|
||||
|
||||
@property
|
||||
def load_balanced(self) -> Optional[bool]:
|
||||
"""True if the client was configured to connect to a load balancer."""
|
||||
return self.__load_balanced
|
||||
|
||||
@property
|
||||
def event_listeners(self) -> list[_EventListeners]:
|
||||
"""The event listeners registered for this client.
|
||||
|
||||
See :mod:`~pymongo.monitoring` for details.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
assert self.__pool_options._event_listeners is not None
|
||||
return self.__pool_options._event_listeners.event_listeners()
|
||||
|
||||
@property
|
||||
def server_monitoring_mode(self) -> str:
|
||||
"""The configured serverMonitoringMode option.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
"""
|
||||
return self.__server_monitoring_mode
|
||||
@ -44,8 +44,8 @@ Transactions
|
||||
.. versionadded:: 3.7
|
||||
|
||||
MongoDB 4.0 adds support for transactions on replica set primaries. A
|
||||
transaction is associated with a :class:`ClientSession`. To start a transaction
|
||||
on a session, use :meth:`ClientSession.start_transaction` in a with-statement.
|
||||
transaction is associated with a :class:`AsyncClientSession`. To start a transaction
|
||||
on a session, use :meth:`AsyncClientSession.start_transaction` in a with-statement.
|
||||
Then, execute an operation within the transaction by passing the session to the
|
||||
operation:
|
||||
|
||||
@ -63,9 +63,9 @@ operation:
|
||||
)
|
||||
|
||||
Upon normal completion of ``async with session.start_transaction()`` block, the
|
||||
transaction automatically calls :meth:`ClientSession.commit_transaction`.
|
||||
transaction automatically calls :meth:`AsyncClientSession.commit_transaction`.
|
||||
If the block exits with an exception, the transaction automatically calls
|
||||
:meth:`ClientSession.abort_transaction`.
|
||||
:meth:`AsyncClientSession.abort_transaction`.
|
||||
|
||||
In general, multi-document transactions only support read/write (CRUD)
|
||||
operations on existing collections. However, MongoDB 4.4 adds support for
|
||||
@ -157,8 +157,6 @@ from bson.int64 import Int64
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot
|
||||
from pymongo.asynchronous.cursor import _ConnectionManager
|
||||
from pymongo.asynchronous.operations import _Op
|
||||
from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
ConnectionFailure,
|
||||
@ -167,23 +165,25 @@ from pymongo.errors import (
|
||||
PyMongoError,
|
||||
WTimeoutError,
|
||||
)
|
||||
from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.asynchronous.server import Server
|
||||
from pymongo.asynchronous.typings import ClusterTime, _Address
|
||||
from pymongo.typings import ClusterTime, _Address
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class SessionOptions:
|
||||
"""Options for a new :class:`ClientSession`.
|
||||
"""Options for a new :class:`AsyncClientSession`.
|
||||
|
||||
:param causal_consistency: If True, read operations are causally
|
||||
ordered within the session. Defaults to True when the ``snapshot``
|
||||
@ -246,7 +246,7 @@ class SessionOptions:
|
||||
|
||||
|
||||
class TransactionOptions:
|
||||
"""Options for :meth:`ClientSession.start_transaction`.
|
||||
"""Options for :meth:`AsyncClientSession.start_transaction`.
|
||||
|
||||
:param read_concern: The
|
||||
:class:`~pymongo.read_concern.ReadConcern` to use for this transaction.
|
||||
@ -336,8 +336,8 @@ class TransactionOptions:
|
||||
|
||||
|
||||
def _validate_session_write_concern(
|
||||
session: Optional[ClientSession], write_concern: Optional[WriteConcern]
|
||||
) -> Optional[ClientSession]:
|
||||
session: Optional[AsyncClientSession], write_concern: Optional[WriteConcern]
|
||||
) -> Optional[AsyncClientSession]:
|
||||
"""Validate that an explicit session is not used with an unack'ed write.
|
||||
|
||||
Returns the session to use for the next operation.
|
||||
@ -362,7 +362,7 @@ def _validate_session_write_concern(
|
||||
class _TransactionContext:
|
||||
"""Internal transaction context manager for start_transaction."""
|
||||
|
||||
def __init__(self, session: ClientSession):
|
||||
def __init__(self, session: AsyncClientSession):
|
||||
self.__session = session
|
||||
|
||||
async def __aenter__(self) -> _TransactionContext:
|
||||
@ -391,7 +391,7 @@ class _TxnState:
|
||||
|
||||
|
||||
class _Transaction:
|
||||
"""Internal class to hold transaction information in a ClientSession."""
|
||||
"""Internal class to hold transaction information in a AsyncClientSession."""
|
||||
|
||||
def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient):
|
||||
self.opts = opts
|
||||
@ -410,12 +410,12 @@ class _Transaction:
|
||||
return self.state == _TxnState.STARTING
|
||||
|
||||
@property
|
||||
def pinned_conn(self) -> Optional[Connection]:
|
||||
def pinned_conn(self) -> Optional[AsyncConnection]:
|
||||
if self.active() and self.conn_mgr:
|
||||
return self.conn_mgr.conn
|
||||
return None
|
||||
|
||||
def pin(self, server: Server, conn: Connection) -> None:
|
||||
def pin(self, server: Server, conn: AsyncConnection) -> None:
|
||||
self.sharded = True
|
||||
self.pinned_address = server.description.address
|
||||
if server.description.server_type == SERVER_TYPE.LoadBalancer:
|
||||
@ -481,16 +481,16 @@ if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
|
||||
|
||||
class ClientSession:
|
||||
class AsyncClientSession:
|
||||
"""A session for ordering sequential operations.
|
||||
|
||||
:class:`ClientSession` instances are **not thread-safe or fork-safe**.
|
||||
:class:`AsyncClientSession` instances are **not thread-safe or fork-safe**.
|
||||
They can only be used by one thread or process at a time. A single
|
||||
:class:`ClientSession` cannot be used to run multiple operations
|
||||
:class:`AsyncClientSession` cannot be used to run multiple operations
|
||||
concurrently.
|
||||
|
||||
Should not be initialized directly by application developers - to create a
|
||||
:class:`ClientSession`, call
|
||||
:class:`AsyncClientSession`, call
|
||||
:meth:`~pymongo.mongo_client.AsyncMongoClient.start_session`.
|
||||
"""
|
||||
|
||||
@ -541,7 +541,7 @@ class ClientSession:
|
||||
if self._server_session is None:
|
||||
raise InvalidOperation("Cannot use ended session")
|
||||
|
||||
async def __aenter__(self) -> ClientSession:
|
||||
async def __aenter__(self) -> AsyncClientSession:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
@ -560,14 +560,14 @@ class ClientSession:
|
||||
return self._options
|
||||
|
||||
@property
|
||||
async def session_id(self) -> Mapping[str, Any]:
|
||||
def session_id(self) -> Mapping[str, Any]:
|
||||
"""A BSON document, the opaque server session identifier."""
|
||||
self._check_ended()
|
||||
self._materialize(self._client.topology_description.logical_session_timeout_minutes)
|
||||
return self._server_session.session_id
|
||||
|
||||
@property
|
||||
async def _transaction_id(self) -> Int64:
|
||||
def _transaction_id(self) -> Int64:
|
||||
"""The current transaction id for the underlying server session."""
|
||||
self._materialize(self._client.topology_description.logical_session_timeout_minutes)
|
||||
return self._server_session.transaction_id
|
||||
@ -598,7 +598,7 @@ class ClientSession:
|
||||
|
||||
async def with_transaction(
|
||||
self,
|
||||
callback: Callable[[ClientSession], _T],
|
||||
callback: Callable[[AsyncClientSession], _T],
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
@ -646,25 +646,25 @@ class ClientSession:
|
||||
however, ``with_transaction`` will return without taking further
|
||||
action.
|
||||
|
||||
:class:`ClientSession` instances are **not thread-safe or fork-safe**.
|
||||
:class:`AsyncClientSession` instances are **not thread-safe or fork-safe**.
|
||||
Consequently, the ``callback`` must not attempt to execute multiple
|
||||
operations concurrently.
|
||||
|
||||
When ``callback`` raises an exception, ``with_transaction``
|
||||
automatically aborts the current transaction. When ``callback`` or
|
||||
:meth:`~ClientSession.commit_transaction` raises an exception that
|
||||
:meth:`~AsyncClientSession.commit_transaction` raises an exception that
|
||||
includes the ``"TransientTransactionError"`` error label,
|
||||
``with_transaction`` starts a new transaction and re-executes
|
||||
the ``callback``.
|
||||
|
||||
When :meth:`~ClientSession.commit_transaction` raises an exception with
|
||||
When :meth:`~AsyncClientSession.commit_transaction` raises an exception with
|
||||
the ``"UnknownTransactionCommitResult"`` error label,
|
||||
``with_transaction`` retries the commit until the result of the
|
||||
transaction is known.
|
||||
|
||||
This method will cease retrying after 120 seconds has elapsed. This
|
||||
timeout is not configurable and any exception raised by the
|
||||
``callback`` or by :meth:`ClientSession.commit_transaction` after the
|
||||
``callback`` or by :meth:`AsyncClientSession.commit_transaction` after the
|
||||
timeout is reached will be re-raised. Applications that desire a
|
||||
different timeout duration should not use this method.
|
||||
|
||||
@ -850,7 +850,7 @@ class ClientSession:
|
||||
"""
|
||||
|
||||
async def func(
|
||||
_session: Optional[ClientSession], conn: Connection, _retryable: bool
|
||||
_session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable: bool
|
||||
) -> dict[str, Any]:
|
||||
return await self._finish_transaction(conn, command_name)
|
||||
|
||||
@ -858,7 +858,7 @@ class ClientSession:
|
||||
func, self, None, retryable=True, operation=_Op.ABORT
|
||||
)
|
||||
|
||||
async def _finish_transaction(self, conn: Connection, command_name: str) -> dict[str, Any]:
|
||||
async def _finish_transaction(self, conn: AsyncConnection, command_name: str) -> dict[str, Any]:
|
||||
self._transaction.attempt += 1
|
||||
opts = self._transaction.opts
|
||||
assert opts
|
||||
@ -897,8 +897,8 @@ class ClientSession:
|
||||
"""Update the cluster time for this session.
|
||||
|
||||
:param cluster_time: The
|
||||
:data:`~pymongo.client_session.ClientSession.cluster_time` from
|
||||
another `ClientSession` instance.
|
||||
:data:`~pymongo.client_session.AsyncClientSession.cluster_time` from
|
||||
another `AsyncClientSession` instance.
|
||||
"""
|
||||
if not isinstance(cluster_time, _Mapping):
|
||||
raise TypeError("cluster_time must be a subclass of collections.Mapping")
|
||||
@ -918,8 +918,8 @@ class ClientSession:
|
||||
"""Update the operation time for this session.
|
||||
|
||||
:param operation_time: The
|
||||
:data:`~pymongo.client_session.ClientSession.operation_time` from
|
||||
another `ClientSession` instance.
|
||||
:data:`~pymongo.client_session.AsyncClientSession.operation_time` from
|
||||
another `AsyncClientSession` instance.
|
||||
"""
|
||||
if not isinstance(operation_time, Timestamp):
|
||||
raise TypeError("operation_time must be an instance of bson.timestamp.Timestamp")
|
||||
@ -966,11 +966,11 @@ class ClientSession:
|
||||
return None
|
||||
|
||||
@property
|
||||
def _pinned_connection(self) -> Optional[Connection]:
|
||||
def _pinned_connection(self) -> Optional[AsyncConnection]:
|
||||
"""The connection this transaction was started on."""
|
||||
return self._transaction.pinned_conn
|
||||
|
||||
def _pin(self, server: Server, conn: Connection) -> None:
|
||||
def _pin(self, server: Server, conn: AsyncConnection) -> None:
|
||||
"""Pin this session to the given Server or to the given connection."""
|
||||
self._transaction.pin(server, conn)
|
||||
|
||||
@ -999,7 +999,7 @@ class ClientSession:
|
||||
command: MutableMapping[str, Any],
|
||||
is_retryable: bool,
|
||||
read_preference: _ServerMode,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
) -> None:
|
||||
if not conn.supports_sessions:
|
||||
if not self._implicit:
|
||||
@ -1042,7 +1042,7 @@ class ClientSession:
|
||||
self._check_ended()
|
||||
self._server_session.inc_transaction_id()
|
||||
|
||||
def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: Connection) -> None:
|
||||
def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: AsyncConnection) -> None:
|
||||
if self.options.causal_consistency and self.operation_time is not None:
|
||||
cmd.setdefault("readConcern", {})["afterClusterTime"] = self.operation_time
|
||||
if self.options.snapshot:
|
||||
@ -1054,7 +1054,7 @@ class ClientSession:
|
||||
rc["atClusterTime"] = self._snapshot_time
|
||||
|
||||
def __copy__(self) -> NoReturn:
|
||||
raise TypeError("A ClientSession cannot be copied, create a new session instead")
|
||||
raise TypeError("A AsyncClientSession cannot be copied, create a new session instead")
|
||||
|
||||
|
||||
class _EmptyServerSession:
|
||||
|
||||
@ -1,226 +0,0 @@
|
||||
# Copyright 2016 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tools for working with `collations`_.
|
||||
|
||||
.. _collations: https://www.mongodb.com/docs/manual/reference/collation/
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping, Optional, Union
|
||||
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class CollationStrength:
|
||||
"""
|
||||
An enum that defines values for `strength` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
PRIMARY = 1
|
||||
"""Differentiate base (unadorned) characters."""
|
||||
|
||||
SECONDARY = 2
|
||||
"""Differentiate character accents."""
|
||||
|
||||
TERTIARY = 3
|
||||
"""Differentiate character case."""
|
||||
|
||||
QUATERNARY = 4
|
||||
"""Differentiate words with and without punctuation."""
|
||||
|
||||
IDENTICAL = 5
|
||||
"""Differentiate unicode code point (characters are exactly identical)."""
|
||||
|
||||
|
||||
class CollationAlternate:
|
||||
"""
|
||||
An enum that defines values for `alternate` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
NON_IGNORABLE = "non-ignorable"
|
||||
"""Spaces and punctuation are treated as base characters."""
|
||||
|
||||
SHIFTED = "shifted"
|
||||
"""Spaces and punctuation are *not* considered base characters.
|
||||
|
||||
Spaces and punctuation are distinguished regardless when the
|
||||
:class:`~pymongo.collation.Collation` strength is at least
|
||||
:data:`~pymongo.collation.CollationStrength.QUATERNARY`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class CollationMaxVariable:
|
||||
"""
|
||||
An enum that defines values for `max_variable` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
PUNCT = "punct"
|
||||
"""Both punctuation and spaces are ignored."""
|
||||
|
||||
SPACE = "space"
|
||||
"""Spaces alone are ignored."""
|
||||
|
||||
|
||||
class CollationCaseFirst:
|
||||
"""
|
||||
An enum that defines values for `case_first` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
UPPER = "upper"
|
||||
"""Sort uppercase characters first."""
|
||||
|
||||
LOWER = "lower"
|
||||
"""Sort lowercase characters first."""
|
||||
|
||||
OFF = "off"
|
||||
"""Default for locale or collation strength."""
|
||||
|
||||
|
||||
class Collation:
|
||||
"""Collation
|
||||
|
||||
:param locale: (string) The locale of the collation. This should be a string
|
||||
that identifies an `ICU locale ID` exactly. For example, ``en_US`` is
|
||||
valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB
|
||||
documentation for a list of supported locales.
|
||||
:param caseLevel: (optional) If ``True``, turn on case sensitivity if
|
||||
`strength` is 1 or 2 (case sensitivity is implied if `strength` is
|
||||
greater than 2). Defaults to ``False``.
|
||||
:param caseFirst: (optional) Specify that either uppercase or lowercase
|
||||
characters take precedence. Must be one of the following values:
|
||||
|
||||
* :data:`~CollationCaseFirst.UPPER`
|
||||
* :data:`~CollationCaseFirst.LOWER`
|
||||
* :data:`~CollationCaseFirst.OFF` (the default)
|
||||
|
||||
:param strength: Specify the comparison strength. This is also
|
||||
known as the ICU comparison level. This must be one of the following
|
||||
values:
|
||||
|
||||
* :data:`~CollationStrength.PRIMARY`
|
||||
* :data:`~CollationStrength.SECONDARY`
|
||||
* :data:`~CollationStrength.TERTIARY` (the default)
|
||||
* :data:`~CollationStrength.QUATERNARY`
|
||||
* :data:`~CollationStrength.IDENTICAL`
|
||||
|
||||
Each successive level builds upon the previous. For example, a
|
||||
`strength` of :data:`~CollationStrength.SECONDARY` differentiates
|
||||
characters based both on the unadorned base character and its accents.
|
||||
|
||||
:param numericOrdering: If ``True``, order numbers numerically
|
||||
instead of in collation order (defaults to ``False``).
|
||||
:param alternate: Specify whether spaces and punctuation are
|
||||
considered base characters. This must be one of the following values:
|
||||
|
||||
* :data:`~CollationAlternate.NON_IGNORABLE` (the default)
|
||||
* :data:`~CollationAlternate.SHIFTED`
|
||||
|
||||
:param maxVariable: When `alternate` is
|
||||
:data:`~CollationAlternate.SHIFTED`, this option specifies what
|
||||
characters may be ignored. This must be one of the following values:
|
||||
|
||||
* :data:`~CollationMaxVariable.PUNCT` (the default)
|
||||
* :data:`~CollationMaxVariable.SPACE`
|
||||
|
||||
:param normalization: If ``True``, normalizes text into Unicode
|
||||
NFD. Defaults to ``False``.
|
||||
:param backwards: If ``True``, accents on characters are
|
||||
considered from the back of the word to the front, as it is done in some
|
||||
French dictionary ordering traditions. Defaults to ``False``.
|
||||
:param kwargs: Keyword arguments supplying any additional options
|
||||
to be sent with this Collation object.
|
||||
|
||||
.. versionadded: 3.4
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
locale: str,
|
||||
caseLevel: Optional[bool] = None,
|
||||
caseFirst: Optional[str] = None,
|
||||
strength: Optional[int] = None,
|
||||
numericOrdering: Optional[bool] = None,
|
||||
alternate: Optional[str] = None,
|
||||
maxVariable: Optional[str] = None,
|
||||
normalization: Optional[bool] = None,
|
||||
backwards: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
locale = common.validate_string("locale", locale)
|
||||
self.__document: dict[str, Any] = {"locale": locale}
|
||||
if caseLevel is not None:
|
||||
self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel)
|
||||
if caseFirst is not None:
|
||||
self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst)
|
||||
if strength is not None:
|
||||
self.__document["strength"] = common.validate_integer("strength", strength)
|
||||
if numericOrdering is not None:
|
||||
self.__document["numericOrdering"] = validate_boolean(
|
||||
"numericOrdering", numericOrdering
|
||||
)
|
||||
if alternate is not None:
|
||||
self.__document["alternate"] = common.validate_string("alternate", alternate)
|
||||
if maxVariable is not None:
|
||||
self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable)
|
||||
if normalization is not None:
|
||||
self.__document["normalization"] = validate_boolean("normalization", normalization)
|
||||
if backwards is not None:
|
||||
self.__document["backwards"] = validate_boolean("backwards", backwards)
|
||||
self.__document.update(kwargs)
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""The document representation of this collation.
|
||||
|
||||
.. note::
|
||||
:class:`Collation` is immutable. Mutating the value of
|
||||
:attr:`document` does not mutate this :class:`Collation`.
|
||||
"""
|
||||
return self.__document.copy()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
document = self.document
|
||||
return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Collation):
|
||||
return self.document == other.document
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
def validate_collation_or_none(
|
||||
value: Optional[Union[Mapping[str, Any], Collation]]
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, Collation):
|
||||
return value.document
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")
|
||||
@ -42,27 +42,32 @@ from bson.objectid import ObjectId
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.son import SON
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import ASCENDING, _csot
|
||||
from pymongo.asynchronous import common, helpers, message
|
||||
from pymongo import ASCENDING, _csot, common, helpers_shared, message
|
||||
from pymongo.asynchronous.aggregation import (
|
||||
_CollectionAggregationCommand,
|
||||
_CollectionRawAggregationCommand,
|
||||
)
|
||||
from pymongo.asynchronous.bulk import _Bulk
|
||||
from pymongo.asynchronous.bulk import _AsyncBulk
|
||||
from pymongo.asynchronous.change_stream import CollectionChangeStream
|
||||
from pymongo.asynchronous.collation import validate_collation_or_none
|
||||
from pymongo.asynchronous.command_cursor import (
|
||||
AsyncCommandCursor,
|
||||
AsyncRawBatchCommandCursor,
|
||||
)
|
||||
from pymongo.asynchronous.common import _ecoc_coll_name, _esc_coll_name
|
||||
from pymongo.asynchronous.cursor import (
|
||||
AsyncCursor,
|
||||
AsyncRawBatchCursor,
|
||||
)
|
||||
from pymongo.asynchronous.helpers import _check_write_command_response
|
||||
from pymongo.asynchronous.message import _UNICODE_REPLACE_CODEC_OPTIONS
|
||||
from pymongo.asynchronous.operations import (
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.common import _ecoc_coll_name, _esc_coll_name
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
InvalidName,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.helpers_shared import _check_write_command_response
|
||||
from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS
|
||||
from pymongo.operations import (
|
||||
DeleteMany,
|
||||
DeleteOne,
|
||||
IndexModel,
|
||||
@ -75,15 +80,8 @@ from pymongo.asynchronous.operations import (
|
||||
_IndexList,
|
||||
_Op,
|
||||
)
|
||||
from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
InvalidName,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.read_concern import DEFAULT_READ_CONCERN
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.results import (
|
||||
BulkWriteResult,
|
||||
DeleteResult,
|
||||
@ -91,6 +89,7 @@ from pymongo.results import (
|
||||
InsertOneResult,
|
||||
UpdateResult,
|
||||
)
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
|
||||
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean
|
||||
|
||||
_IS_SYNC = False
|
||||
@ -127,11 +126,11 @@ class ReturnDocument:
|
||||
if TYPE_CHECKING:
|
||||
import bson
|
||||
from pymongo.asynchronous.aggregation import _AggregationCommand
|
||||
from pymongo.asynchronous.client_session import ClientSession
|
||||
from pymongo.asynchronous.collation import Collation
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.asynchronous.server import Server
|
||||
from pymongo.collation import Collation
|
||||
from pymongo.read_concern import ReadConcern
|
||||
|
||||
|
||||
@ -147,7 +146,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Get / create an asynchronous Mongo collection.
|
||||
@ -373,7 +372,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
def _write_concern_for_cmd(
|
||||
self, cmd: Mapping[str, Any], session: Optional[ClientSession]
|
||||
self, cmd: Mapping[str, Any], session: Optional[AsyncClientSession]
|
||||
) -> WriteConcern:
|
||||
raw_wc = cmd.get("writeConcern")
|
||||
if raw_wc is not None:
|
||||
@ -413,7 +412,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
batch_size: Optional[int] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
start_at_operation_time: Optional[Timestamp] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
start_after: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
full_document_before_change: Optional[str] = None,
|
||||
@ -494,7 +493,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
the specified :class:`~bson.timestamp.Timestamp`. Requires
|
||||
MongoDB >= 4.0.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param start_after: The same as `resume_after` except that
|
||||
`start_after` can resume notifications after an invalidate event.
|
||||
This option and `resume_after` are mutually exclusive.
|
||||
@ -546,13 +545,13 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
return change_stream
|
||||
|
||||
async def _conn_for_writes(
|
||||
self, session: Optional[ClientSession], operation: str
|
||||
) -> AsyncContextManager[Connection]:
|
||||
self, session: Optional[AsyncClientSession], operation: str
|
||||
) -> AsyncContextManager[AsyncConnection]:
|
||||
return await self._database.client._conn_for_writes(session, operation)
|
||||
|
||||
async def _command(
|
||||
self,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
command: MutableMapping[str, Any],
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
@ -561,13 +560,13 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
retryable_write: bool = False,
|
||||
user_fields: Optional[Any] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""Internal command helper.
|
||||
|
||||
:param conn` - A Connection instance.
|
||||
:param conn` - A AsyncConnection instance.
|
||||
:param command` - The command itself, as a :class:`~bson.son.SON` instance.
|
||||
:param read_preference` (optional) - The read preference to use.
|
||||
:param codec_options` (optional) - An instance of
|
||||
@ -581,7 +580,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
:param collation` (optional) - An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param retryable_write: True if this command is a retryable
|
||||
write.
|
||||
:param user_fields: Response fields that should be decoded
|
||||
@ -613,7 +612,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
name: str,
|
||||
options: MutableMapping[str, Any],
|
||||
collation: Optional[_CollationIn],
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
encrypted_fields: Optional[Mapping[str, Any]] = None,
|
||||
qev2_required: bool = False,
|
||||
) -> None:
|
||||
@ -646,7 +645,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def _create(
|
||||
self,
|
||||
options: MutableMapping[str, Any],
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
) -> None:
|
||||
collation = validate_collation_or_none(options.pop("collation", None))
|
||||
encrypted_fields = options.pop("encryptedFields", None)
|
||||
@ -676,7 +675,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
requests: Sequence[_WriteOp[_DocumentType]],
|
||||
ordered: bool = True,
|
||||
bypass_document_validation: bool = False,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
let: Optional[Mapping] = None,
|
||||
) -> BulkWriteResult:
|
||||
@ -726,7 +725,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
write to opt-out of document level validation. Default is
|
||||
``False``.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param let: Map of parameter names and values. Values must be
|
||||
@ -755,7 +754,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
"""
|
||||
common.validate_list("requests", requests)
|
||||
|
||||
blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let)
|
||||
blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let)
|
||||
for request in requests:
|
||||
try:
|
||||
request._add_to_bulk(blk)
|
||||
@ -775,7 +774,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
write_concern: WriteConcern,
|
||||
op_id: Optional[int],
|
||||
bypass_doc_val: bool,
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
comment: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""Internal helper for inserting a single document."""
|
||||
@ -786,7 +785,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
command["comment"] = comment
|
||||
|
||||
async def _insert_command(
|
||||
session: Optional[ClientSession], conn: Connection, retryable_write: bool
|
||||
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
|
||||
) -> None:
|
||||
if bypass_doc_val:
|
||||
command["bypassDocumentValidation"] = True
|
||||
@ -815,7 +814,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
document: Union[_DocumentType, RawBSONDocument],
|
||||
bypass_document_validation: bool = False,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> InsertOneResult:
|
||||
"""Insert a single document.
|
||||
@ -835,7 +834,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
write to opt-out of document level validation. Default is
|
||||
``False``.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
|
||||
@ -881,7 +880,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
documents: Iterable[Union[_DocumentType, RawBSONDocument]],
|
||||
ordered: bool = True,
|
||||
bypass_document_validation: bool = False,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> InsertManyResult:
|
||||
"""Insert an iterable of documents.
|
||||
@ -904,7 +903,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
write to opt-out of document level validation. Default is
|
||||
``False``.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
|
||||
@ -945,14 +944,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
yield (message._INSERT, document)
|
||||
|
||||
write_concern = self._write_concern_for(session)
|
||||
blk = _Bulk(self, ordered, bypass_document_validation, comment=comment)
|
||||
blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment)
|
||||
blk.ops = list(gen())
|
||||
await blk.execute(write_concern, session, _Op.INSERT)
|
||||
return InsertManyResult(inserted_ids, write_concern.acknowledged)
|
||||
|
||||
async def _update(
|
||||
self,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
criteria: Mapping[str, Any],
|
||||
document: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
@ -964,7 +963,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
retryable_write: bool = False,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
@ -996,7 +995,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
|
||||
)
|
||||
if not isinstance(hint, str):
|
||||
hint = helpers._index_document(hint)
|
||||
hint = helpers_shared._index_document(hint)
|
||||
update_doc["hint"] = hint
|
||||
command = {"update": self.name, "ordered": ordered, "updates": [update_doc]}
|
||||
if let is not None:
|
||||
@ -1051,14 +1050,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""Internal update / replace helper."""
|
||||
|
||||
async def _update(
|
||||
session: Optional[ClientSession], conn: Connection, retryable_write: bool
|
||||
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
return await self._update(
|
||||
conn,
|
||||
@ -1094,7 +1093,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
bypass_document_validation: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> UpdateResult:
|
||||
@ -1143,7 +1142,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param let: Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
@ -1197,7 +1196,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> UpdateResult:
|
||||
@ -1252,7 +1251,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param let: Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
@ -1310,7 +1309,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> UpdateResult:
|
||||
@ -1352,7 +1351,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param let: Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
@ -1404,14 +1403,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def drop(
|
||||
self,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
encrypted_fields: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Alias for :meth:`~pymongo.database.AsyncDatabase.drop_collection`.
|
||||
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param encrypted_fields: **(BETA)** Document that describes the encrypted fields for
|
||||
@ -1447,7 +1446,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def _delete(
|
||||
self,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
criteria: Mapping[str, Any],
|
||||
multi: bool,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
@ -1455,7 +1454,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
ordered: bool = True,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
retryable_write: bool = False,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
@ -1477,7 +1476,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
|
||||
)
|
||||
if not isinstance(hint, str):
|
||||
hint = helpers._index_document(hint)
|
||||
hint = helpers_shared._index_document(hint)
|
||||
delete_doc["hint"] = hint
|
||||
command = {"delete": self.name, "ordered": ordered, "deletes": [delete_doc]}
|
||||
|
||||
@ -1510,14 +1509,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
ordered: bool = True,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""Internal delete helper."""
|
||||
|
||||
async def _delete(
|
||||
session: Optional[ClientSession], conn: Connection, retryable_write: bool
|
||||
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
|
||||
) -> Mapping[str, Any]:
|
||||
return await self._delete(
|
||||
conn,
|
||||
@ -1546,7 +1545,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
filter: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> DeleteResult:
|
||||
@ -1570,7 +1569,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.4 and above.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param let: Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
@ -1611,7 +1610,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
filter: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> DeleteResult:
|
||||
@ -1635,7 +1634,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.4 and above.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param let: Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
@ -1737,7 +1736,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
always be returned. Use a dict to exclude fields from
|
||||
the result (e.g. projection={'_id': False}).
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param skip: the number of documents to omit (from
|
||||
the start of the result set) when returning the results
|
||||
:param limit: the maximum number of results to
|
||||
@ -1931,8 +1930,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def _count_cmd(
|
||||
self,
|
||||
session: Optional[ClientSession],
|
||||
conn: Connection,
|
||||
session: Optional[AsyncClientSession],
|
||||
conn: AsyncConnection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
cmd: dict[str, Any],
|
||||
collation: Optional[Collation],
|
||||
@ -1956,11 +1955,11 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def _aggregate_one_result(
|
||||
self,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
cmd: dict[str, Any],
|
||||
collation: Optional[_CollationIn],
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""Internal helper to run an aggregate that returns a single result."""
|
||||
result = await self._command(
|
||||
@ -2012,9 +2011,9 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
kwargs["comment"] = comment
|
||||
|
||||
async def _cmd(
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
_server: Server,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
) -> int:
|
||||
cmd: dict[str, Any] = {"count": self._name}
|
||||
@ -2026,7 +2025,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def count_documents(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
@ -2072,7 +2071,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
to count in the collection. Can be an empty document to count all
|
||||
documents.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: See list of options above.
|
||||
@ -2095,14 +2094,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
pipeline.append({"$group": {"_id": 1, "n": {"$sum": 1}}})
|
||||
cmd = {"aggregate": self._name, "pipeline": pipeline, "cursor": {}}
|
||||
if "hint" in kwargs and not isinstance(kwargs["hint"], str):
|
||||
kwargs["hint"] = helpers._index_document(kwargs["hint"])
|
||||
kwargs["hint"] = helpers_shared._index_document(kwargs["hint"])
|
||||
collation = validate_collation_or_none(kwargs.pop("collation", None))
|
||||
cmd.update(kwargs)
|
||||
|
||||
async def _cmd(
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
_server: Server,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
) -> int:
|
||||
result = await self._aggregate_one_result(
|
||||
@ -2117,10 +2116,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def _retryable_non_cursor_read(
|
||||
self,
|
||||
func: Callable[
|
||||
[Optional[ClientSession], Server, Connection, Optional[_ServerMode]],
|
||||
[Optional[AsyncClientSession], Server, AsyncConnection, Optional[_ServerMode]],
|
||||
Coroutine[Any, Any, T],
|
||||
],
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
operation: str,
|
||||
) -> T:
|
||||
"""Non-cursor read helper to handle implicit session creation."""
|
||||
@ -2131,7 +2130,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def create_indexes(
|
||||
self,
|
||||
indexes: Sequence[IndexModel],
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
@ -2147,7 +2146,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
:param indexes: A list of :class:`~pymongo.operations.IndexModel`
|
||||
instances.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: optional arguments to the createIndexes
|
||||
@ -2177,14 +2176,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
@_csot.apply
|
||||
async def _create_indexes(
|
||||
self, indexes: Sequence[IndexModel], session: Optional[ClientSession], **kwargs: Any
|
||||
self, indexes: Sequence[IndexModel], session: Optional[AsyncClientSession], **kwargs: Any
|
||||
) -> list[str]:
|
||||
"""Internal createIndexes helper.
|
||||
|
||||
:param indexes: A list of :class:`~pymongo.operations.IndexModel`
|
||||
instances.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param kwargs: optional arguments to the createIndexes
|
||||
command (like maxTimeMS) can be passed as keyword arguments.
|
||||
"""
|
||||
@ -2223,7 +2222,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def create_index(
|
||||
self,
|
||||
keys: _IndexKeyHint,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -2299,7 +2298,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
:param keys: a single key or a list of (key, direction)
|
||||
pairs specifying the index to create
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: any additional index creation
|
||||
@ -2340,7 +2339,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def drop_indexes(
|
||||
self,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -2350,7 +2349,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
Raises OperationFailure on an error.
|
||||
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: optional arguments to the createIndexes
|
||||
@ -2375,7 +2374,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def drop_index(
|
||||
self,
|
||||
index_or_name: _IndexKeyHint,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -2397,7 +2396,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
:param index_or_name: index (or name of index) to drop
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: optional arguments to the createIndexes
|
||||
@ -2424,13 +2423,13 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def _drop_index(
|
||||
self,
|
||||
index_or_name: _IndexKeyHint,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
name = index_or_name
|
||||
if isinstance(index_or_name, list):
|
||||
name = helpers._gen_index_name(index_or_name)
|
||||
name = helpers_shared._gen_index_name(index_or_name)
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("index_or_name must be an instance of str or list")
|
||||
@ -2451,7 +2450,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def list_indexes(
|
||||
self,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> AsyncCommandCursor[MutableMapping[str, Any]]:
|
||||
"""Get a cursor over the index documents for this collection.
|
||||
@ -2462,7 +2461,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
SON([('v', 2), ('key', SON([('_id', 1)])), ('name', '_id_')])
|
||||
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
|
||||
@ -2480,7 +2479,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def _list_indexes(
|
||||
self,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> AsyncCommandCursor[MutableMapping[str, Any]]:
|
||||
codec_options: CodecOptions = CodecOptions(SON)
|
||||
@ -2492,9 +2491,9 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
explicit_session = session is not None
|
||||
|
||||
async def _cmd(
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
_server: Server,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
read_preference: _ServerMode,
|
||||
) -> AsyncCommandCursor[MutableMapping[str, Any]]:
|
||||
cmd = {"listIndexes": self._name, "cursor": {}}
|
||||
@ -2529,7 +2528,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def index_information(
|
||||
self,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> MutableMapping[str, Any]:
|
||||
"""Get information on this collection's indexes.
|
||||
@ -2551,7 +2550,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
'x_1': {'unique': True, 'key': [('x', 1)]}}
|
||||
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
|
||||
@ -2572,7 +2571,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def list_search_indexes(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCommandCursor[Mapping[str, Any]]:
|
||||
@ -2582,7 +2581,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
for. Only indexes with matching index names will be returned.
|
||||
If not given, all search indexes for the current collection
|
||||
will be returned.
|
||||
:param session: a :class:`~pymongo.client_session.ClientSession`.
|
||||
:param session: a :class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
|
||||
@ -2625,7 +2624,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def create_search_index(
|
||||
self,
|
||||
model: Union[Mapping[str, Any], SearchIndexModel],
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -2636,7 +2635,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
instance or a dictionary with a model "definition" and optional
|
||||
"name".
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: optional arguments to the createSearchIndexes
|
||||
@ -2655,14 +2654,14 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def create_search_indexes(
|
||||
self,
|
||||
models: list[SearchIndexModel],
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Create multiple search indexes for the current collection.
|
||||
|
||||
:param models: A list of :class:`~pymongo.operations.SearchIndexModel` instances.
|
||||
:param session: a :class:`~pymongo.client_session.ClientSession`.
|
||||
:param session: a :class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: optional arguments to the createSearchIndexes
|
||||
@ -2679,7 +2678,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def _create_search_indexes(
|
||||
self,
|
||||
models: list[SearchIndexModel],
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
@ -2711,7 +2710,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def drop_search_index(
|
||||
self,
|
||||
name: str,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -2719,7 +2718,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
:param name: The name of the search index to be deleted.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: optional arguments to the dropSearchIndexes
|
||||
@ -2746,7 +2745,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
name: str,
|
||||
definition: Mapping[str, Any],
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -2755,7 +2754,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
:param name: The name of the search index to be updated.
|
||||
:param definition: The new search index definition.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: optional arguments to the updateSearchIndexes
|
||||
@ -2780,7 +2779,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def options(
|
||||
self,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> MutableMapping[str, Any]:
|
||||
"""Get the options set on this collection.
|
||||
@ -2791,7 +2790,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
dictionary if the collection has not been created yet.
|
||||
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
|
||||
@ -2830,7 +2829,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
aggregation_command: Type[_AggregationCommand],
|
||||
pipeline: _Pipeline,
|
||||
cursor_class: Type[AsyncCommandCursor],
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
explicit_session: bool,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
@ -2859,7 +2858,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def aggregate(
|
||||
self,
|
||||
pipeline: _Pipeline,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
@ -2882,7 +2881,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
:param pipeline: a list of aggregation pipeline stages
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param let: A dict of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
@ -2954,7 +2953,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def aggregate_raw_batches(
|
||||
self,
|
||||
pipeline: _Pipeline,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncRawBatchCursor[_DocumentType]:
|
||||
@ -3003,7 +3002,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def rename(
|
||||
self,
|
||||
new_name: str,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> MutableMapping[str, Any]:
|
||||
@ -3017,7 +3016,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
:param new_name: new name for this collection
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: additional arguments to the rename command
|
||||
@ -3067,7 +3066,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
key: str,
|
||||
filter: Optional[Mapping[str, Any]] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> list:
|
||||
@ -3093,7 +3092,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
:param filter: A query document that specifies the documents
|
||||
from which to retrieve the distinct values.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: See list of options above.
|
||||
@ -3118,9 +3117,9 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd["comment"] = comment
|
||||
|
||||
async def _cmd(
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
_server: Server,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
) -> list:
|
||||
return (
|
||||
@ -3146,7 +3145,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
return_document: bool = ReturnDocument.BEFORE,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
let: Optional[Mapping] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -3163,20 +3162,20 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd["let"] = let
|
||||
cmd.update(kwargs)
|
||||
if projection is not None:
|
||||
cmd["fields"] = helpers._fields_list_to_dict(projection, "projection")
|
||||
cmd["fields"] = helpers_shared._fields_list_to_dict(projection, "projection")
|
||||
if sort is not None:
|
||||
cmd["sort"] = helpers._index_document(sort)
|
||||
cmd["sort"] = helpers_shared._index_document(sort)
|
||||
if upsert is not None:
|
||||
validate_boolean("upsert", upsert)
|
||||
cmd["upsert"] = upsert
|
||||
if hint is not None:
|
||||
if not isinstance(hint, str):
|
||||
hint = helpers._index_document(hint)
|
||||
hint = helpers_shared._index_document(hint)
|
||||
|
||||
write_concern = self._write_concern_for_cmd(cmd, session)
|
||||
|
||||
async def _find_and_modify_helper(
|
||||
session: Optional[ClientSession], conn: Connection, retryable_write: bool
|
||||
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
|
||||
) -> Any:
|
||||
acknowledged = write_concern.acknowledged
|
||||
if array_filters is not None:
|
||||
@ -3222,7 +3221,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
|
||||
sort: Optional[_IndexList] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
@ -3268,7 +3267,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
(e.g. ``[('field', ASCENDING)]``). This option is only supported
|
||||
on MongoDB 4.4 and above.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param let: Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
@ -3314,7 +3313,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
upsert: bool = False,
|
||||
return_document: bool = ReturnDocument.BEFORE,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
@ -3366,7 +3365,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.4 and above.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param let: Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
@ -3422,7 +3421,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
return_document: bool = ReturnDocument.BEFORE,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
@ -3513,7 +3512,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.4 and above.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param let: Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
|
||||
@ -30,22 +30,22 @@ from typing import (
|
||||
|
||||
from bson import CodecOptions, _convert_raw_document_lists_to_streams
|
||||
from pymongo.asynchronous.cursor import _ConnectionManager
|
||||
from pymongo.asynchronous.message import (
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.message import (
|
||||
_CursorAddress,
|
||||
_GetMore,
|
||||
_OpMsg,
|
||||
_OpReply,
|
||||
_RawBatchGetMore,
|
||||
)
|
||||
from pymongo.asynchronous.response import PinnedResponse
|
||||
from pymongo.asynchronous.typings import _Address, _DocumentOut, _DocumentType
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.response import PinnedResponse
|
||||
from pymongo.typings import _Address, _DocumentOut, _DocumentType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.client_session import ClientSession
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
@ -62,7 +62,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
address: Optional[_Address],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
explicit_session: bool = False,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
@ -133,7 +133,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
"""
|
||||
return self._postbatchresumetoken
|
||||
|
||||
async def _maybe_pin_connection(self, conn: Connection) -> None:
|
||||
async def _maybe_pin_connection(self, conn: AsyncConnection) -> None:
|
||||
client = self._collection.database.client
|
||||
if not client._should_pin_cursor(self._session):
|
||||
return
|
||||
@ -188,8 +188,8 @@ class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
return self._address
|
||||
|
||||
@property
|
||||
def session(self) -> Optional[ClientSession]:
|
||||
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
|
||||
def session(self) -> Optional[AsyncClientSession]:
|
||||
"""The cursor's :class:`~pymongo.client_session.AsyncClientSession`, or None.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
"""
|
||||
@ -272,7 +272,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self._sock_mgr:
|
||||
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
|
||||
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type]
|
||||
if response.from_command:
|
||||
cursor = response.docs[0]["cursor"]
|
||||
documents = cursor["nextBatch"]
|
||||
@ -384,7 +384,7 @@ class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):
|
||||
address: Optional[_Address],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
explicit_session: bool = False,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,178 +0,0 @@
|
||||
# Copyright 2018 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
from pymongo.asynchronous.hello_compat import HelloCompat
|
||||
from pymongo.helpers_constants import _SENSITIVE_COMMANDS
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"}
|
||||
_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD}
|
||||
_NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
|
||||
|
||||
|
||||
def _have_snappy() -> bool:
|
||||
try:
|
||||
import snappy # type:ignore[import] # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _have_zlib() -> bool:
|
||||
try:
|
||||
import zlib # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _have_zstd() -> bool:
|
||||
try:
|
||||
import zstandard # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[str]:
|
||||
try:
|
||||
# `value` is string.
|
||||
compressors = value.split(",") # type: ignore[union-attr]
|
||||
except AttributeError:
|
||||
# `value` is an iterable.
|
||||
compressors = list(value)
|
||||
|
||||
for compressor in compressors[:]:
|
||||
if compressor not in _SUPPORTED_COMPRESSORS:
|
||||
compressors.remove(compressor)
|
||||
warnings.warn(f"Unsupported compressor: {compressor}", stacklevel=2)
|
||||
elif compressor == "snappy" and not _have_snappy():
|
||||
compressors.remove(compressor)
|
||||
warnings.warn(
|
||||
"Wire protocol compression with snappy is not available. "
|
||||
"You must install the python-snappy module for snappy support.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif compressor == "zlib" and not _have_zlib():
|
||||
compressors.remove(compressor)
|
||||
warnings.warn(
|
||||
"Wire protocol compression with zlib is not available. "
|
||||
"The zlib module is not available.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif compressor == "zstd" and not _have_zstd():
|
||||
compressors.remove(compressor)
|
||||
warnings.warn(
|
||||
"Wire protocol compression with zstandard is not available. "
|
||||
"You must install the zstandard module for zstandard support.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return compressors
|
||||
|
||||
|
||||
def validate_zlib_compression_level(option: str, value: Any) -> int:
|
||||
try:
|
||||
level = int(value)
|
||||
except Exception:
|
||||
raise TypeError(f"{option} must be an integer, not {value!r}.") from None
|
||||
if level < -1 or level > 9:
|
||||
raise ValueError("%s must be between -1 and 9, not %d." % (option, level))
|
||||
return level
|
||||
|
||||
|
||||
class CompressionSettings:
|
||||
def __init__(self, compressors: list[str], zlib_compression_level: int):
|
||||
self.compressors = compressors
|
||||
self.zlib_compression_level = zlib_compression_level
|
||||
|
||||
def get_compression_context(
|
||||
self, compressors: Optional[list[str]]
|
||||
) -> Union[SnappyContext, ZlibContext, ZstdContext, None]:
|
||||
if compressors:
|
||||
chosen = compressors[0]
|
||||
if chosen == "snappy":
|
||||
return SnappyContext()
|
||||
elif chosen == "zlib":
|
||||
return ZlibContext(self.zlib_compression_level)
|
||||
elif chosen == "zstd":
|
||||
return ZstdContext()
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
class SnappyContext:
|
||||
compressor_id = 1
|
||||
|
||||
@staticmethod
|
||||
def compress(data: bytes) -> bytes:
|
||||
import snappy
|
||||
|
||||
return snappy.compress(data)
|
||||
|
||||
|
||||
class ZlibContext:
|
||||
compressor_id = 2
|
||||
|
||||
def __init__(self, level: int):
|
||||
self.level = level
|
||||
|
||||
def compress(self, data: bytes) -> bytes:
|
||||
import zlib
|
||||
|
||||
return zlib.compress(data, self.level)
|
||||
|
||||
|
||||
class ZstdContext:
|
||||
compressor_id = 3
|
||||
|
||||
@staticmethod
|
||||
def compress(data: bytes) -> bytes:
|
||||
# ZstdCompressor is not thread safe.
|
||||
# TODO: Use a pool?
|
||||
|
||||
import zstandard
|
||||
|
||||
return zstandard.ZstdCompressor().compress(data)
|
||||
|
||||
|
||||
def decompress(data: bytes, compressor_id: int) -> bytes:
|
||||
if compressor_id == SnappyContext.compressor_id:
|
||||
# python-snappy doesn't support the buffer interface.
|
||||
# https://github.com/andrix/python-snappy/issues/65
|
||||
# This only matters when data is a memoryview since
|
||||
# id(bytes(data)) == id(data) when data is a bytes.
|
||||
import snappy
|
||||
|
||||
return snappy.uncompress(bytes(data))
|
||||
elif compressor_id == ZlibContext.compressor_id:
|
||||
import zlib
|
||||
|
||||
return zlib.decompress(data)
|
||||
elif compressor_id == ZstdContext.compressor_id:
|
||||
# ZstdDecompressor is not thread safe.
|
||||
# TODO: Use a pool?
|
||||
import zstandard
|
||||
|
||||
return zstandard.ZstdDecompressor().decompress(data)
|
||||
else:
|
||||
raise ValueError("Unknown compressorId %d" % (compressor_id,))
|
||||
@ -36,14 +36,17 @@ from typing import (
|
||||
from bson import RE_TYPE, _convert_raw_document_lists_to_streams
|
||||
from bson.code import Code
|
||||
from bson.son import SON
|
||||
from pymongo.asynchronous import helpers
|
||||
from pymongo.asynchronous.collation import validate_collation_or_none
|
||||
from pymongo.asynchronous.common import (
|
||||
from pymongo import helpers_shared
|
||||
from pymongo.asynchronous.helpers import anext
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.common import (
|
||||
validate_is_document_type,
|
||||
validate_is_mapping,
|
||||
)
|
||||
from pymongo.asynchronous.helpers import anext
|
||||
from pymongo.asynchronous.message import (
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.lock import _ALock, _create_lock
|
||||
from pymongo.message import (
|
||||
_CursorAddress,
|
||||
_GetMore,
|
||||
_OpMsg,
|
||||
@ -52,21 +55,18 @@ from pymongo.asynchronous.message import (
|
||||
_RawBatchGetMore,
|
||||
_RawBatchQuery,
|
||||
)
|
||||
from pymongo.asynchronous.response import PinnedResponse
|
||||
from pymongo.asynchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.lock import _ALock, _create_lock
|
||||
from pymongo.response import PinnedResponse
|
||||
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SupportsItems
|
||||
|
||||
from bson.codec_options import CodecOptions
|
||||
from pymongo.asynchronous.client_session import ClientSession
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.read_preferences import _ServerMode
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
@ -74,8 +74,8 @@ _IS_SYNC = False
|
||||
class _ConnectionManager:
|
||||
"""Used with exhaust cursors to ensure the connection is returned."""
|
||||
|
||||
def __init__(self, conn: Connection, more_to_come: bool):
|
||||
self.conn: Optional[Connection] = conn
|
||||
def __init__(self, conn: AsyncConnection, more_to_come: bool):
|
||||
self.conn: Optional[AsyncConnection] = conn
|
||||
self.more_to_come = more_to_come
|
||||
self._alock = _ALock(_create_lock())
|
||||
|
||||
@ -116,7 +116,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
show_record_id: Optional[bool] = None,
|
||||
snapshot: Optional[bool] = None,
|
||||
comment: Optional[Any] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
allow_disk_use: Optional[bool] = None,
|
||||
let: Optional[bool] = None,
|
||||
) -> None:
|
||||
@ -134,7 +134,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
self._exhaust = False
|
||||
self._sock_mgr: Any = None
|
||||
self._killed = False
|
||||
self._session: Optional[ClientSession]
|
||||
self._session: Optional[AsyncClientSession]
|
||||
|
||||
if session:
|
||||
self._session = session
|
||||
@ -179,7 +179,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
allow_disk_use = validate_boolean("allow_disk_use", allow_disk_use)
|
||||
|
||||
if projection is not None:
|
||||
projection = helpers._fields_list_to_dict(projection, "projection")
|
||||
projection = helpers_shared._fields_list_to_dict(projection, "projection")
|
||||
|
||||
if let is not None:
|
||||
validate_is_document_type("let", let)
|
||||
@ -191,7 +191,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
self._skip = skip
|
||||
self._limit = limit
|
||||
self._batch_size = batch_size
|
||||
self._ordering = sort and helpers._index_document(sort) or None
|
||||
self._ordering = sort and helpers_shared._index_document(sort) or None
|
||||
self._max_scan = max_scan
|
||||
self._explain = False
|
||||
self._comment = comment
|
||||
@ -313,7 +313,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
base.__dict__.update(data)
|
||||
return base
|
||||
|
||||
def _clone_base(self, session: Optional[ClientSession]) -> AsyncCursor:
|
||||
def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncCursor:
|
||||
"""Creates an empty Cursor object for information to be copied into."""
|
||||
return self.__class__(self._collection, session=session)
|
||||
|
||||
@ -742,8 +742,8 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
key, if not given :data:`~pymongo.ASCENDING` is assumed
|
||||
"""
|
||||
self._check_okay_to_chain()
|
||||
keys = helpers._index_list(key_or_list, direction)
|
||||
self._ordering = helpers._index_document(keys)
|
||||
keys = helpers_shared._index_list(key_or_list, direction)
|
||||
self._ordering = helpers_shared._index_document(keys)
|
||||
return self
|
||||
|
||||
async def explain(self) -> _DocumentType:
|
||||
@ -774,7 +774,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
if isinstance(index, str):
|
||||
self._hint = index
|
||||
else:
|
||||
self._hint = helpers._index_document(index)
|
||||
self._hint = helpers_shared._index_document(index)
|
||||
|
||||
def hint(self, index: Optional[_Hint]) -> AsyncCursor[_DocumentType]:
|
||||
"""Adds a 'hint', telling Mongo the proper index to use for the query.
|
||||
@ -927,8 +927,8 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
return self._address
|
||||
|
||||
@property
|
||||
def session(self) -> Optional[ClientSession]:
|
||||
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
|
||||
def session(self) -> Optional[AsyncClientSession]:
|
||||
"""The cursor's :class:`~pymongo.client_session.AsyncClientSession`, or None.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
"""
|
||||
@ -1122,7 +1122,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
self._address = response.address
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self._sock_mgr:
|
||||
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
|
||||
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type]
|
||||
|
||||
cmd_name = operation.name
|
||||
docs = response.docs
|
||||
|
||||
@ -33,25 +33,24 @@ from typing import (
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
|
||||
from bson.dbref import DBRef
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo import _csot, common
|
||||
from pymongo.asynchronous.aggregation import _DatabaseAggregationCommand
|
||||
from pymongo.asynchronous.change_stream import DatabaseChangeStream
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.common import _ecoc_coll_name, _esc_coll_name
|
||||
from pymongo.asynchronous.operations import _Op
|
||||
from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
|
||||
from pymongo.common import _ecoc_coll_name, _esc_coll_name
|
||||
from pymongo.database_shared import _check_name, _CodecDocumentType
|
||||
from pymongo.errors import CollectionInvalid, InvalidOperation
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import bson
|
||||
import bson.codec_options
|
||||
from pymongo.asynchronous.client_session import ClientSession
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.asynchronous.server import Server
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.write_concern import WriteConcern
|
||||
@ -151,7 +150,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
>>> db1.read_preference
|
||||
Primary()
|
||||
>>> from pymongo.asynchronous.read_preferences import Secondary
|
||||
>>> from pymongo.read_preferences import Secondary
|
||||
>>> db2 = db1.with_options(read_preference=Secondary([{'node': 'analytics'}]))
|
||||
>>> db1.read_preference
|
||||
Primary()
|
||||
@ -328,7 +327,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
batch_size: Optional[int] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
start_at_operation_time: Optional[Timestamp] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
start_after: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
full_document_before_change: Optional[str] = None,
|
||||
@ -402,7 +401,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
the specified :class:`~bson.timestamp.Timestamp`. Requires
|
||||
MongoDB >= 4.0.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param start_after: The same as `resume_after` except that
|
||||
`start_after` can resume notifications after an invalidate event.
|
||||
This option and `resume_after` are mutually exclusive.
|
||||
@ -458,7 +457,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
check_exists: Optional[bool] = True,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCollection[_DocumentType]:
|
||||
@ -489,7 +488,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param `check_exists`: if True (the default), send a listCollections command to
|
||||
check if the collection already exists before creation.
|
||||
:param kwargs: additional keyword arguments will
|
||||
@ -607,7 +606,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
return coll
|
||||
|
||||
async def aggregate(
|
||||
self, pipeline: _Pipeline, session: Optional[ClientSession] = None, **kwargs: Any
|
||||
self, pipeline: _Pipeline, session: Optional[AsyncClientSession] = None, **kwargs: Any
|
||||
) -> AsyncCommandCursor[_DocumentType]:
|
||||
"""Perform a database-level aggregation.
|
||||
|
||||
@ -634,7 +633,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
:param pipeline: a list of aggregation pipeline stages
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param kwargs: extra `aggregate command`_ parameters.
|
||||
|
||||
All optional `aggregate command`_ parameters should be passed as
|
||||
@ -688,7 +687,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
@overload
|
||||
async def _command(
|
||||
self,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
value: int = 1,
|
||||
check: bool = True,
|
||||
@ -697,7 +696,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
codec_options: CodecOptions[dict[str, Any]] = DEFAULT_CODEC_OPTIONS,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
...
|
||||
@ -705,7 +704,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
@overload
|
||||
async def _command(
|
||||
self,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
value: int = 1,
|
||||
check: bool = True,
|
||||
@ -714,14 +713,14 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
codec_options: CodecOptions[_CodecDocumentType] = ...,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
**kwargs: Any,
|
||||
) -> _CodecDocumentType:
|
||||
...
|
||||
|
||||
async def _command(
|
||||
self,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
value: int = 1,
|
||||
check: bool = True,
|
||||
@ -732,7 +731,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
] = DEFAULT_CODEC_OPTIONS,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[dict[str, Any], _CodecDocumentType]:
|
||||
"""Internal command helper."""
|
||||
@ -763,7 +762,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
codec_options: None = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
@ -778,7 +777,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
codec_options: CodecOptions[_CodecDocumentType] = ...,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> _CodecDocumentType:
|
||||
@ -793,7 +792,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
codec_options: Optional[bson.codec_options.CodecOptions[_CodecDocumentType]] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[dict[str, Any], _CodecDocumentType]:
|
||||
@ -852,7 +851,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
:param codec_options: A :class:`~bson.codec_options.CodecOptions`
|
||||
instance.
|
||||
:param session: A
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: additional keyword arguments will
|
||||
@ -922,7 +921,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
value: Any = 1,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
codec_options: Optional[CodecOptions[_CodecDocumentType]] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
@ -953,7 +952,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
:param codec_options`: A :class:`~bson.codec_options.CodecOptions`
|
||||
instance.
|
||||
:param session: A
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to future getMores for this
|
||||
command.
|
||||
:param max_await_time_ms: The number of ms to wait for more data on future getMores for this command.
|
||||
@ -1024,15 +1023,15 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
operation: str,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Same as command but used for retryable read commands."""
|
||||
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
|
||||
|
||||
async def _cmd(
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
_server: Server,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
read_preference: _ServerMode,
|
||||
) -> dict[str, Any]:
|
||||
return await self._command(
|
||||
@ -1046,8 +1045,8 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def _list_collections(
|
||||
self,
|
||||
conn: Connection,
|
||||
session: Optional[ClientSession],
|
||||
conn: AsyncConnection,
|
||||
session: Optional[AsyncClientSession],
|
||||
read_preference: _ServerMode,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCommandCursor[MutableMapping[str, Any]]:
|
||||
@ -1075,7 +1074,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def _list_collections_helper(
|
||||
self,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
filter: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
@ -1083,7 +1082,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
"""Get a cursor over the collections of this database.
|
||||
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param filter: A query document to filter the list of
|
||||
collections returned from the listCollections command.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
@ -1106,9 +1105,9 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
kwargs["comment"] = comment
|
||||
|
||||
async def _cmd(
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
_server: Server,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
read_preference: _ServerMode,
|
||||
) -> AsyncCommandCursor[MutableMapping[str, Any]]:
|
||||
return await self._list_collections(
|
||||
@ -1121,7 +1120,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def list_collections(
|
||||
self,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
filter: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
@ -1129,7 +1128,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
"""Get a cursor over the collections of this database.
|
||||
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param filter: A query document to filter the list of
|
||||
collections returned from the listCollections command.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
@ -1149,7 +1148,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def _list_collection_names(
|
||||
self,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
filter: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
@ -1174,7 +1173,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def list_collection_names(
|
||||
self,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
filter: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
@ -1187,7 +1186,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
db.list_collection_names(filter=filter)
|
||||
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param filter: A query document to filter the list of
|
||||
collections returned from the listCollections command.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
@ -1207,7 +1206,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
return await self._list_collection_names(session, filter, comment, **kwargs)
|
||||
|
||||
async def _drop_helper(
|
||||
self, name: str, session: Optional[ClientSession] = None, comment: Optional[Any] = None
|
||||
self, name: str, session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None
|
||||
) -> dict[str, Any]:
|
||||
command = {"drop": name}
|
||||
if comment is not None:
|
||||
@ -1227,7 +1226,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
async def drop_collection(
|
||||
self,
|
||||
name_or_collection: Union[str, AsyncCollection[_DocumentTypeArg]],
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
encrypted_fields: Optional[Mapping[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
@ -1236,7 +1235,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
:param name_or_collection: the name of a collection to drop or the
|
||||
collection object itself
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param encrypted_fields: **(BETA)** Document that describes the encrypted fields for
|
||||
@ -1306,7 +1305,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
name_or_collection: Union[str, AsyncCollection[_DocumentTypeArg]],
|
||||
scandata: bool = False,
|
||||
full: bool = False,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
background: Optional[bool] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> dict[str, Any]:
|
||||
@ -1326,7 +1325,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
of the structure of the collection and the individual
|
||||
documents.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param background: A boolean flag that determines whether
|
||||
the command runs in the background. Requires MongoDB 4.4+.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
@ -1386,7 +1385,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
async def dereference(
|
||||
self,
|
||||
dbref: DBRef,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> Optional[_DocumentType]:
|
||||
@ -1401,7 +1400,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
:param dbref: the reference
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: any additional keyword arguments
|
||||
|
||||
@ -59,16 +59,13 @@ from bson.errors import BSONError
|
||||
from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson
|
||||
from pymongo import _csot
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.common import CONNECT_TIMEOUT
|
||||
from pymongo.asynchronous.cursor import AsyncCursor
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.encryption_options import AutoEncryptionOpts, RangeOpts
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.operations import UpdateOne
|
||||
from pymongo.asynchronous.pool import PoolOptions, _configured_socket, _raise_connection_failure
|
||||
from pymongo.asynchronous.typings import _DocumentType, _DocumentTypeArg
|
||||
from pymongo.asynchronous.uri_parser import parse_host
|
||||
from pymongo.asynchronous.pool import _configured_socket, _raise_connection_failure
|
||||
from pymongo.common import CONNECT_TIMEOUT
|
||||
from pymongo.daemon import _spawn_daemon
|
||||
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
EncryptedCollectionError,
|
||||
@ -78,9 +75,13 @@ from pymongo.errors import (
|
||||
ServerSelectionTimeoutError,
|
||||
)
|
||||
from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall
|
||||
from pymongo.operations import UpdateOne
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.results import BulkWriteResult, DeleteResult
|
||||
from pymongo.ssl_support import get_ssl_context
|
||||
from pymongo.typings import _DocumentType, _DocumentTypeArg
|
||||
from pymongo.uri_parser import parse_host
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -381,7 +382,10 @@ class _Encrypter:
|
||||
)
|
||||
|
||||
io_callbacks = _EncryptionIO( # type:ignore[misc]
|
||||
metadata_client, key_vault_coll, mongocryptd_client, opts
|
||||
metadata_client,
|
||||
key_vault_coll, # type:ignore[arg-type]
|
||||
mongocryptd_client,
|
||||
opts,
|
||||
)
|
||||
self._auto_encrypter = AsyncAutoEncrypter(
|
||||
io_callbacks,
|
||||
@ -459,7 +463,14 @@ class Algorithm(str, enum.Enum):
|
||||
RANGE = "Range"
|
||||
"""Range.
|
||||
|
||||
.. versionadded:: 4.8
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
RANGEPREVIEW = "RangePreview"
|
||||
"""**DEPRECATED** - RangePreview.
|
||||
|
||||
.. note:: Support for RangePreview is deprecated. Use :attr:`Algorithm.RANGE` instead.
|
||||
|
||||
.. versionadded:: 4.4
|
||||
"""
|
||||
|
||||
|
||||
@ -473,7 +484,18 @@ class QueryType(str, enum.Enum):
|
||||
"""Used to encrypt a value for an equality query."""
|
||||
|
||||
RANGE = "range"
|
||||
"""Used to encrypt a value for a range query."""
|
||||
"""Used to encrypt a value for a range query.
|
||||
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
|
||||
RANGEPREVIEW = "RangePreview"
|
||||
"""**DEPRECATED** - Used to encrypt a value for a rangePreview query.
|
||||
|
||||
.. note:: Support for RangePreview is deprecated. Use :attr:`QueryType.RANGE` instead.
|
||||
|
||||
.. versionadded:: 4.4
|
||||
"""
|
||||
|
||||
|
||||
def _create_mongocrypt_options(**kwargs: Any) -> MongoCryptOptions:
|
||||
@ -570,7 +592,7 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
raise ConfigurationError(
|
||||
"client-side field level encryption requires the pymongocrypt "
|
||||
"library: install a compatible version with: "
|
||||
"python -m pip install 'pymongo[encryption]'"
|
||||
"python -m pip install --upgrade 'pymongo[encryption]'"
|
||||
)
|
||||
|
||||
if not isinstance(codec_options, CodecOptions):
|
||||
@ -843,7 +865,7 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
|
||||
:return: The encrypted value, a :class:`~bson.binary.Binary` with subtype 6.
|
||||
|
||||
.. versionchanged:: 4.8
|
||||
.. versionchanged:: 4.9
|
||||
Added the `range_opts` parameter.
|
||||
|
||||
.. versionchanged:: 4.7
|
||||
@ -899,7 +921,7 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
|
||||
:return: The encrypted expression, a :class:`~bson.RawBSONDocument`.
|
||||
|
||||
.. versionchanged:: 4.8
|
||||
.. versionchanged:: 4.9
|
||||
Added the `range_opts` parameter.
|
||||
|
||||
.. versionchanged:: 4.7
|
||||
|
||||
@ -1,272 +0,0 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Support for automatic client-side field level encryption."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
||||
|
||||
try:
|
||||
import pymongocrypt # type:ignore[import] # noqa: F401
|
||||
|
||||
_HAVE_PYMONGOCRYPT = True
|
||||
except ImportError:
|
||||
_HAVE_PYMONGOCRYPT = False
|
||||
from bson import int64
|
||||
from pymongo.asynchronous.common import validate_is_mapping
|
||||
from pymongo.asynchronous.uri_parser import _parse_kms_tls_options
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.typings import _DocumentTypeArg
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class AutoEncryptionOpts:
|
||||
"""Options to configure automatic client-side field level encryption."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: Optional[AsyncMongoClient[_DocumentTypeArg]] = None,
|
||||
schema_map: Optional[Mapping[str, Any]] = None,
|
||||
bypass_auto_encryption: bool = False,
|
||||
mongocryptd_uri: str = "mongodb://localhost:27020",
|
||||
mongocryptd_bypass_spawn: bool = False,
|
||||
mongocryptd_spawn_path: str = "mongocryptd",
|
||||
mongocryptd_spawn_args: Optional[list[str]] = None,
|
||||
kms_tls_options: Optional[Mapping[str, Any]] = None,
|
||||
crypt_shared_lib_path: Optional[str] = None,
|
||||
crypt_shared_lib_required: bool = False,
|
||||
bypass_query_analysis: bool = False,
|
||||
encrypted_fields_map: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Options to configure automatic client-side field level encryption.
|
||||
|
||||
Automatic client-side field level encryption requires MongoDB >=4.2
|
||||
enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not
|
||||
supported for operations on a database or view and will result in
|
||||
error.
|
||||
|
||||
Although automatic encryption requires MongoDB >=4.2 enterprise or a
|
||||
MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all
|
||||
users. To configure automatic *decryption* without automatic
|
||||
*encryption* set ``bypass_auto_encryption=True``. Explicit
|
||||
encryption and explicit decryption is also supported for all users
|
||||
with the :class:`~pymongo.encryption.ClientEncryption` class.
|
||||
|
||||
See :ref:`automatic-client-side-encryption` for an example.
|
||||
|
||||
:param kms_providers: Map of KMS provider options. The `kms_providers`
|
||||
map values differ by provider:
|
||||
|
||||
- `aws`: Map with "accessKeyId" and "secretAccessKey" as strings.
|
||||
These are the AWS access key ID and AWS secret access key used
|
||||
to generate KMS messages. An optional "sessionToken" may be
|
||||
included to support temporary AWS credentials.
|
||||
- `azure`: Map with "tenantId", "clientId", and "clientSecret" as
|
||||
strings. Additionally, "identityPlatformEndpoint" may also be
|
||||
specified as a string (defaults to 'login.microsoftonline.com').
|
||||
These are the Azure Active Directory credentials used to
|
||||
generate Azure Key Vault messages.
|
||||
- `gcp`: Map with "email" as a string and "privateKey"
|
||||
as `bytes` or a base64 encoded string.
|
||||
Additionally, "endpoint" may also be specified as a string
|
||||
(defaults to 'oauth2.googleapis.com'). These are the
|
||||
credentials used to generate Google Cloud KMS messages.
|
||||
- `kmip`: Map with "endpoint" as a host with required port.
|
||||
For example: ``{"endpoint": "example.com:443"}``.
|
||||
- `local`: Map with "key" as `bytes` (96 bytes in length) or
|
||||
a base64 encoded string which decodes
|
||||
to 96 bytes. "key" is the master key used to encrypt/decrypt
|
||||
data keys. This key should be generated and stored as securely
|
||||
as possible.
|
||||
|
||||
KMS providers may be specified with an optional name suffix
|
||||
separated by a colon, for example "kmip:name" or "aws:name".
|
||||
Named KMS providers do not support :ref:`CSFLE on-demand credentials`.
|
||||
Named KMS providers enables more than one of each KMS provider type to be configured.
|
||||
For example, to configure multiple local KMS providers::
|
||||
|
||||
kms_providers = {
|
||||
"local": {"key": local_kek1}, # Unnamed KMS provider.
|
||||
"local:myname": {"key": local_kek2}, # Named KMS provider with name "myname".
|
||||
}
|
||||
|
||||
:param key_vault_namespace: The namespace for the key vault collection.
|
||||
The key vault collection contains all data keys used for encryption
|
||||
and decryption. Data keys are stored as documents in this MongoDB
|
||||
collection. Data keys are protected with encryption by a KMS
|
||||
provider.
|
||||
:param key_vault_client: By default, the key vault collection
|
||||
is assumed to reside in the same MongoDB cluster as the encrypted
|
||||
AsyncMongoClient. Use this option to route data key queries to a
|
||||
separate MongoDB cluster.
|
||||
:param schema_map: Map of collection namespace ("db.coll") to
|
||||
JSON Schema. By default, a collection's JSONSchema is periodically
|
||||
polled with the listCollections command. But a JSONSchema may be
|
||||
specified locally with the schemaMap option.
|
||||
|
||||
**Supplying a `schema_map` provides more security than relying on
|
||||
JSON Schemas obtained from the server. It protects against a
|
||||
malicious server advertising a false JSON Schema, which could trick
|
||||
the client into sending unencrypted data that should be
|
||||
encrypted.**
|
||||
|
||||
Schemas supplied in the schemaMap only apply to configuring
|
||||
automatic encryption for client side encryption. Other validation
|
||||
rules in the JSON schema will not be enforced by the driver and
|
||||
will result in an error.
|
||||
:param bypass_auto_encryption: If ``True``, automatic
|
||||
encryption will be disabled but automatic decryption will still be
|
||||
enabled. Defaults to ``False``.
|
||||
:param mongocryptd_uri: The MongoDB URI used to connect
|
||||
to the *local* mongocryptd process. Defaults to
|
||||
``'mongodb://localhost:27020'``.
|
||||
:param mongocryptd_bypass_spawn: If ``True``, the encrypted
|
||||
AsyncMongoClient will not attempt to spawn the mongocryptd process.
|
||||
Defaults to ``False``.
|
||||
:param mongocryptd_spawn_path: Used for spawning the
|
||||
mongocryptd process. Defaults to ``'mongocryptd'`` and spawns
|
||||
mongocryptd from the system path.
|
||||
:param mongocryptd_spawn_args: A list of string arguments to
|
||||
use when spawning the mongocryptd process. Defaults to
|
||||
``['--idleShutdownTimeoutSecs=60']``. If the list does not include
|
||||
the ``idleShutdownTimeoutSecs`` option then
|
||||
``'--idleShutdownTimeoutSecs=60'`` will be added.
|
||||
:param kms_tls_options: A map of KMS provider names to TLS
|
||||
options to use when creating secure connections to KMS providers.
|
||||
Accepts the same TLS options as
|
||||
:class:`pymongo.mongo_client.AsyncMongoClient`. For example, to
|
||||
override the system default CA file::
|
||||
|
||||
kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}}
|
||||
|
||||
Or to supply a client certificate::
|
||||
|
||||
kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}}
|
||||
:param crypt_shared_lib_path: Override the path to load the crypt_shared library.
|
||||
:param crypt_shared_lib_required: If True, raise an error if libmongocrypt is
|
||||
unable to load the crypt_shared library.
|
||||
:param bypass_query_analysis: If ``True``, disable automatic analysis
|
||||
of outgoing commands. Set `bypass_query_analysis` to use explicit
|
||||
encryption on indexed fields without the MongoDB Enterprise Advanced
|
||||
licensed crypt_shared library.
|
||||
:param encrypted_fields_map: Map of collection namespace ("db.coll") to documents
|
||||
that described the encrypted fields for Queryable Encryption. For example::
|
||||
|
||||
{
|
||||
"db.encryptedCollection": {
|
||||
"escCollection": "enxcol_.encryptedCollection.esc",
|
||||
"ecocCollection": "enxcol_.encryptedCollection.ecoc",
|
||||
"fields": [
|
||||
{
|
||||
"path": "firstName",
|
||||
"keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')),
|
||||
"bsonType": "string",
|
||||
"queries": {"queryType": "equality"}
|
||||
},
|
||||
{
|
||||
"path": "ssn",
|
||||
"keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')),
|
||||
"bsonType": "string"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`,
|
||||
and `bypass_query_analysis` parameters.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Added the `kms_tls_options` parameter and the "kmip" KMS provider.
|
||||
|
||||
.. versionadded:: 3.9
|
||||
"""
|
||||
if not _HAVE_PYMONGOCRYPT:
|
||||
raise ConfigurationError(
|
||||
"client side encryption requires the pymongocrypt library: "
|
||||
"install a compatible version with: "
|
||||
"python -m pip install 'pymongo[encryption]'"
|
||||
)
|
||||
if encrypted_fields_map:
|
||||
validate_is_mapping("encrypted_fields_map", encrypted_fields_map)
|
||||
self._encrypted_fields_map = encrypted_fields_map
|
||||
self._bypass_query_analysis = bypass_query_analysis
|
||||
self._crypt_shared_lib_path = crypt_shared_lib_path
|
||||
self._crypt_shared_lib_required = crypt_shared_lib_required
|
||||
self._kms_providers = kms_providers
|
||||
self._key_vault_namespace = key_vault_namespace
|
||||
self._key_vault_client = key_vault_client
|
||||
self._schema_map = schema_map
|
||||
self._bypass_auto_encryption = bypass_auto_encryption
|
||||
self._mongocryptd_uri = mongocryptd_uri
|
||||
self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
|
||||
self._mongocryptd_spawn_path = mongocryptd_spawn_path
|
||||
if mongocryptd_spawn_args is None:
|
||||
mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"]
|
||||
self._mongocryptd_spawn_args = mongocryptd_spawn_args
|
||||
if not isinstance(self._mongocryptd_spawn_args, list):
|
||||
raise TypeError("mongocryptd_spawn_args must be a list")
|
||||
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
|
||||
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
|
||||
# Maps KMS provider name to a SSLContext.
|
||||
self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options)
|
||||
self._bypass_query_analysis = bypass_query_analysis
|
||||
|
||||
|
||||
class RangeOpts:
|
||||
"""Options to configure encrypted queries using the range algorithm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sparsity: int,
|
||||
trim_factor: int,
|
||||
min: Optional[Any] = None,
|
||||
max: Optional[Any] = None,
|
||||
precision: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Options to configure encrypted queries using the range algorithm.
|
||||
|
||||
:param sparsity: An integer.
|
||||
:param trim_factor: An integer.
|
||||
:param min: A BSON scalar value corresponding to the type being queried.
|
||||
:param max: A BSON scalar value corresponding to the type being queried.
|
||||
:param precision: An integer, may only be set for double or decimal128 types.
|
||||
|
||||
.. versionadded:: 4.4
|
||||
"""
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.sparsity = sparsity
|
||||
self.trim_factor = trim_factor
|
||||
self.precision = precision
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
doc = {}
|
||||
for k, v in [
|
||||
("sparsity", int64.Int64(self.sparsity)),
|
||||
("trimFactor", self.trim_factor),
|
||||
("precision", self.precision),
|
||||
("min", self.min),
|
||||
("max", self.max),
|
||||
]:
|
||||
if v is not None:
|
||||
doc[k] = v
|
||||
return doc
|
||||
@ -1,225 +0,0 @@
|
||||
# Copyright 2020-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Example event logger classes.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
|
||||
These loggers can be registered using :func:`register` or
|
||||
:class:`~pymongo.mongo_client.MongoClient`.
|
||||
|
||||
``monitoring.register(CommandLogger())``
|
||||
|
||||
or
|
||||
|
||||
``MongoClient(event_listeners=[CommandLogger()])``
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from pymongo.asynchronous import monitoring
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class CommandLogger(monitoring.CommandListener):
|
||||
"""A simple listener that logs command events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.CommandStartedEvent`,
|
||||
:class:`~pymongo.monitoring.CommandSucceededEvent` and
|
||||
:class:`~pymongo.monitoring.CommandFailedEvent` events and
|
||||
logs them at the `INFO` severity level using :mod:`logging`.
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def started(self, event: monitoring.CommandStartedEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} started on server "
|
||||
f"{event.connection_id}"
|
||||
)
|
||||
|
||||
def succeeded(self, event: monitoring.CommandSucceededEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} on server {event.connection_id} "
|
||||
f"succeeded in {event.duration_micros} "
|
||||
"microseconds"
|
||||
)
|
||||
|
||||
def failed(self, event: monitoring.CommandFailedEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} on server {event.connection_id} "
|
||||
f"failed in {event.duration_micros} "
|
||||
"microseconds"
|
||||
)
|
||||
|
||||
|
||||
class ServerLogger(monitoring.ServerListener):
|
||||
"""A simple listener that logs server discovery events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`,
|
||||
:class:`~pymongo.monitoring.ServerDescriptionChangedEvent`,
|
||||
and :class:`~pymongo.monitoring.ServerClosedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def opened(self, event: monitoring.ServerOpeningEvent) -> None:
|
||||
logging.info(f"Server {event.server_address} added to topology {event.topology_id}")
|
||||
|
||||
def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None:
|
||||
previous_server_type = event.previous_description.server_type
|
||||
new_server_type = event.new_description.server_type
|
||||
if new_server_type != previous_server_type:
|
||||
# server_type_name was added in PyMongo 3.4
|
||||
logging.info(
|
||||
f"Server {event.server_address} changed type from "
|
||||
f"{event.previous_description.server_type_name} to "
|
||||
f"{event.new_description.server_type_name}"
|
||||
)
|
||||
|
||||
def closed(self, event: monitoring.ServerClosedEvent) -> None:
|
||||
logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}")
|
||||
|
||||
|
||||
class HeartbeatLogger(monitoring.ServerHeartbeatListener):
|
||||
"""A simple listener that logs server heartbeat events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`,
|
||||
:class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`,
|
||||
and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
|
||||
logging.info(f"Heartbeat sent to server {event.connection_id}")
|
||||
|
||||
def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
|
||||
# The reply.document attribute was added in PyMongo 3.4.
|
||||
logging.info(
|
||||
f"Heartbeat to server {event.connection_id} "
|
||||
"succeeded with reply "
|
||||
f"{event.reply.document}"
|
||||
)
|
||||
|
||||
def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
|
||||
logging.warning(
|
||||
f"Heartbeat to server {event.connection_id} failed with error {event.reply}"
|
||||
)
|
||||
|
||||
|
||||
class TopologyLogger(monitoring.TopologyListener):
|
||||
"""A simple listener that logs server topology events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`,
|
||||
:class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`,
|
||||
and :class:`~pymongo.monitoring.TopologyClosedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def opened(self, event: monitoring.TopologyOpenedEvent) -> None:
|
||||
logging.info(f"Topology with id {event.topology_id} opened")
|
||||
|
||||
def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None:
|
||||
logging.info(f"Topology description updated for topology id {event.topology_id}")
|
||||
previous_topology_type = event.previous_description.topology_type
|
||||
new_topology_type = event.new_description.topology_type
|
||||
if new_topology_type != previous_topology_type:
|
||||
# topology_type_name was added in PyMongo 3.4
|
||||
logging.info(
|
||||
f"Topology {event.topology_id} changed type from "
|
||||
f"{event.previous_description.topology_type_name} to "
|
||||
f"{event.new_description.topology_type_name}"
|
||||
)
|
||||
# The has_writable_server and has_readable_server methods
|
||||
# were added in PyMongo 3.4.
|
||||
if not event.new_description.has_writable_server():
|
||||
logging.warning("No writable servers available.")
|
||||
if not event.new_description.has_readable_server():
|
||||
logging.warning("No readable servers available.")
|
||||
|
||||
def closed(self, event: monitoring.TopologyClosedEvent) -> None:
|
||||
logging.info(f"Topology with id {event.topology_id} closed")
|
||||
|
||||
|
||||
class ConnectionPoolLogger(monitoring.ConnectionPoolListener):
|
||||
"""A simple listener that logs server connection pool events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`,
|
||||
:class:`~pymongo.monitoring.PoolClearedEvent`,
|
||||
:class:`~pymongo.monitoring.PoolClosedEvent`,
|
||||
:~pymongo.monitoring.class:`ConnectionCreatedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionReadyEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionClosedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckedOutEvent`,
|
||||
and :class:`~pymongo.monitoring.ConnectionCheckedInEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def pool_created(self, event: monitoring.PoolCreatedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool created")
|
||||
|
||||
def pool_ready(self, event: monitoring.PoolReadyEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool ready")
|
||||
|
||||
def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool cleared")
|
||||
|
||||
def pool_closed(self, event: monitoring.PoolClosedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool closed")
|
||||
|
||||
def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created")
|
||||
|
||||
def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded"
|
||||
)
|
||||
|
||||
def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] "
|
||||
f'connection closed, reason: "{event.reason}"'
|
||||
)
|
||||
|
||||
def connection_check_out_started(
|
||||
self, event: monitoring.ConnectionCheckOutStartedEvent
|
||||
) -> None:
|
||||
logging.info(f"[pool {event.address}] connection check out started")
|
||||
|
||||
def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}")
|
||||
|
||||
def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool"
|
||||
)
|
||||
|
||||
def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool"
|
||||
)
|
||||
@ -1,26 +0,0 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The HelloCompat class, placed here to break circular import issues."""
|
||||
from __future__ import annotations
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class HelloCompat:
|
||||
CMD = "hello"
|
||||
LEGACY_CMD = "ismaster"
|
||||
PRIMARY = "isWritablePrimary"
|
||||
LEGACY_PRIMARY = "ismaster"
|
||||
LEGACY_ERROR = "not master"
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright 2009-present MongoDB, Inc.
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,270 +12,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Bits and pieces used by the driver that don't really fit elsewhere."""
|
||||
"""Miscellaneous pieces that need to be synchronized."""
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import sys
|
||||
import traceback
|
||||
from collections import abc
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Container,
|
||||
Iterable,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pymongo import ASCENDING
|
||||
from pymongo.asynchronous.hello_compat import HelloCompat
|
||||
from pymongo.errors import (
|
||||
CursorNotFound,
|
||||
DuplicateKeyError,
|
||||
ExecutionTimeout,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
WriteConcernError,
|
||||
WriteError,
|
||||
WTimeoutError,
|
||||
_wtimeout_error,
|
||||
)
|
||||
from pymongo.helpers_constants import _NOT_PRIMARY_CODES, _REAUTHENTICATION_REQUIRED_CODE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.operations import _IndexList
|
||||
from pymongo.asynchronous.typings import _DocumentOut
|
||||
from pymongo.cursor_shared import _Hint
|
||||
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _gen_index_name(keys: _IndexList) -> str:
|
||||
"""Generate an index name from the set of fields it is over."""
|
||||
return "_".join(["{}_{}".format(*item) for item in keys])
|
||||
|
||||
|
||||
def _index_list(
|
||||
key_or_list: _Hint, direction: Optional[Union[int, str]] = None
|
||||
) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]:
|
||||
"""Helper to generate a list of (key, direction) pairs.
|
||||
|
||||
Takes such a list, or a single key, or a single key and direction.
|
||||
"""
|
||||
if direction is not None:
|
||||
if not isinstance(key_or_list, str):
|
||||
raise TypeError("Expected a string and a direction")
|
||||
return [(key_or_list, direction)]
|
||||
else:
|
||||
if isinstance(key_or_list, str):
|
||||
return [(key_or_list, ASCENDING)]
|
||||
elif isinstance(key_or_list, abc.ItemsView):
|
||||
return list(key_or_list) # type: ignore[arg-type]
|
||||
elif isinstance(key_or_list, abc.Mapping):
|
||||
return list(key_or_list.items())
|
||||
elif not isinstance(key_or_list, (list, tuple)):
|
||||
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
|
||||
values: list[tuple[str, int]] = []
|
||||
for item in key_or_list:
|
||||
if isinstance(item, str):
|
||||
item = (item, ASCENDING) # noqa: PLW2901
|
||||
values.append(item)
|
||||
return values
|
||||
|
||||
|
||||
def _index_document(index_list: _IndexList) -> dict[str, Any]:
|
||||
"""Helper to generate an index specifying document.
|
||||
|
||||
Takes a list of (key, direction) pairs.
|
||||
"""
|
||||
if not isinstance(index_list, (list, tuple, abc.Mapping)):
|
||||
raise TypeError(
|
||||
"must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list)
|
||||
)
|
||||
if not len(index_list):
|
||||
raise ValueError("key_or_list must not be empty")
|
||||
|
||||
index: dict[str, Any] = {}
|
||||
|
||||
if isinstance(index_list, abc.Mapping):
|
||||
for key in index_list:
|
||||
value = index_list[key]
|
||||
_validate_index_key_pair(key, value)
|
||||
index[key] = value
|
||||
else:
|
||||
for item in index_list:
|
||||
if isinstance(item, str):
|
||||
item = (item, ASCENDING) # noqa: PLW2901
|
||||
key, value = item
|
||||
_validate_index_key_pair(key, value)
|
||||
index[key] = value
|
||||
return index
|
||||
|
||||
|
||||
def _validate_index_key_pair(key: Any, value: Any) -> None:
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("first item in each key pair must be an instance of str")
|
||||
if not isinstance(value, (str, int, abc.Mapping)):
|
||||
raise TypeError(
|
||||
"second item in each key pair must be 1, -1, "
|
||||
"'2d', or another valid MongoDB index specifier."
|
||||
)
|
||||
|
||||
|
||||
def _check_command_response(
|
||||
response: _DocumentOut,
|
||||
max_wire_version: Optional[int],
|
||||
allowable_errors: Optional[Container[Union[int, str]]] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
) -> None:
|
||||
"""Check the response to a command for errors."""
|
||||
if "ok" not in response:
|
||||
# Server didn't recognize our message as a command.
|
||||
raise OperationFailure(
|
||||
response.get("$err"), # type: ignore[arg-type]
|
||||
response.get("code"),
|
||||
response,
|
||||
max_wire_version,
|
||||
)
|
||||
|
||||
if parse_write_concern_error and "writeConcernError" in response:
|
||||
_error = response["writeConcernError"]
|
||||
_labels = response.get("errorLabels")
|
||||
if _labels:
|
||||
_error.update({"errorLabels": _labels})
|
||||
_raise_write_concern_error(_error)
|
||||
|
||||
if response["ok"]:
|
||||
return
|
||||
|
||||
details = response
|
||||
# Mongos returns the error details in a 'raw' object
|
||||
# for some errors.
|
||||
if "raw" in response:
|
||||
for shard in response["raw"].values():
|
||||
# Grab the first non-empty raw error from a shard.
|
||||
if shard.get("errmsg") and not shard.get("ok"):
|
||||
details = shard
|
||||
break
|
||||
|
||||
errmsg = details["errmsg"]
|
||||
code = details.get("code")
|
||||
|
||||
# For allowable errors, only check for error messages when the code is not
|
||||
# included.
|
||||
if allowable_errors:
|
||||
if code is not None:
|
||||
if code in allowable_errors:
|
||||
return
|
||||
elif errmsg in allowable_errors:
|
||||
return
|
||||
|
||||
# Server is "not primary" or "recovering"
|
||||
if code is not None:
|
||||
if code in _NOT_PRIMARY_CODES:
|
||||
raise NotPrimaryError(errmsg, response)
|
||||
elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg:
|
||||
raise NotPrimaryError(errmsg, response)
|
||||
|
||||
# Other errors
|
||||
# findAndModify with upsert can raise duplicate key error
|
||||
if code in (11000, 11001, 12582):
|
||||
raise DuplicateKeyError(errmsg, code, response, max_wire_version)
|
||||
elif code == 50:
|
||||
raise ExecutionTimeout(errmsg, code, response, max_wire_version)
|
||||
elif code == 43:
|
||||
raise CursorNotFound(errmsg, code, response, max_wire_version)
|
||||
|
||||
raise OperationFailure(errmsg, code, response, max_wire_version)
|
||||
|
||||
|
||||
def _raise_last_write_error(write_errors: list[Any]) -> NoReturn:
|
||||
# If the last batch had multiple errors only report
|
||||
# the last error to emulate continue_on_error.
|
||||
error = write_errors[-1]
|
||||
if error.get("code") == 11000:
|
||||
raise DuplicateKeyError(error.get("errmsg"), 11000, error)
|
||||
raise WriteError(error.get("errmsg"), error.get("code"), error)
|
||||
|
||||
|
||||
def _raise_write_concern_error(error: Any) -> NoReturn:
|
||||
if _wtimeout_error(error):
|
||||
# Make sure we raise WTimeoutError
|
||||
raise WTimeoutError(error.get("errmsg"), error.get("code"), error)
|
||||
raise WriteConcernError(error.get("errmsg"), error.get("code"), error)
|
||||
|
||||
|
||||
def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
|
||||
"""Return the writeConcernError or None."""
|
||||
wce = result.get("writeConcernError")
|
||||
if wce:
|
||||
# The server reports errorLabels at the top level but it's more
|
||||
# convenient to attach it to the writeConcernError doc itself.
|
||||
error_labels = result.get("errorLabels")
|
||||
if error_labels:
|
||||
# Copy to avoid changing the original document.
|
||||
wce = wce.copy()
|
||||
wce["errorLabels"] = error_labels
|
||||
return wce
|
||||
|
||||
|
||||
def _check_write_command_response(result: Mapping[str, Any]) -> None:
|
||||
"""Backward compatibility helper for write command error handling."""
|
||||
# Prefer write errors over write concern errors
|
||||
write_errors = result.get("writeErrors")
|
||||
if write_errors:
|
||||
_raise_last_write_error(write_errors)
|
||||
|
||||
wce = _get_wce_doc(result)
|
||||
if wce:
|
||||
_raise_write_concern_error(wce)
|
||||
|
||||
|
||||
def _fields_list_to_dict(
|
||||
fields: Union[Mapping[str, Any], Iterable[str]], option_name: str
|
||||
) -> Mapping[str, Any]:
|
||||
"""Takes a sequence of field names and returns a matching dictionary.
|
||||
|
||||
["a", "b"] becomes {"a": 1, "b": 1}
|
||||
|
||||
and
|
||||
|
||||
["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1}
|
||||
"""
|
||||
if isinstance(fields, abc.Mapping):
|
||||
return fields
|
||||
|
||||
if isinstance(fields, (abc.Sequence, abc.Set)):
|
||||
if not all(isinstance(field, str) for field in fields):
|
||||
raise TypeError(f"{option_name} must be a list of key names, each an instance of str")
|
||||
return dict.fromkeys(fields, 1)
|
||||
|
||||
raise TypeError(f"{option_name} must be a mapping or list of key names")
|
||||
|
||||
|
||||
def _handle_exception() -> None:
|
||||
"""Print exceptions raised by subscribers to stderr."""
|
||||
# Heavily influenced by logging.Handler.handleError.
|
||||
|
||||
# See note here:
|
||||
# https://docs.python.org/3.4/library/sys.html#sys.__stderr__
|
||||
if sys.stderr:
|
||||
einfo = sys.exc_info()
|
||||
try:
|
||||
traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr)
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
del einfo
|
||||
|
||||
|
||||
# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
@ -283,8 +38,8 @@ F = TypeVar("F", bound=Callable[..., Any])
|
||||
def _handle_reauth(func: F) -> F:
|
||||
async def inner(*args: Any, **kwargs: Any) -> Any:
|
||||
no_reauth = kwargs.pop("no_reauth", False)
|
||||
from pymongo.asynchronous.message import _BulkWriteContext
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.message import _BulkWriteContext
|
||||
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
@ -292,16 +47,16 @@ def _handle_reauth(func: F) -> F:
|
||||
if no_reauth:
|
||||
raise
|
||||
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
|
||||
# Look for an argument that either is a Connection
|
||||
# Look for an argument that either is a AsyncConnection
|
||||
# or has a connection attribute, so we can trigger
|
||||
# a reauth.
|
||||
conn = None
|
||||
for arg in args:
|
||||
if isinstance(arg, Connection):
|
||||
if isinstance(arg, AsyncConnection):
|
||||
conn = arg
|
||||
break
|
||||
if isinstance(arg, _BulkWriteContext):
|
||||
conn = arg.conn
|
||||
conn = arg.conn # type: ignore[assignment]
|
||||
break
|
||||
if conn:
|
||||
await conn.authenticate(reauthenticate=True)
|
||||
|
||||
@ -1,171 +0,0 @@
|
||||
# Copyright 2023-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from bson import UuidRepresentation, json_util
|
||||
from bson.json_util import JSONOptions, _truncate_documents
|
||||
from pymongo.asynchronous.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class _CommandStatusMessage(str, enum.Enum):
|
||||
STARTED = "Command started"
|
||||
SUCCEEDED = "Command succeeded"
|
||||
FAILED = "Command failed"
|
||||
|
||||
|
||||
class _ServerSelectionStatusMessage(str, enum.Enum):
|
||||
STARTED = "Server selection started"
|
||||
SUCCEEDED = "Server selection succeeded"
|
||||
FAILED = "Server selection failed"
|
||||
WAITING = "Waiting for suitable server to become available"
|
||||
|
||||
|
||||
class _ConnectionStatusMessage(str, enum.Enum):
|
||||
POOL_CREATED = "Connection pool created"
|
||||
POOL_READY = "Connection pool ready"
|
||||
POOL_CLOSED = "Connection pool closed"
|
||||
POOL_CLEARED = "Connection pool cleared"
|
||||
|
||||
CONN_CREATED = "Connection created"
|
||||
CONN_READY = "Connection ready"
|
||||
CONN_CLOSED = "Connection closed"
|
||||
|
||||
CHECKOUT_STARTED = "Connection checkout started"
|
||||
CHECKOUT_SUCCEEDED = "Connection checked out"
|
||||
CHECKOUT_FAILED = "Connection checkout failed"
|
||||
CHECKEDIN = "Connection checked in"
|
||||
|
||||
|
||||
_DEFAULT_DOCUMENT_LENGTH = 1000
|
||||
_SENSITIVE_COMMANDS = [
|
||||
"authenticate",
|
||||
"saslStart",
|
||||
"saslContinue",
|
||||
"getnonce",
|
||||
"createUser",
|
||||
"updateUser",
|
||||
"copydbgetnonce",
|
||||
"copydbsaslstart",
|
||||
"copydb",
|
||||
]
|
||||
_HELLO_COMMANDS = ["hello", "ismaster", "isMaster"]
|
||||
_REDACTED_FAILURE_FIELDS = ["code", "codeName", "errorLabels"]
|
||||
_DOCUMENT_NAMES = ["command", "reply", "failure"]
|
||||
_JSON_OPTIONS = JSONOptions(uuid_representation=UuidRepresentation.STANDARD)
|
||||
_COMMAND_LOGGER = logging.getLogger("pymongo.command")
|
||||
_CONNECTION_LOGGER = logging.getLogger("pymongo.connection")
|
||||
_SERVER_SELECTION_LOGGER = logging.getLogger("pymongo.serverSelection")
|
||||
_CLIENT_LOGGER = logging.getLogger("pymongo.client")
|
||||
_VERBOSE_CONNECTION_ERROR_REASONS = {
|
||||
ConnectionClosedReason.POOL_CLOSED: "Connection pool was closed",
|
||||
ConnectionCheckOutFailedReason.POOL_CLOSED: "Connection pool was closed",
|
||||
ConnectionClosedReason.STALE: "Connection pool was stale",
|
||||
ConnectionClosedReason.ERROR: "An error occurred while using the connection",
|
||||
ConnectionCheckOutFailedReason.CONN_ERROR: "An error occurred while trying to establish a new connection",
|
||||
ConnectionClosedReason.IDLE: "Connection was idle too long",
|
||||
ConnectionCheckOutFailedReason.TIMEOUT: "Connection exceeded the specified timeout",
|
||||
}
|
||||
|
||||
|
||||
def _debug_log(logger: logging.Logger, **fields: Any) -> None:
|
||||
logger.debug(LogMessage(**fields))
|
||||
|
||||
|
||||
def _verbose_connection_error_reason(reason: str) -> str:
|
||||
return _VERBOSE_CONNECTION_ERROR_REASONS.get(reason, reason)
|
||||
|
||||
|
||||
def _info_log(logger: logging.Logger, **fields: Any) -> None:
|
||||
logger.info(LogMessage(**fields))
|
||||
|
||||
|
||||
def _log_or_warn(logger: logging.Logger, message: str) -> None:
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
logger.info(message)
|
||||
else:
|
||||
# stacklevel=4 ensures that the warning is for the user's code.
|
||||
warnings.warn(message, UserWarning, stacklevel=4)
|
||||
|
||||
|
||||
class LogMessage:
|
||||
__slots__ = ("_kwargs", "_redacted")
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
self._kwargs = kwargs
|
||||
self._redacted = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
self._redact()
|
||||
return "%s" % (
|
||||
json_util.dumps(
|
||||
self._kwargs, json_options=_JSON_OPTIONS, default=lambda o: o.__repr__()
|
||||
)
|
||||
)
|
||||
|
||||
def _is_sensitive(self, doc_name: str) -> bool:
|
||||
is_speculative_authenticate = (
|
||||
self._kwargs.pop("speculative_authenticate", False)
|
||||
or "speculativeAuthenticate" in self._kwargs[doc_name]
|
||||
)
|
||||
is_sensitive_command = (
|
||||
"commandName" in self._kwargs and self._kwargs["commandName"] in _SENSITIVE_COMMANDS
|
||||
)
|
||||
|
||||
is_sensitive_hello = (
|
||||
self._kwargs["commandName"] in _HELLO_COMMANDS and is_speculative_authenticate
|
||||
)
|
||||
|
||||
return is_sensitive_command or is_sensitive_hello
|
||||
|
||||
def _redact(self) -> None:
|
||||
if self._redacted:
|
||||
return
|
||||
self._kwargs = {k: v for k, v in self._kwargs.items() if v is not None}
|
||||
if "durationMS" in self._kwargs and hasattr(self._kwargs["durationMS"], "total_seconds"):
|
||||
self._kwargs["durationMS"] = self._kwargs["durationMS"].total_seconds() * 1000
|
||||
if "serviceId" in self._kwargs:
|
||||
self._kwargs["serviceId"] = str(self._kwargs["serviceId"])
|
||||
document_length = int(os.getenv("MONGOB_LOG_MAX_DOCUMENT_LENGTH", _DEFAULT_DOCUMENT_LENGTH))
|
||||
if document_length < 0:
|
||||
document_length = _DEFAULT_DOCUMENT_LENGTH
|
||||
is_server_side_error = self._kwargs.pop("isServerSideError", False)
|
||||
|
||||
for doc_name in _DOCUMENT_NAMES:
|
||||
doc = self._kwargs.get(doc_name)
|
||||
if doc:
|
||||
if doc_name == "failure" and is_server_side_error:
|
||||
doc = {k: v for k, v in doc.items() if k in _REDACTED_FAILURE_FIELDS}
|
||||
if doc_name != "failure" and self._is_sensitive(doc_name):
|
||||
doc = json_util.dumps({})
|
||||
else:
|
||||
truncated_doc = _truncate_documents(doc, document_length)[0]
|
||||
doc = json_util.dumps(
|
||||
truncated_doc,
|
||||
json_options=_JSON_OPTIONS,
|
||||
default=lambda o: o.__repr__(),
|
||||
)
|
||||
if len(doc) > document_length:
|
||||
doc = (
|
||||
doc.encode()[:document_length].decode("unicode-escape", "ignore")
|
||||
) + "..."
|
||||
self._kwargs[doc_name] = doc
|
||||
self._redacted = True
|
||||
@ -1,125 +0,0 @@
|
||||
# Copyright 2016 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Criteria to select ServerDescriptions based on maxStalenessSeconds.
|
||||
|
||||
The Max Staleness Spec says: When there is a known primary P,
|
||||
a secondary S's staleness is estimated with this formula:
|
||||
|
||||
(S.lastUpdateTime - S.lastWriteDate) - (P.lastUpdateTime - P.lastWriteDate)
|
||||
+ heartbeatFrequencyMS
|
||||
|
||||
When there is no known primary, a secondary S's staleness is estimated with:
|
||||
|
||||
SMax.lastWriteDate - S.lastWriteDate + heartbeatFrequencyMS
|
||||
|
||||
where "SMax" is the secondary with the greatest lastWriteDate.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.server_selectors import Selection
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Constant defined in Max Staleness Spec: An idle primary writes a no-op every
|
||||
# 10 seconds to refresh secondaries' lastWriteDate values.
|
||||
IDLE_WRITE_PERIOD = 10
|
||||
SMALLEST_MAX_STALENESS = 90
|
||||
|
||||
|
||||
def _validate_max_staleness(max_staleness: int, heartbeat_frequency: int) -> None:
|
||||
# We checked for max staleness -1 before this, it must be positive here.
|
||||
if max_staleness < heartbeat_frequency + IDLE_WRITE_PERIOD:
|
||||
raise ConfigurationError(
|
||||
"maxStalenessSeconds must be at least heartbeatFrequencyMS +"
|
||||
" %d seconds. maxStalenessSeconds is set to %d,"
|
||||
" heartbeatFrequencyMS is set to %d."
|
||||
% (IDLE_WRITE_PERIOD, max_staleness, heartbeat_frequency * 1000)
|
||||
)
|
||||
|
||||
if max_staleness < SMALLEST_MAX_STALENESS:
|
||||
raise ConfigurationError(
|
||||
"maxStalenessSeconds must be at least %d. "
|
||||
"maxStalenessSeconds is set to %d." % (SMALLEST_MAX_STALENESS, max_staleness)
|
||||
)
|
||||
|
||||
|
||||
def _with_primary(max_staleness: int, selection: Selection) -> Selection:
|
||||
"""Apply max_staleness, in seconds, to a Selection with a known primary."""
|
||||
primary = selection.primary
|
||||
assert primary
|
||||
sds = []
|
||||
|
||||
for s in selection.server_descriptions:
|
||||
if s.server_type == SERVER_TYPE.RSSecondary:
|
||||
# See max-staleness.rst for explanation of this formula.
|
||||
assert s.last_write_date and primary.last_write_date # noqa: PT018
|
||||
staleness = (
|
||||
(s.last_update_time - s.last_write_date)
|
||||
- (primary.last_update_time - primary.last_write_date)
|
||||
+ selection.heartbeat_frequency
|
||||
)
|
||||
|
||||
if staleness <= max_staleness:
|
||||
sds.append(s)
|
||||
else:
|
||||
sds.append(s)
|
||||
|
||||
return selection.with_server_descriptions(sds)
|
||||
|
||||
|
||||
def _no_primary(max_staleness: int, selection: Selection) -> Selection:
|
||||
"""Apply max_staleness, in seconds, to a Selection with no known primary."""
|
||||
# Secondary that's replicated the most recent writes.
|
||||
smax = selection.secondary_with_max_last_write_date()
|
||||
if not smax:
|
||||
# No secondaries and no primary, short-circuit out of here.
|
||||
return selection.with_server_descriptions([])
|
||||
|
||||
sds = []
|
||||
|
||||
for s in selection.server_descriptions:
|
||||
if s.server_type == SERVER_TYPE.RSSecondary:
|
||||
# See max-staleness.rst for explanation of this formula.
|
||||
assert smax.last_write_date and s.last_write_date # noqa: PT018
|
||||
staleness = smax.last_write_date - s.last_write_date + selection.heartbeat_frequency
|
||||
|
||||
if staleness <= max_staleness:
|
||||
sds.append(s)
|
||||
else:
|
||||
sds.append(s)
|
||||
|
||||
return selection.with_server_descriptions(sds)
|
||||
|
||||
|
||||
def select(max_staleness: int, selection: Selection) -> Selection:
|
||||
"""Apply max_staleness, in seconds, to a Selection."""
|
||||
if max_staleness == -1:
|
||||
return selection
|
||||
|
||||
# Server Selection Spec: If the TopologyType is ReplicaSetWithPrimary or
|
||||
# ReplicaSetNoPrimary, a client MUST raise an error if maxStaleness <
|
||||
# heartbeatFrequency + IDLE_WRITE_PERIOD, or if maxStaleness < 90.
|
||||
_validate_max_staleness(max_staleness, selection.heartbeat_frequency)
|
||||
|
||||
if selection.primary:
|
||||
return _with_primary(max_staleness, selection)
|
||||
else:
|
||||
return _no_primary(max_staleness, selection)
|
||||
File diff suppressed because it is too large
Load Diff
@ -58,42 +58,14 @@ from typing import (
|
||||
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot, helpers_constants
|
||||
from pymongo.asynchronous import (
|
||||
client_session,
|
||||
common,
|
||||
database,
|
||||
helpers,
|
||||
message,
|
||||
periodic_executor,
|
||||
uri_parser,
|
||||
)
|
||||
from pymongo import _csot, common, helpers_shared, uri_parser
|
||||
from pymongo.asynchronous import client_session, database, periodic_executor
|
||||
from pymongo.asynchronous.change_stream import ChangeStream, ClusterChangeStream
|
||||
from pymongo.asynchronous.client_options import ClientOptions
|
||||
from pymongo.asynchronous.client_session import _EmptyServerSession
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.logger import _CLIENT_LOGGER, _log_or_warn
|
||||
from pymongo.asynchronous.monitoring import ConnectionClosedReason
|
||||
from pymongo.asynchronous.operations import _Op
|
||||
from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.asynchronous.server_selectors import writable_server_selector
|
||||
from pymongo.asynchronous.settings import TopologySettings
|
||||
from pymongo.asynchronous.topology import Topology, _ErrorContext
|
||||
from pymongo.asynchronous.topology_description import TOPOLOGY_TYPE, TopologyDescription
|
||||
from pymongo.asynchronous.typings import (
|
||||
ClusterTime,
|
||||
_Address,
|
||||
_CollationIn,
|
||||
_DocumentType,
|
||||
_DocumentTypeArg,
|
||||
_Pipeline,
|
||||
)
|
||||
from pymongo.asynchronous.uri_parser import (
|
||||
_check_options,
|
||||
_handle_option_deprecations,
|
||||
_handle_security_options,
|
||||
_normalize_options,
|
||||
)
|
||||
from pymongo.client_options import ClientOptions
|
||||
from pymongo.errors import (
|
||||
AutoReconnect,
|
||||
BulkWriteError,
|
||||
@ -108,29 +80,52 @@ from pymongo.errors import (
|
||||
WriteConcernError,
|
||||
)
|
||||
from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks
|
||||
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
|
||||
from pymongo.message import _CursorAddress, _GetMore, _Query
|
||||
from pymongo.monitoring import ConnectionClosedReason
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
|
||||
from pymongo.typings import (
|
||||
ClusterTime,
|
||||
_Address,
|
||||
_CollationIn,
|
||||
_DocumentType,
|
||||
_DocumentTypeArg,
|
||||
_Pipeline,
|
||||
)
|
||||
from pymongo.uri_parser import (
|
||||
_check_options,
|
||||
_handle_option_deprecations,
|
||||
_handle_security_options,
|
||||
_normalize_options,
|
||||
)
|
||||
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.asynchronous.bulk import _Bulk
|
||||
from pymongo.asynchronous.client_session import ClientSession, _ServerSession
|
||||
from pymongo.asynchronous.bulk import _AsyncBulk
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession, _ServerSession
|
||||
from pymongo.asynchronous.cursor import _ConnectionManager
|
||||
from pymongo.asynchronous.message import _CursorAddress, _GetMore, _Query
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.response import Response
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.asynchronous.server import Server
|
||||
from pymongo.asynchronous.server_selectors import Selection
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.response import Response
|
||||
from pymongo.server_selectors import Selection
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
_WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], Coroutine[Any, Any, T]]
|
||||
_WriteCall = Callable[
|
||||
[Optional["AsyncClientSession"], "AsyncConnection", bool], Coroutine[Any, Any, T]
|
||||
]
|
||||
_ReadCall = Callable[
|
||||
[Optional["ClientSession"], "Server", "Connection", _ServerMode], Coroutine[Any, Any, T]
|
||||
[Optional["AsyncClientSession"], "Server", "AsyncConnection", _ServerMode],
|
||||
Coroutine[Any, Any, T],
|
||||
]
|
||||
|
||||
_IS_SYNC = False
|
||||
@ -892,7 +887,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self._kill_cursors_executor = executor
|
||||
self._opened = False
|
||||
|
||||
def _should_pin_cursor(self, session: Optional[ClientSession]) -> Optional[bool]:
|
||||
def _should_pin_cursor(self, session: Optional[AsyncClientSession]) -> Optional[bool]:
|
||||
return self._options.load_balanced and not (session and session.in_transaction)
|
||||
|
||||
def _after_fork(self) -> None:
|
||||
@ -915,7 +910,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
batch_size: Optional[int] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
start_at_operation_time: Optional[Timestamp] = None,
|
||||
session: Optional[client_session.ClientSession] = None,
|
||||
session: Optional[client_session.AsyncClientSession] = None,
|
||||
start_after: Optional[Mapping[str, Any]] = None,
|
||||
comment: Optional[Any] = None,
|
||||
full_document_before_change: Optional[str] = None,
|
||||
@ -989,7 +984,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
the specified :class:`~bson.timestamp.Timestamp`. Requires
|
||||
MongoDB >= 4.0.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param start_after: The same as `resume_after` except that
|
||||
`start_after` can resume notifications after an invalidate event.
|
||||
This option and `resume_after` are mutually exclusive.
|
||||
@ -1163,30 +1158,30 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
"""Request that a cursor and/or connection be cleaned up soon."""
|
||||
self._kill_cursors_queue.append((address, cursor_id, conn_mgr))
|
||||
|
||||
def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession:
|
||||
def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession:
|
||||
server_session = _EmptyServerSession()
|
||||
opts = client_session.SessionOptions(**kwargs)
|
||||
return client_session.ClientSession(self, server_session, opts, implicit)
|
||||
return client_session.AsyncClientSession(self, server_session, opts, implicit)
|
||||
|
||||
def start_session(
|
||||
self,
|
||||
causal_consistency: Optional[bool] = None,
|
||||
default_transaction_options: Optional[client_session.TransactionOptions] = None,
|
||||
snapshot: Optional[bool] = False,
|
||||
) -> client_session.ClientSession:
|
||||
) -> client_session.AsyncClientSession:
|
||||
"""Start a logical session.
|
||||
|
||||
This method takes the same parameters as
|
||||
:class:`~pymongo.client_session.SessionOptions`. See the
|
||||
:mod:`~pymongo.client_session` module for details and examples.
|
||||
|
||||
A :class:`~pymongo.client_session.ClientSession` may only be used with
|
||||
the MongoClient that started it. :class:`ClientSession` instances are
|
||||
A :class:`~pymongo.client_session.AsyncClientSession` may only be used with
|
||||
the MongoClient that started it. :class:`AsyncClientSession` instances are
|
||||
**not thread-safe or fork-safe**. They can only be used by one thread
|
||||
or process at a time. A single :class:`ClientSession` cannot be used
|
||||
or process at a time. A single :class:`AsyncClientSession` cannot be used
|
||||
to run multiple operations concurrently.
|
||||
|
||||
:return: An instance of :class:`~pymongo.client_session.ClientSession`.
|
||||
:return: An instance of :class:`~pymongo.client_session.AsyncClientSession`.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
"""
|
||||
@ -1197,7 +1192,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]:
|
||||
def _ensure_session(
|
||||
self, session: Optional[AsyncClientSession] = None
|
||||
) -> Optional[AsyncClientSession]:
|
||||
"""If provided session is None, lend a temporary session."""
|
||||
if session:
|
||||
return session
|
||||
@ -1211,7 +1208,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
return None
|
||||
|
||||
def _send_cluster_time(
|
||||
self, command: MutableMapping[str, Any], session: Optional[ClientSession]
|
||||
self, command: MutableMapping[str, Any], session: Optional[AsyncClientSession]
|
||||
) -> None:
|
||||
topology_time = self._topology.max_cluster_time()
|
||||
session_time = session.cluster_time if session else None
|
||||
@ -1474,7 +1471,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
async def _end_sessions(self, session_ids: list[_ServerSession]) -> None:
|
||||
"""Send endSessions command(s) with the given session ids."""
|
||||
try:
|
||||
# Use Connection.command directly to avoid implicitly creating
|
||||
# Use AsyncConnection.command directly to avoid implicitly creating
|
||||
# another session.
|
||||
async with await self._conn_for_reads(
|
||||
ReadPreference.PRIMARY_PREFERRED, None, operation=_Op.END_SESSIONS
|
||||
@ -1535,8 +1532,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _checkout(
|
||||
self, server: Server, session: Optional[ClientSession]
|
||||
) -> AsyncGenerator[Connection, None]:
|
||||
self, server: Server, session: Optional[AsyncClientSession]
|
||||
) -> AsyncGenerator[AsyncConnection, None]:
|
||||
in_txn = session and session.in_transaction
|
||||
async with _MongoClientErrorHandler(self, server, session) as err_handler:
|
||||
# Reuse the pinned connection, if it exists.
|
||||
@ -1570,7 +1567,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
async def _select_server(
|
||||
self,
|
||||
server_selector: Callable[[Selection], Selection],
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
operation: str,
|
||||
address: Optional[_Address] = None,
|
||||
deprioritized_servers: Optional[list[Server]] = None,
|
||||
@ -1581,7 +1578,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
:Parameters:
|
||||
- `server_selector`: The server selector to use if the session is
|
||||
not pinned and no address is given.
|
||||
- `session`: The ClientSession for the next operation, or None. May
|
||||
- `session`: The AsyncClientSession for the next operation, or None. May
|
||||
be pinned to a mongos server address.
|
||||
- `address` (optional): Address when sending a message
|
||||
to a specific server, used for getMore.
|
||||
@ -1615,15 +1612,15 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
raise
|
||||
|
||||
async def _conn_for_writes(
|
||||
self, session: Optional[ClientSession], operation: str
|
||||
) -> AsyncContextManager[Connection]:
|
||||
self, session: Optional[AsyncClientSession], operation: str
|
||||
) -> AsyncContextManager[AsyncConnection]:
|
||||
server = await self._select_server(writable_server_selector, session, operation)
|
||||
return self._checkout(server, session)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _conn_from_server(
|
||||
self, read_preference: _ServerMode, server: Server, session: Optional[ClientSession]
|
||||
) -> AsyncGenerator[tuple[Connection, _ServerMode], None]:
|
||||
self, read_preference: _ServerMode, server: Server, session: Optional[AsyncClientSession]
|
||||
) -> AsyncGenerator[tuple[AsyncConnection, _ServerMode], None]:
|
||||
assert read_preference is not None, "read_preference must not be None"
|
||||
# Get a connection for a server matching the read preference, and yield
|
||||
# conn with the effective read preference. The Server Selection
|
||||
@ -1648,9 +1645,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
async def _conn_for_reads(
|
||||
self,
|
||||
read_preference: _ServerMode,
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
operation: str,
|
||||
) -> AsyncContextManager[tuple[Connection, _ServerMode]]:
|
||||
) -> AsyncContextManager[tuple[AsyncConnection, _ServerMode]]:
|
||||
assert read_preference is not None, "read_preference must not be None"
|
||||
server = await self._select_server(read_preference, session, operation)
|
||||
return self._conn_from_server(read_preference, server, session)
|
||||
@ -1672,13 +1669,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
if operation.conn_mgr:
|
||||
server = await self._select_server(
|
||||
operation.read_preference,
|
||||
operation.session,
|
||||
operation.session, # type: ignore[arg-type]
|
||||
operation.name,
|
||||
address=address,
|
||||
)
|
||||
|
||||
async with operation.conn_mgr._alock:
|
||||
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler:
|
||||
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
|
||||
err_handler.contribute_socket(operation.conn_mgr.conn)
|
||||
return await server.run_operation(
|
||||
operation.conn_mgr.conn,
|
||||
@ -1690,9 +1687,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
async def _cmd(
|
||||
_session: Optional[ClientSession],
|
||||
_session: Optional[AsyncClientSession],
|
||||
server: Server,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
read_preference: _ServerMode,
|
||||
) -> Response:
|
||||
operation.reset() # Reset op in case of retry.
|
||||
@ -1708,9 +1705,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
return await self._retryable_read(
|
||||
_cmd,
|
||||
operation.read_preference,
|
||||
operation.session,
|
||||
operation.session, # type: ignore[arg-type]
|
||||
address=address,
|
||||
retryable=isinstance(operation, message._Query),
|
||||
retryable=isinstance(operation, _Query),
|
||||
operation=operation.name,
|
||||
)
|
||||
|
||||
@ -1718,8 +1715,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
retryable: bool,
|
||||
func: _WriteCall[T],
|
||||
session: Optional[ClientSession],
|
||||
bulk: Optional[_Bulk],
|
||||
session: Optional[AsyncClientSession],
|
||||
bulk: Optional[_AsyncBulk],
|
||||
operation: str,
|
||||
operation_id: Optional[int] = None,
|
||||
) -> T:
|
||||
@ -1748,8 +1745,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
async def _retry_internal(
|
||||
self,
|
||||
func: _WriteCall[T] | _ReadCall[T],
|
||||
session: Optional[ClientSession],
|
||||
bulk: Optional[_Bulk],
|
||||
session: Optional[AsyncClientSession],
|
||||
bulk: Optional[_AsyncBulk],
|
||||
operation: str,
|
||||
is_read: bool = False,
|
||||
address: Optional[_Address] = None,
|
||||
@ -1787,7 +1784,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
func: _ReadCall[T],
|
||||
read_pref: _ServerMode,
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
operation: str,
|
||||
address: Optional[_Address] = None,
|
||||
retryable: bool = True,
|
||||
@ -1830,9 +1827,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
retryable: bool,
|
||||
func: _WriteCall[T],
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
operation: str,
|
||||
bulk: Optional[_Bulk] = None,
|
||||
bulk: Optional[_AsyncBulk] = None,
|
||||
operation_id: Optional[int] = None,
|
||||
) -> T:
|
||||
"""Execute an operation with consecutive retries if possible
|
||||
@ -1856,7 +1853,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
cursor_id: int,
|
||||
address: Optional[_CursorAddress],
|
||||
conn_mgr: _ConnectionManager,
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
explicit_session: bool,
|
||||
) -> None:
|
||||
"""Cleanup a cursor from __del__ without locking.
|
||||
@ -1880,7 +1877,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
cursor_id: int,
|
||||
address: Optional[_CursorAddress],
|
||||
conn_mgr: _ConnectionManager,
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
explicit_session: bool,
|
||||
) -> None:
|
||||
"""Cleanup a cursor from cursor.close() using a lock.
|
||||
@ -1913,7 +1910,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
cursor_id: int,
|
||||
address: Optional[_CursorAddress],
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
conn_mgr: Optional[_ConnectionManager] = None,
|
||||
) -> None:
|
||||
"""Send a kill cursors message with the given id.
|
||||
@ -1941,7 +1938,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
cursor_ids: Sequence[int],
|
||||
address: Optional[_CursorAddress],
|
||||
topology: Topology,
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
) -> None:
|
||||
"""Send a kill cursors message with the given ids."""
|
||||
if address:
|
||||
@ -1960,8 +1957,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
cursor_ids: Sequence[int],
|
||||
address: _CursorAddress,
|
||||
session: Optional[ClientSession],
|
||||
conn: Connection,
|
||||
session: Optional[AsyncClientSession],
|
||||
conn: AsyncConnection,
|
||||
) -> None:
|
||||
namespace = address.namespace
|
||||
db, coll = namespace.split(".", 1)
|
||||
@ -1994,7 +1991,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# can be caught in _process_periodic_tasks
|
||||
raise
|
||||
else:
|
||||
helpers._handle_exception()
|
||||
helpers_shared._handle_exception()
|
||||
|
||||
# Don't re-open topology if it's closed and there's no pending cursors.
|
||||
if address_to_cursor_ids:
|
||||
@ -2006,7 +2003,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
raise
|
||||
else:
|
||||
helpers._handle_exception()
|
||||
helpers_shared._handle_exception()
|
||||
|
||||
# This method is run periodically by a background thread.
|
||||
async def _process_periodic_tasks(self) -> None:
|
||||
@ -2020,7 +2017,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
return
|
||||
else:
|
||||
helpers._handle_exception()
|
||||
helpers_shared._handle_exception()
|
||||
|
||||
def _return_server_session(
|
||||
self, server_session: Union[_ServerSession, _EmptyServerSession]
|
||||
@ -2032,12 +2029,12 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _tmp_session(
|
||||
self, session: Optional[client_session.ClientSession], close: bool = True
|
||||
) -> AsyncGenerator[Optional[client_session.ClientSession], None, None]:
|
||||
self, session: Optional[client_session.AsyncClientSession], close: bool = True
|
||||
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None, None]:
|
||||
"""If provided session is None, lend a temporary session."""
|
||||
if session is not None:
|
||||
if not isinstance(session, client_session.ClientSession):
|
||||
raise ValueError("'session' argument must be a ClientSession or None.")
|
||||
if not isinstance(session, client_session.AsyncClientSession):
|
||||
raise ValueError("'session' argument must be an AsyncClientSession or None.")
|
||||
# Don't call end_session.
|
||||
yield session
|
||||
return
|
||||
@ -2061,19 +2058,19 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
yield None
|
||||
|
||||
async def _process_response(
|
||||
self, reply: Mapping[str, Any], session: Optional[ClientSession]
|
||||
self, reply: Mapping[str, Any], session: Optional[AsyncClientSession]
|
||||
) -> None:
|
||||
await self._topology.receive_cluster_time(reply.get("$clusterTime"))
|
||||
if session is not None:
|
||||
session._process_response(reply)
|
||||
|
||||
async def server_info(
|
||||
self, session: Optional[client_session.ClientSession] = None
|
||||
self, session: Optional[client_session.AsyncClientSession] = None
|
||||
) -> dict[str, Any]:
|
||||
"""Get information about the MongoDB server we're connected to.
|
||||
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
@ -2087,7 +2084,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def _list_databases(
|
||||
self,
|
||||
session: Optional[client_session.ClientSession] = None,
|
||||
session: Optional[client_session.AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCommandCursor[dict[str, Any]]:
|
||||
@ -2109,14 +2106,14 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def list_databases(
|
||||
self,
|
||||
session: Optional[client_session.ClientSession] = None,
|
||||
session: Optional[client_session.AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCommandCursor[dict[str, Any]]:
|
||||
"""Get a cursor over the databases of the connected server.
|
||||
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
:param kwargs: Optional parameters of the
|
||||
@ -2134,13 +2131,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
async def list_database_names(
|
||||
self,
|
||||
session: Optional[client_session.ClientSession] = None,
|
||||
session: Optional[client_session.AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> list[str]:
|
||||
"""Get a list of the names of all databases on the connected server.
|
||||
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
|
||||
@ -2156,7 +2153,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
async def drop_database(
|
||||
self,
|
||||
name_or_database: Union[str, database.AsyncDatabase[_DocumentTypeArg]],
|
||||
session: Optional[client_session.ClientSession] = None,
|
||||
session: Optional[client_session.AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Drop a database.
|
||||
@ -2168,7 +2165,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
:class:`~pymongo.database.Database` instance representing the
|
||||
database to drop
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
:class:`~pymongo.client_session.AsyncClientSession`.
|
||||
:param comment: A user-provided comment to attach to this
|
||||
command.
|
||||
|
||||
@ -2236,10 +2233,10 @@ def _add_retryable_write_error(exc: PyMongoError, max_wire_version: int, is_mong
|
||||
# Do not consult writeConcernError for pre-4.4 mongos.
|
||||
if isinstance(exc, WriteConcernError) and is_mongos:
|
||||
pass
|
||||
elif code in helpers_constants._RETRYABLE_ERROR_CODES:
|
||||
elif code in helpers_shared._RETRYABLE_ERROR_CODES:
|
||||
exc._add_error_label("RetryableWriteError")
|
||||
|
||||
# Connection errors are always retryable except NotPrimaryError and WaitQueueTimeoutError which is
|
||||
# AsyncConnection errors are always retryable except NotPrimaryError and WaitQueueTimeoutError which is
|
||||
# handled above.
|
||||
if isinstance(exc, ConnectionFailure) and not isinstance(
|
||||
exc, (NotPrimaryError, WaitQueueTimeoutError)
|
||||
@ -2261,7 +2258,9 @@ class _MongoClientErrorHandler:
|
||||
"handled",
|
||||
)
|
||||
|
||||
def __init__(self, client: AsyncMongoClient, server: Server, session: Optional[ClientSession]):
|
||||
def __init__(
|
||||
self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession]
|
||||
):
|
||||
self.client = client
|
||||
self.server_address = server.description.address
|
||||
self.session = session
|
||||
@ -2275,7 +2274,7 @@ class _MongoClientErrorHandler:
|
||||
self.service_id: Optional[ObjectId] = None
|
||||
self.handled = False
|
||||
|
||||
def contribute_socket(self, conn: Connection, completed_handshake: bool = True) -> None:
|
||||
def contribute_socket(self, conn: AsyncConnection, completed_handshake: bool = True) -> None:
|
||||
"""Provide socket information to the error handler."""
|
||||
self.max_wire_version = conn.max_wire_version
|
||||
self.sock_generation = conn.generation
|
||||
@ -2327,10 +2326,10 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
self,
|
||||
mongo_client: AsyncMongoClient,
|
||||
func: _WriteCall[T] | _ReadCall[T],
|
||||
bulk: Optional[_Bulk],
|
||||
bulk: Optional[_AsyncBulk],
|
||||
operation: str,
|
||||
is_read: bool = False,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
read_pref: Optional[_ServerMode] = None,
|
||||
address: Optional[_Address] = None,
|
||||
retryable: bool = False,
|
||||
@ -2392,7 +2391,7 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
exc_code = getattr(exc, "code", None)
|
||||
if self._is_not_eligible_for_retry() or (
|
||||
isinstance(exc, OperationFailure)
|
||||
and exc_code not in helpers_constants._RETRYABLE_ERROR_CODES
|
||||
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
|
||||
):
|
||||
raise
|
||||
self._retrying = True
|
||||
|
||||
@ -21,19 +21,20 @@ import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
|
||||
|
||||
from pymongo import common
|
||||
from pymongo._csot import MovingMinimum
|
||||
from pymongo.asynchronous import common, periodic_executor
|
||||
from pymongo.asynchronous.hello import Hello
|
||||
from pymongo.asynchronous import periodic_executor
|
||||
from pymongo.asynchronous.periodic_executor import _shutdown_executors
|
||||
from pymongo.asynchronous.pool import _is_faas
|
||||
from pymongo.asynchronous.read_preferences import MovingAverage
|
||||
from pymongo.asynchronous.server_description import ServerDescription
|
||||
from pymongo.asynchronous.srv_resolver import _SrvResolver
|
||||
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.pool_options import _is_faas
|
||||
from pymongo.read_preferences import MovingAverage
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.srv_resolver import _SrvResolver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.pool import Connection, Pool, _CancellationContext
|
||||
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext
|
||||
from pymongo.asynchronous.settings import TopologySettings
|
||||
from pymongo.asynchronous.topology import Topology
|
||||
|
||||
@ -294,7 +295,7 @@ class Monitor(MonitorBase):
|
||||
)
|
||||
return sd
|
||||
|
||||
async def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]:
|
||||
async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float]:
|
||||
"""Return (Hello, round_trip_time).
|
||||
|
||||
Can raise ConnectionFailure or OperationFailure.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -33,20 +33,18 @@ from typing import (
|
||||
)
|
||||
|
||||
from bson import _decode_all_selective
|
||||
from pymongo import _csot
|
||||
from pymongo.asynchronous import helpers as _async_helpers
|
||||
from pymongo.asynchronous import message as _async_message
|
||||
from pymongo.asynchronous.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.asynchronous.compression_support import _NO_COMPRESSION, decompress
|
||||
from pymongo.asynchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.asynchronous.message import _UNPACK_REPLY, _OpMsg, _OpReply
|
||||
from pymongo.asynchronous.monitoring import _is_speculative_authenticate
|
||||
from pymongo import _csot, helpers_shared, message
|
||||
from pymongo.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.compression_support import _NO_COMPRESSION, decompress
|
||||
from pymongo.errors import (
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
ProtocolError,
|
||||
_OperationCancelled,
|
||||
)
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
|
||||
from pymongo.monitoring import _is_speculative_authenticate
|
||||
from pymongo.network_layer import (
|
||||
_POLL_TIMEOUT,
|
||||
_UNPACK_COMPRESSION_HEADER,
|
||||
@ -58,27 +56,27 @@ from pymongo.socket_checker import _errno_from_exception
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson import CodecOptions
|
||||
from pymongo.asynchronous.client_session import ClientSession
|
||||
from pymongo.asynchronous.compression_support import SnappyContext, ZlibContext, ZstdContext
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.monitoring import _EventListeners
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.read_preferences import _ServerMode
|
||||
from pymongo.asynchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
async def command(
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
dbname: str,
|
||||
spec: MutableMapping[str, Any],
|
||||
is_mongos: bool,
|
||||
read_preference: Optional[_ServerMode],
|
||||
codec_options: CodecOptions[_DocumentType],
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
client: Optional[AsyncMongoClient],
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
@ -97,13 +95,13 @@ async def command(
|
||||
) -> _DocumentType:
|
||||
"""Execute a command over the socket, or raise socket.error.
|
||||
|
||||
:param conn: a Connection instance
|
||||
:param conn: a AsyncConnection instance
|
||||
:param dbname: name of the database on which to run the command
|
||||
:param spec: a command document as an ordered dict type, eg SON.
|
||||
:param is_mongos: are we connected to a mongos?
|
||||
:param read_preference: a read preference
|
||||
:param codec_options: a CodecOptions instance
|
||||
:param session: optional ClientSession instance.
|
||||
:param session: optional AsyncClientSession instance.
|
||||
:param client: optional AsyncMongoClient instance for updating $clusterTime.
|
||||
:param check: raise OperationFailure if there are errors
|
||||
:param allowable_errors: errors to ignore if `check` is True
|
||||
@ -130,7 +128,7 @@ async def command(
|
||||
orig = spec
|
||||
if is_mongos and not use_op_msg:
|
||||
assert read_preference is not None
|
||||
spec = _async_message._maybe_add_read_preference(spec, read_preference)
|
||||
spec = message._maybe_add_read_preference(spec, read_preference)
|
||||
if read_concern and not (session and session.in_transaction):
|
||||
if read_concern.level:
|
||||
spec["readConcern"] = read_concern.document
|
||||
@ -158,22 +156,20 @@ async def command(
|
||||
if use_op_msg:
|
||||
flags = _OpMsg.MORE_TO_COME if unacknowledged else 0
|
||||
flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0
|
||||
request_id, msg, size, max_doc_size = _async_message._op_msg(
|
||||
request_id, msg, size, max_doc_size = message._op_msg(
|
||||
flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx
|
||||
)
|
||||
# If this is an unacknowledged write then make sure the encoded doc(s)
|
||||
# are small enough, otherwise rely on the server to return an error.
|
||||
if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size:
|
||||
_async_message._raise_document_too_large(name, size, max_bson_size)
|
||||
message._raise_document_too_large(name, size, max_bson_size)
|
||||
else:
|
||||
request_id, msg, size = _async_message._query(
|
||||
request_id, msg, size = message._query(
|
||||
0, ns, 0, -1, spec, None, codec_options, compression_ctx
|
||||
)
|
||||
|
||||
if max_bson_size is not None and size > max_bson_size + _async_message._COMMAND_OVERHEAD:
|
||||
_async_message._raise_document_too_large(
|
||||
name, size, max_bson_size + _async_message._COMMAND_OVERHEAD
|
||||
)
|
||||
if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD:
|
||||
message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD)
|
||||
if client is not None:
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
@ -220,7 +216,7 @@ async def command(
|
||||
if client:
|
||||
await client._process_response(response_doc, session)
|
||||
if check:
|
||||
_async_helpers._check_command_response(
|
||||
helpers_shared._check_command_response(
|
||||
response_doc,
|
||||
conn.max_wire_version,
|
||||
allowable_errors,
|
||||
@ -231,7 +227,7 @@ async def command(
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _async_message._convert_exception(exc)
|
||||
failure = message._convert_exception(exc)
|
||||
if client is not None:
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
@ -310,7 +306,7 @@ async def command(
|
||||
|
||||
|
||||
async def receive_message(
|
||||
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
|
||||
conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
if _csot.get_timeout():
|
||||
@ -355,7 +351,7 @@ async def receive_message(
|
||||
return unpack_reply(data)
|
||||
|
||||
|
||||
async def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
|
||||
async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None:
|
||||
"""Block until at least one byte is read, or a timeout, or a cancel."""
|
||||
sock = conn.conn
|
||||
timed_out = False
|
||||
@ -390,7 +386,7 @@ async def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
|
||||
|
||||
|
||||
async def _receive_data_on_socket(
|
||||
conn: Connection, length: int, deadline: Optional[float]
|
||||
conn: AsyncConnection, length: int, deadline: Optional[float]
|
||||
) -> memoryview:
|
||||
buf = bytearray(length)
|
||||
mv = memoryview(buf)
|
||||
|
||||
@ -1,625 +0,0 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Operation class definitions."""
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo.asynchronous import helpers
|
||||
from pymongo.asynchronous.collation import validate_collation_or_none
|
||||
from pymongo.asynchronous.common import validate_is_mapping, validate_list
|
||||
from pymongo.asynchronous.helpers import _gen_index_name, _index_document, _index_list
|
||||
from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.bulk import _Bulk
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary
|
||||
_IndexList = Union[
|
||||
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
|
||||
]
|
||||
_IndexKeyHint = Union[str, _IndexList]
|
||||
|
||||
|
||||
class _Op(str, enum.Enum):
|
||||
ABORT = "abortTransaction"
|
||||
AGGREGATE = "aggregate"
|
||||
COMMIT = "commitTransaction"
|
||||
COUNT = "count"
|
||||
CREATE = "create"
|
||||
CREATE_INDEXES = "createIndexes"
|
||||
CREATE_SEARCH_INDEXES = "createSearchIndexes"
|
||||
DELETE = "delete"
|
||||
DISTINCT = "distinct"
|
||||
DROP = "drop"
|
||||
DROP_DATABASE = "dropDatabase"
|
||||
DROP_INDEXES = "dropIndexes"
|
||||
DROP_SEARCH_INDEXES = "dropSearchIndexes"
|
||||
END_SESSIONS = "endSessions"
|
||||
FIND_AND_MODIFY = "findAndModify"
|
||||
FIND = "find"
|
||||
INSERT = "insert"
|
||||
LIST_COLLECTIONS = "listCollections"
|
||||
LIST_INDEXES = "listIndexes"
|
||||
LIST_SEARCH_INDEX = "listSearchIndexes"
|
||||
LIST_DATABASES = "listDatabases"
|
||||
UPDATE = "update"
|
||||
UPDATE_INDEX = "updateIndex"
|
||||
UPDATE_SEARCH_INDEX = "updateSearchIndex"
|
||||
RENAME = "rename"
|
||||
GETMORE = "getMore"
|
||||
KILL_CURSORS = "killCursors"
|
||||
TEST = "testOperation"
|
||||
|
||||
|
||||
class InsertOne(Generic[_DocumentType]):
|
||||
"""Represents an insert_one operation."""
|
||||
|
||||
__slots__ = ("_doc",)
|
||||
|
||||
def __init__(self, document: _DocumentType) -> None:
|
||||
"""Create an InsertOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
|
||||
|
||||
:param document: The document to insert. If the document is missing an
|
||||
_id field one will be added.
|
||||
"""
|
||||
self._doc = document
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_insert(self._doc) # type: ignore[arg-type]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"InsertOne({self._doc!r})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return other._doc == self._doc
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
class DeleteOne:
|
||||
"""Represents a delete_one operation."""
|
||||
|
||||
__slots__ = ("_filter", "_collation", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create a DeleteOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the document to delete.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.4 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hint`` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
self._filter = filter
|
||||
self._collation = collation
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_delete(
|
||||
self._filter,
|
||||
1,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return (other._filter, other._collation, other._hint) == (
|
||||
self._filter,
|
||||
self._collation,
|
||||
self._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
class DeleteMany:
|
||||
"""Represents a delete_many operation."""
|
||||
|
||||
__slots__ = ("_filter", "_collation", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create a DeleteMany instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the documents to delete.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.4 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hint`` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
self._filter = filter
|
||||
self._collation = collation
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_delete(
|
||||
self._filter,
|
||||
0,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return (other._filter, other._collation, other._hint) == (
|
||||
self._filter,
|
||||
self._collation,
|
||||
self._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
class ReplaceOne(Generic[_DocumentType]):
|
||||
"""Represents a replace_one operation."""
|
||||
|
||||
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
replacement: Union[_DocumentType, RawBSONDocument],
|
||||
upsert: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create a ReplaceOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the document to replace.
|
||||
:param replacement: The new document.
|
||||
:param upsert: If ``True``, perform an insert if no documents
|
||||
match the filter.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hint`` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the ``collation`` option.
|
||||
"""
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if upsert is not None:
|
||||
validate_boolean("upsert", upsert)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
self._filter = filter
|
||||
self._doc = replacement
|
||||
self._upsert = upsert
|
||||
self._collation = collation
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_replace(
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return (
|
||||
other._filter,
|
||||
other._doc,
|
||||
other._upsert,
|
||||
other._collation,
|
||||
other._hint,
|
||||
) == (
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
other._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format(
|
||||
self.__class__.__name__,
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
self._hint,
|
||||
)
|
||||
|
||||
|
||||
class _UpdateOp:
|
||||
"""Private base class for update operations."""
|
||||
|
||||
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
doc: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool,
|
||||
collation: Optional[_CollationIn],
|
||||
array_filters: Optional[list[Mapping[str, Any]]],
|
||||
hint: Optional[_IndexKeyHint],
|
||||
):
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if upsert is not None:
|
||||
validate_boolean("upsert", upsert)
|
||||
if array_filters is not None:
|
||||
validate_list("array_filters", array_filters)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
|
||||
self._filter = filter
|
||||
self._doc = doc
|
||||
self._upsert = upsert
|
||||
self._collation = collation
|
||||
self._array_filters = array_filters
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, type(self)):
|
||||
return (
|
||||
other._filter,
|
||||
other._doc,
|
||||
other._upsert,
|
||||
other._collation,
|
||||
other._array_filters,
|
||||
other._hint,
|
||||
) == (
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
self._array_filters,
|
||||
self._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format(
|
||||
self.__class__.__name__,
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
self._array_filters,
|
||||
self._hint,
|
||||
)
|
||||
|
||||
|
||||
class UpdateOne(_UpdateOp):
|
||||
"""Represents an update_one operation."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[list[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Represents an update_one operation.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the document to update.
|
||||
:param update: The modifications to apply.
|
||||
:param upsert: If ``True``, perform an insert if no documents
|
||||
match the filter.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param array_filters: A list of filters specifying which
|
||||
array elements an update should apply.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the `hint` option.
|
||||
.. versionchanged:: 3.9
|
||||
Added the ability to accept a pipeline as the `update`.
|
||||
.. versionchanged:: 3.6
|
||||
Added the `array_filters` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
super().__init__(filter, update, upsert, collation, array_filters, hint)
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_update(
|
||||
self._filter,
|
||||
self._doc,
|
||||
False,
|
||||
self._upsert,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
array_filters=self._array_filters,
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
|
||||
class UpdateMany(_UpdateOp):
|
||||
"""Represents an update_many operation."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[list[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create an UpdateMany instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the documents to update.
|
||||
:param update: The modifications to apply.
|
||||
:param upsert: If ``True``, perform an insert if no documents
|
||||
match the filter.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param array_filters: A list of filters specifying which
|
||||
array elements an update should apply.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the `hint` option.
|
||||
.. versionchanged:: 3.9
|
||||
Added the ability to accept a pipeline as the `update`.
|
||||
.. versionchanged:: 3.6
|
||||
Added the `array_filters` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
super().__init__(filter, update, upsert, collation, array_filters, hint)
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_update(
|
||||
self._filter,
|
||||
self._doc,
|
||||
True,
|
||||
self._upsert,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
array_filters=self._array_filters,
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
|
||||
class IndexModel:
|
||||
"""Represents an index to create."""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None:
|
||||
"""Create an Index instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.create_indexes`.
|
||||
|
||||
Takes either a single key or a list containing (key, direction) pairs
|
||||
or keys. If no direction is given, :data:`~pymongo.ASCENDING` will
|
||||
be assumed.
|
||||
The key(s) must be an instance of :class:`str`, and the direction(s) must
|
||||
be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`,
|
||||
:data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`,
|
||||
:data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`).
|
||||
|
||||
Valid options include, but are not limited to:
|
||||
|
||||
- `name`: custom name to use for this index - if none is
|
||||
given, a name will be generated.
|
||||
- `unique`: if ``True``, creates a uniqueness constraint on the index.
|
||||
- `background`: if ``True``, this index should be created in the
|
||||
background.
|
||||
- `sparse`: if ``True``, omit from the index any documents that lack
|
||||
the indexed field.
|
||||
- `bucketSize`: for use with geoHaystack indexes.
|
||||
Number of documents to group together within a certain proximity
|
||||
to a given longitude and latitude.
|
||||
- `min`: minimum value for keys in a :data:`~pymongo.GEO2D`
|
||||
index.
|
||||
- `max`: maximum value for keys in a :data:`~pymongo.GEO2D`
|
||||
index.
|
||||
- `expireAfterSeconds`: <int> Used to create an expiring (TTL)
|
||||
collection. MongoDB will automatically delete documents from
|
||||
this collection after <int> seconds. The indexed field must
|
||||
be a UTC datetime or the data will not expire.
|
||||
- `partialFilterExpression`: A document that specifies a filter for
|
||||
a partial index.
|
||||
- `collation`: An instance of :class:`~pymongo.collation.Collation`
|
||||
that specifies the collation to use.
|
||||
- `wildcardProjection`: Allows users to include or exclude specific
|
||||
field paths from a `wildcard index`_ using the { "$**" : 1} key
|
||||
pattern. Requires MongoDB >= 4.2.
|
||||
- `hidden`: if ``True``, this index will be hidden from the query
|
||||
planner and will not be evaluated as part of query plan
|
||||
selection. Requires MongoDB >= 4.4.
|
||||
|
||||
See the MongoDB documentation for a full list of supported options by
|
||||
server version.
|
||||
|
||||
:param keys: a single key or a list containing (key, direction) pairs
|
||||
or keys specifying the index to create.
|
||||
:param kwargs: any additional index creation
|
||||
options (see the above list) should be passed as keyword
|
||||
arguments.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hidden`` option.
|
||||
.. versionchanged:: 3.2
|
||||
Added the ``partialFilterExpression`` option to support partial
|
||||
indexes.
|
||||
|
||||
.. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/
|
||||
"""
|
||||
keys = _index_list(keys)
|
||||
if kwargs.get("name") is None:
|
||||
kwargs["name"] = _gen_index_name(keys)
|
||||
kwargs["key"] = _index_document(keys)
|
||||
collation = validate_collation_or_none(kwargs.pop("collation", None))
|
||||
self.__document = kwargs
|
||||
if collation is not None:
|
||||
self.__document["collation"] = collation
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""An index document suitable for passing to the createIndexes
|
||||
command.
|
||||
"""
|
||||
return self.__document
|
||||
|
||||
|
||||
class SearchIndexModel:
|
||||
"""Represents a search index to create."""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
definition: Mapping[str, Any],
|
||||
name: Optional[str] = None,
|
||||
type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a Search Index instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.create_search_index` and :meth:`~pymongo.collection.AsyncCollection.create_search_indexes`.
|
||||
|
||||
:param definition: The definition for this index.
|
||||
:param name: The name for this index, if present.
|
||||
:param type: The type for this index which defaults to "search". Alternative values include "vectorSearch".
|
||||
:param kwargs: Keyword arguments supplying any additional options.
|
||||
|
||||
.. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster.
|
||||
.. versionadded:: 4.5
|
||||
.. versionchanged:: 4.7
|
||||
Added the type and kwargs arguments.
|
||||
"""
|
||||
self.__document: dict[str, Any] = {}
|
||||
if name is not None:
|
||||
self.__document["name"] = name
|
||||
self.__document["definition"] = definition
|
||||
if type is not None:
|
||||
self.__document["type"] = type
|
||||
self.__document.update(kwargs)
|
||||
|
||||
@property
|
||||
def document(self) -> Mapping[str, Any]:
|
||||
"""The document for this index."""
|
||||
return self.__document
|
||||
@ -16,17 +16,14 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -39,39 +36,18 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import bson
|
||||
from bson import DEFAULT_CODEC_OPTIONS
|
||||
from pymongo import __version__, _csot
|
||||
from pymongo.asynchronous import helpers
|
||||
from pymongo import _csot, helpers_shared
|
||||
from pymongo.asynchronous.client_session import _validate_session_write_concern
|
||||
from pymongo.asynchronous.common import (
|
||||
from pymongo.asynchronous.helpers import _handle_reauth
|
||||
from pymongo.asynchronous.network import command, receive_message
|
||||
from pymongo.common import (
|
||||
MAX_BSON_SIZE,
|
||||
MAX_CONNECTING,
|
||||
MAX_IDLE_TIME_SEC,
|
||||
MAX_MESSAGE_SIZE,
|
||||
MAX_POOL_SIZE,
|
||||
MAX_WIRE_VERSION,
|
||||
MAX_WRITE_BATCH_SIZE,
|
||||
MIN_POOL_SIZE,
|
||||
ORDERED_TYPES,
|
||||
WAIT_QUEUE_TIMEOUT,
|
||||
)
|
||||
from pymongo.asynchronous.hello import Hello
|
||||
from pymongo.asynchronous.hello_compat import HelloCompat
|
||||
from pymongo.asynchronous.helpers import _handle_reauth
|
||||
from pymongo.asynchronous.logger import (
|
||||
_CONNECTION_LOGGER,
|
||||
_ConnectionStatusMessage,
|
||||
_debug_log,
|
||||
_verbose_connection_error_reason,
|
||||
)
|
||||
from pymongo.asynchronous.monitoring import (
|
||||
ConnectionCheckOutFailedReason,
|
||||
ConnectionClosedReason,
|
||||
_EventListeners,
|
||||
)
|
||||
from pymongo.asynchronous.network import command, receive_message
|
||||
from pymongo.asynchronous.read_preferences import ReadPreference
|
||||
from pymongo.errors import ( # type:ignore[attr-defined]
|
||||
AutoReconnect,
|
||||
ConfigurationError,
|
||||
@ -86,8 +62,21 @@ from pymongo.errors import ( # type:ignore[attr-defined]
|
||||
WaitQueueTimeoutError,
|
||||
_CertificateError,
|
||||
)
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.lock import _ACondition, _ALock, _create_lock
|
||||
from pymongo.logger import (
|
||||
_CONNECTION_LOGGER,
|
||||
_ConnectionStatusMessage,
|
||||
_debug_log,
|
||||
_verbose_connection_error_reason,
|
||||
)
|
||||
from pymongo.monitoring import (
|
||||
ConnectionCheckOutFailedReason,
|
||||
ConnectionClosedReason,
|
||||
)
|
||||
from pymongo.network_layer import async_sendall
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.server_api import _add_to_command
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.socket_checker import SocketChecker
|
||||
@ -96,22 +85,19 @@ from pymongo.ssl_support import HAS_SNI, SSLError
|
||||
if TYPE_CHECKING:
|
||||
from bson import CodecOptions
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.asynchronous.auth import MongoCredential, _AuthContext
|
||||
from pymongo.asynchronous.client_session import ClientSession
|
||||
from pymongo.asynchronous.compression_support import (
|
||||
CompressionSettings,
|
||||
from pymongo.asynchronous.auth import _AuthContext
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler
|
||||
from pymongo.compression_support import (
|
||||
SnappyContext,
|
||||
ZlibContext,
|
||||
ZstdContext,
|
||||
)
|
||||
from pymongo.asynchronous.message import _OpMsg, _OpReply
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler
|
||||
from pymongo.asynchronous.read_preferences import _ServerMode
|
||||
from pymongo.asynchronous.typings import ClusterTime, _Address, _CollationIn
|
||||
from pymongo.driver_info import DriverInfo
|
||||
from pymongo.pyopenssl_context import SSLContext, _sslConn
|
||||
from pymongo.message import _OpMsg, _OpReply
|
||||
from pymongo.pyopenssl_context import _sslConn
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.typings import ClusterTime, _Address, _CollationIn
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
try:
|
||||
@ -191,236 +177,6 @@ else:
|
||||
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
|
||||
|
||||
|
||||
_METADATA: dict[str, Any] = {"driver": {"name": "PyMongo", "version": __version__}}
|
||||
|
||||
if sys.platform.startswith("linux"):
|
||||
_METADATA["os"] = {
|
||||
"type": platform.system(),
|
||||
"name": platform.system(),
|
||||
"architecture": platform.machine(),
|
||||
# Kernel version (e.g. 4.4.0-17-generic).
|
||||
"version": platform.release(),
|
||||
}
|
||||
elif sys.platform == "darwin":
|
||||
_METADATA["os"] = {
|
||||
"type": platform.system(),
|
||||
"name": platform.system(),
|
||||
"architecture": platform.machine(),
|
||||
# (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin
|
||||
# kernel version.
|
||||
"version": platform.mac_ver()[0],
|
||||
}
|
||||
elif sys.platform == "win32":
|
||||
_ver = sys.getwindowsversion()
|
||||
_METADATA["os"] = {
|
||||
"type": "Windows",
|
||||
"name": "Windows",
|
||||
# Avoid using platform calls, see PYTHON-4455.
|
||||
"architecture": os.environ.get("PROCESSOR_ARCHITECTURE") or platform.machine(),
|
||||
# Windows patch level (e.g. 10.0.17763-SP0).
|
||||
"version": ".".join(map(str, _ver[:3])) + f"-SP{_ver[-1] or '0'}",
|
||||
}
|
||||
elif sys.platform.startswith("java"):
|
||||
_name, _ver, _arch = platform.java_ver()[-1]
|
||||
_METADATA["os"] = {
|
||||
# Linux, Windows 7, Mac OS X, etc.
|
||||
"type": _name,
|
||||
"name": _name,
|
||||
# x86, x86_64, AMD64, etc.
|
||||
"architecture": _arch,
|
||||
# Linux kernel version, OSX version, etc.
|
||||
"version": _ver,
|
||||
}
|
||||
else:
|
||||
# Get potential alias (e.g. SunOS 5.11 becomes Solaris 2.11)
|
||||
_aliased = platform.system_alias(platform.system(), platform.release(), platform.version())
|
||||
_METADATA["os"] = {
|
||||
"type": platform.system(),
|
||||
"name": " ".join([part for part in _aliased[:2] if part]),
|
||||
"architecture": platform.machine(),
|
||||
"version": _aliased[2],
|
||||
}
|
||||
|
||||
if platform.python_implementation().startswith("PyPy"):
|
||||
_METADATA["platform"] = " ".join(
|
||||
(
|
||||
platform.python_implementation(),
|
||||
".".join(map(str, sys.pypy_version_info)), # type: ignore
|
||||
"(Python %s)" % ".".join(map(str, sys.version_info)),
|
||||
)
|
||||
)
|
||||
elif sys.platform.startswith("java"):
|
||||
_METADATA["platform"] = " ".join(
|
||||
(
|
||||
platform.python_implementation(),
|
||||
".".join(map(str, sys.version_info)),
|
||||
"(%s)" % " ".join((platform.system(), platform.release())),
|
||||
)
|
||||
)
|
||||
else:
|
||||
_METADATA["platform"] = " ".join(
|
||||
(platform.python_implementation(), ".".join(map(str, sys.version_info)))
|
||||
)
|
||||
|
||||
DOCKER_ENV_PATH = "/.dockerenv"
|
||||
ENV_VAR_K8S = "KUBERNETES_SERVICE_HOST"
|
||||
|
||||
RUNTIME_NAME_DOCKER = "docker"
|
||||
ORCHESTRATOR_NAME_K8S = "kubernetes"
|
||||
|
||||
|
||||
def get_container_env_info() -> dict[str, str]:
|
||||
"""Returns the runtime and orchestrator of a container.
|
||||
If neither value is present, the metadata client.env.container field will be omitted."""
|
||||
container = {}
|
||||
|
||||
if Path(DOCKER_ENV_PATH).exists():
|
||||
container["runtime"] = RUNTIME_NAME_DOCKER
|
||||
if os.getenv(ENV_VAR_K8S):
|
||||
container["orchestrator"] = ORCHESTRATOR_NAME_K8S
|
||||
|
||||
return container
|
||||
|
||||
|
||||
def _is_lambda() -> bool:
|
||||
if os.getenv("AWS_LAMBDA_RUNTIME_API"):
|
||||
return True
|
||||
env = os.getenv("AWS_EXECUTION_ENV")
|
||||
if env:
|
||||
return env.startswith("AWS_Lambda_")
|
||||
return False
|
||||
|
||||
|
||||
def _is_azure_func() -> bool:
|
||||
return bool(os.getenv("FUNCTIONS_WORKER_RUNTIME"))
|
||||
|
||||
|
||||
def _is_gcp_func() -> bool:
|
||||
return bool(os.getenv("K_SERVICE") or os.getenv("FUNCTION_NAME"))
|
||||
|
||||
|
||||
def _is_vercel() -> bool:
|
||||
return bool(os.getenv("VERCEL"))
|
||||
|
||||
|
||||
def _is_faas() -> bool:
|
||||
return _is_lambda() or _is_azure_func() or _is_gcp_func() or _is_vercel()
|
||||
|
||||
|
||||
def _getenv_int(key: str) -> Optional[int]:
|
||||
"""Like os.getenv but returns an int, or None if the value is missing/malformed."""
|
||||
val = os.getenv(key)
|
||||
if not val:
|
||||
return None
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _metadata_env() -> dict[str, Any]:
|
||||
env: dict[str, Any] = {}
|
||||
container = get_container_env_info()
|
||||
if container:
|
||||
env["container"] = container
|
||||
# Skip if multiple (or no) envs are matched.
|
||||
if (_is_lambda(), _is_azure_func(), _is_gcp_func(), _is_vercel()).count(True) != 1:
|
||||
return env
|
||||
if _is_lambda():
|
||||
env["name"] = "aws.lambda"
|
||||
region = os.getenv("AWS_REGION")
|
||||
if region:
|
||||
env["region"] = region
|
||||
memory_mb = _getenv_int("AWS_LAMBDA_FUNCTION_MEMORY_SIZE")
|
||||
if memory_mb is not None:
|
||||
env["memory_mb"] = memory_mb
|
||||
elif _is_azure_func():
|
||||
env["name"] = "azure.func"
|
||||
elif _is_gcp_func():
|
||||
env["name"] = "gcp.func"
|
||||
region = os.getenv("FUNCTION_REGION")
|
||||
if region:
|
||||
env["region"] = region
|
||||
memory_mb = _getenv_int("FUNCTION_MEMORY_MB")
|
||||
if memory_mb is not None:
|
||||
env["memory_mb"] = memory_mb
|
||||
timeout_sec = _getenv_int("FUNCTION_TIMEOUT_SEC")
|
||||
if timeout_sec is not None:
|
||||
env["timeout_sec"] = timeout_sec
|
||||
elif _is_vercel():
|
||||
env["name"] = "vercel"
|
||||
region = os.getenv("VERCEL_REGION")
|
||||
if region:
|
||||
env["region"] = region
|
||||
return env
|
||||
|
||||
|
||||
_MAX_METADATA_SIZE = 512
|
||||
|
||||
|
||||
# See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations
|
||||
def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None:
|
||||
"""Perform metadata truncation."""
|
||||
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
|
||||
return
|
||||
# 1. Omit fields from env except env.name.
|
||||
env_name = metadata.get("env", {}).get("name")
|
||||
if env_name:
|
||||
metadata["env"] = {"name": env_name}
|
||||
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
|
||||
return
|
||||
# 2. Omit fields from os except os.type.
|
||||
os_type = metadata.get("os", {}).get("type")
|
||||
if os_type:
|
||||
metadata["os"] = {"type": os_type}
|
||||
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
|
||||
return
|
||||
# 3. Omit the env document entirely.
|
||||
metadata.pop("env", None)
|
||||
encoded_size = len(bson.encode(metadata))
|
||||
if encoded_size <= _MAX_METADATA_SIZE:
|
||||
return
|
||||
# 4. Truncate platform.
|
||||
overflow = encoded_size - _MAX_METADATA_SIZE
|
||||
plat = metadata.get("platform", "")
|
||||
if plat:
|
||||
plat = plat[:-overflow]
|
||||
if plat:
|
||||
metadata["platform"] = plat
|
||||
else:
|
||||
metadata.pop("platform", None)
|
||||
encoded_size = len(bson.encode(metadata))
|
||||
if encoded_size <= _MAX_METADATA_SIZE:
|
||||
return
|
||||
# 5. Truncate driver info.
|
||||
overflow = encoded_size - _MAX_METADATA_SIZE
|
||||
driver = metadata.get("driver", {})
|
||||
if driver:
|
||||
# Truncate driver version.
|
||||
driver_version = driver.get("version")[:-overflow]
|
||||
if len(driver_version) >= len(_METADATA["driver"]["version"]):
|
||||
metadata["driver"]["version"] = driver_version
|
||||
else:
|
||||
metadata["driver"]["version"] = _METADATA["driver"]["version"]
|
||||
encoded_size = len(bson.encode(metadata))
|
||||
if encoded_size <= _MAX_METADATA_SIZE:
|
||||
return
|
||||
# Truncate driver name.
|
||||
overflow = encoded_size - _MAX_METADATA_SIZE
|
||||
driver_name = driver.get("name")[:-overflow]
|
||||
if len(driver_name) >= len(_METADATA["driver"]["name"]):
|
||||
metadata["driver"]["name"] = driver_name
|
||||
else:
|
||||
metadata["driver"]["name"] = _METADATA["driver"]["name"]
|
||||
|
||||
|
||||
# If the first getaddrinfo call of this interpreter's life is on a thread,
|
||||
# while the main thread holds the import lock, getaddrinfo deadlocks trying
|
||||
# to import the IDNA codec. Import it here, where presumably we're on the
|
||||
# main thread, to avoid the deadlock. See PYTHON-607.
|
||||
"foo".encode("idna")
|
||||
|
||||
|
||||
def _raise_connection_failure(
|
||||
address: Any,
|
||||
error: Exception,
|
||||
@ -481,238 +237,6 @@ def format_timeout_details(details: Optional[dict[str, float]]) -> str:
|
||||
return result
|
||||
|
||||
|
||||
class PoolOptions:
|
||||
"""Read only connection pool options for an AsyncMongoClient.
|
||||
|
||||
Should not be instantiated directly by application developers. Access
|
||||
a client's pool options via
|
||||
:attr:`~pymongo.client_options.ClientOptions.pool_options` instead::
|
||||
|
||||
pool_opts = client.options.pool_options
|
||||
pool_opts.max_pool_size
|
||||
pool_opts.min_pool_size
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"__max_pool_size",
|
||||
"__min_pool_size",
|
||||
"__max_idle_time_seconds",
|
||||
"__connect_timeout",
|
||||
"__socket_timeout",
|
||||
"__wait_queue_timeout",
|
||||
"__ssl_context",
|
||||
"__tls_allow_invalid_hostnames",
|
||||
"__event_listeners",
|
||||
"__appname",
|
||||
"__driver",
|
||||
"__metadata",
|
||||
"__compression_settings",
|
||||
"__max_connecting",
|
||||
"__pause_enabled",
|
||||
"__server_api",
|
||||
"__load_balanced",
|
||||
"__credentials",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_pool_size: int = MAX_POOL_SIZE,
|
||||
min_pool_size: int = MIN_POOL_SIZE,
|
||||
max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC,
|
||||
connect_timeout: Optional[float] = None,
|
||||
socket_timeout: Optional[float] = None,
|
||||
wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT,
|
||||
ssl_context: Optional[SSLContext] = None,
|
||||
tls_allow_invalid_hostnames: bool = False,
|
||||
event_listeners: Optional[_EventListeners] = None,
|
||||
appname: Optional[str] = None,
|
||||
driver: Optional[DriverInfo] = None,
|
||||
compression_settings: Optional[CompressionSettings] = None,
|
||||
max_connecting: int = MAX_CONNECTING,
|
||||
pause_enabled: bool = True,
|
||||
server_api: Optional[ServerApi] = None,
|
||||
load_balanced: Optional[bool] = None,
|
||||
credentials: Optional[MongoCredential] = None,
|
||||
):
|
||||
self.__max_pool_size = max_pool_size
|
||||
self.__min_pool_size = min_pool_size
|
||||
self.__max_idle_time_seconds = max_idle_time_seconds
|
||||
self.__connect_timeout = connect_timeout
|
||||
self.__socket_timeout = socket_timeout
|
||||
self.__wait_queue_timeout = wait_queue_timeout
|
||||
self.__ssl_context = ssl_context
|
||||
self.__tls_allow_invalid_hostnames = tls_allow_invalid_hostnames
|
||||
self.__event_listeners = event_listeners
|
||||
self.__appname = appname
|
||||
self.__driver = driver
|
||||
self.__compression_settings = compression_settings
|
||||
self.__max_connecting = max_connecting
|
||||
self.__pause_enabled = pause_enabled
|
||||
self.__server_api = server_api
|
||||
self.__load_balanced = load_balanced
|
||||
self.__credentials = credentials
|
||||
self.__metadata = copy.deepcopy(_METADATA)
|
||||
if appname:
|
||||
self.__metadata["application"] = {"name": appname}
|
||||
|
||||
# Combine the "driver" AsyncMongoClient option with PyMongo's info, like:
|
||||
# {
|
||||
# 'driver': {
|
||||
# 'name': 'PyMongo|MyDriver',
|
||||
# 'version': '4.2.0|1.2.3',
|
||||
# },
|
||||
# 'platform': 'CPython 3.8.0|MyPlatform'
|
||||
# }
|
||||
if driver:
|
||||
if driver.name:
|
||||
self.__metadata["driver"]["name"] = "{}|{}".format(
|
||||
_METADATA["driver"]["name"],
|
||||
driver.name,
|
||||
)
|
||||
if driver.version:
|
||||
self.__metadata["driver"]["version"] = "{}|{}".format(
|
||||
_METADATA["driver"]["version"],
|
||||
driver.version,
|
||||
)
|
||||
if driver.platform:
|
||||
self.__metadata["platform"] = "{}|{}".format(_METADATA["platform"], driver.platform)
|
||||
|
||||
env = _metadata_env()
|
||||
if env:
|
||||
self.__metadata["env"] = env
|
||||
|
||||
_truncate_metadata(self.__metadata)
|
||||
|
||||
@property
|
||||
def _credentials(self) -> Optional[MongoCredential]:
|
||||
"""A :class:`~pymongo.auth.MongoCredentials` instance or None."""
|
||||
return self.__credentials
|
||||
|
||||
@property
|
||||
def non_default_options(self) -> dict[str, Any]:
|
||||
"""The non-default options this pool was created with.
|
||||
|
||||
Added for CMAP's :class:`PoolCreatedEvent`.
|
||||
"""
|
||||
opts = {}
|
||||
if self.__max_pool_size != MAX_POOL_SIZE:
|
||||
opts["maxPoolSize"] = self.__max_pool_size
|
||||
if self.__min_pool_size != MIN_POOL_SIZE:
|
||||
opts["minPoolSize"] = self.__min_pool_size
|
||||
if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC:
|
||||
assert self.__max_idle_time_seconds is not None
|
||||
opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000
|
||||
if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT:
|
||||
assert self.__wait_queue_timeout is not None
|
||||
opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000
|
||||
if self.__max_connecting != MAX_CONNECTING:
|
||||
opts["maxConnecting"] = self.__max_connecting
|
||||
return opts
|
||||
|
||||
@property
|
||||
def max_pool_size(self) -> float:
|
||||
"""The maximum allowable number of concurrent connections to each
|
||||
connected server. Requests to a server will block if there are
|
||||
`maxPoolSize` outstanding connections to the requested server.
|
||||
Defaults to 100. Cannot be 0.
|
||||
|
||||
When a server's pool has reached `max_pool_size`, operations for that
|
||||
server block waiting for a socket to be returned to the pool. If
|
||||
``waitQueueTimeoutMS`` is set, a blocked operation will raise
|
||||
:exc:`~pymongo.errors.ConnectionFailure` after a timeout.
|
||||
By default ``waitQueueTimeoutMS`` is not set.
|
||||
"""
|
||||
return self.__max_pool_size
|
||||
|
||||
@property
|
||||
def min_pool_size(self) -> int:
|
||||
"""The minimum required number of concurrent connections that the pool
|
||||
will maintain to each connected server. Default is 0.
|
||||
"""
|
||||
return self.__min_pool_size
|
||||
|
||||
@property
|
||||
def max_connecting(self) -> int:
|
||||
"""The maximum number of concurrent connection creation attempts per
|
||||
pool. Defaults to 2.
|
||||
"""
|
||||
return self.__max_connecting
|
||||
|
||||
@property
|
||||
def pause_enabled(self) -> bool:
|
||||
return self.__pause_enabled
|
||||
|
||||
@property
|
||||
def max_idle_time_seconds(self) -> Optional[int]:
|
||||
"""The maximum number of seconds that a connection can remain
|
||||
idle in the pool before being removed and replaced. Defaults to
|
||||
`None` (no limit).
|
||||
"""
|
||||
return self.__max_idle_time_seconds
|
||||
|
||||
@property
|
||||
def connect_timeout(self) -> Optional[float]:
|
||||
"""How long a connection can take to be opened before timing out."""
|
||||
return self.__connect_timeout
|
||||
|
||||
@property
|
||||
def socket_timeout(self) -> Optional[float]:
|
||||
"""How long a send or receive on a socket can take before timing out."""
|
||||
return self.__socket_timeout
|
||||
|
||||
@property
|
||||
def wait_queue_timeout(self) -> Optional[int]:
|
||||
"""How long a thread will wait for a socket from the pool if the pool
|
||||
has no free sockets.
|
||||
"""
|
||||
return self.__wait_queue_timeout
|
||||
|
||||
@property
|
||||
def _ssl_context(self) -> Optional[SSLContext]:
|
||||
"""An SSLContext instance or None."""
|
||||
return self.__ssl_context
|
||||
|
||||
@property
|
||||
def tls_allow_invalid_hostnames(self) -> bool:
|
||||
"""If True skip ssl.match_hostname."""
|
||||
return self.__tls_allow_invalid_hostnames
|
||||
|
||||
@property
|
||||
def _event_listeners(self) -> Optional[_EventListeners]:
|
||||
"""An instance of pymongo.monitoring._EventListeners."""
|
||||
return self.__event_listeners
|
||||
|
||||
@property
|
||||
def appname(self) -> Optional[str]:
|
||||
"""The application name, for sending with hello in server handshake."""
|
||||
return self.__appname
|
||||
|
||||
@property
|
||||
def driver(self) -> Optional[DriverInfo]:
|
||||
"""Driver name and version, for sending with hello in handshake."""
|
||||
return self.__driver
|
||||
|
||||
@property
|
||||
def _compression_settings(self) -> Optional[CompressionSettings]:
|
||||
return self.__compression_settings
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
"""A dict of metadata about the application, driver, os, and platform."""
|
||||
return self.__metadata.copy()
|
||||
|
||||
@property
|
||||
def server_api(self) -> Optional[ServerApi]:
|
||||
"""A pymongo.server_api.ServerApi or None."""
|
||||
return self.__server_api
|
||||
|
||||
@property
|
||||
def load_balanced(self) -> Optional[bool]:
|
||||
"""True if this Pool is configured in load balanced mode."""
|
||||
return self.__load_balanced
|
||||
|
||||
|
||||
class _CancellationContext:
|
||||
def __init__(self) -> None:
|
||||
self._cancelled = False
|
||||
@ -727,7 +251,7 @@ class _CancellationContext:
|
||||
return self._cancelled
|
||||
|
||||
|
||||
class Connection:
|
||||
class AsyncConnection:
|
||||
"""Store a connection with some metadata.
|
||||
|
||||
:param conn: a raw connection object
|
||||
@ -946,7 +470,7 @@ class Connection:
|
||||
self.more_to_come = reply.more_to_come
|
||||
unpacked_docs = reply.unpack_response()
|
||||
response_doc = unpacked_docs[0]
|
||||
helpers._check_command_response(response_doc, self.max_wire_version)
|
||||
helpers_shared._check_command_response(response_doc, self.max_wire_version)
|
||||
return response_doc
|
||||
|
||||
@_handle_reauth
|
||||
@ -962,7 +486,7 @@ class Connection:
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
client: Optional[AsyncMongoClient] = None,
|
||||
retryable_write: bool = False,
|
||||
publish_events: bool = True,
|
||||
@ -982,7 +506,7 @@ class Connection:
|
||||
:param parse_write_concern_error: Whether to parse the
|
||||
``writeConcernError`` field in the command response.
|
||||
:param collation: The collation for this command.
|
||||
:param session: optional ClientSession instance.
|
||||
:param session: optional AsyncClientSession instance.
|
||||
:param client: optional AsyncMongoClient for gossipping $clusterTime.
|
||||
:param retryable_write: True if this command is a retryable write.
|
||||
:param publish_events: Should we publish events for this command?
|
||||
@ -1099,7 +623,7 @@ class Connection:
|
||||
result = reply.command_response(codec_options)
|
||||
|
||||
# Raises NotPrimaryError or OperationFailure.
|
||||
helpers._check_command_response(result, self.max_wire_version)
|
||||
helpers_shared._check_command_response(result, self.max_wire_version)
|
||||
return result
|
||||
|
||||
async def authenticate(self, reauthenticate: bool = False) -> None:
|
||||
@ -1137,7 +661,7 @@ class Connection:
|
||||
)
|
||||
|
||||
def validate_session(
|
||||
self, client: Optional[AsyncMongoClient], session: Optional[ClientSession]
|
||||
self, client: Optional[AsyncMongoClient], session: Optional[AsyncClientSession]
|
||||
) -> None:
|
||||
"""Validate this session before use with client.
|
||||
|
||||
@ -1190,7 +714,7 @@ class Connection:
|
||||
def send_cluster_time(
|
||||
self,
|
||||
command: MutableMapping[str, Any],
|
||||
session: Optional[ClientSession],
|
||||
session: Optional[AsyncClientSession],
|
||||
client: Optional[AsyncMongoClient],
|
||||
) -> None:
|
||||
"""Add $clusterTime."""
|
||||
@ -1250,7 +774,7 @@ class Connection:
|
||||
return hash(self.conn)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Connection({}){} at {}".format(
|
||||
return "AsyncConnection({}){} at {}".format(
|
||||
repr(self.conn),
|
||||
self.closed and " CLOSED" or "",
|
||||
id(self),
|
||||
@ -1441,7 +965,7 @@ class Pool:
|
||||
"""
|
||||
:param address: a (hostname, port) tuple
|
||||
:param options: a PoolOptions instance
|
||||
:param handshake: whether to call hello for each new Connection
|
||||
:param handshake: whether to call hello for each new AsyncConnection
|
||||
"""
|
||||
if options.pause_enabled:
|
||||
self.state = PoolState.PAUSED
|
||||
@ -1512,7 +1036,7 @@ class Pool:
|
||||
# Retain references to pinned connections to prevent the CPython GC
|
||||
# from thinking that a cursor's pinned connection can be GC'd when the
|
||||
# cursor is GC'd (see PYTHON-2751).
|
||||
self.__pinned_sockets: set[Connection] = set()
|
||||
self.__pinned_sockets: set[AsyncConnection] = set()
|
||||
self.ncursors = 0
|
||||
self.ntxns = 0
|
||||
|
||||
@ -1700,8 +1224,8 @@ class Pool:
|
||||
self.requests -= 1
|
||||
self.size_cond.notify()
|
||||
|
||||
async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection:
|
||||
"""Connect to Mongo and return a new Connection.
|
||||
async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnection:
|
||||
"""Connect to Mongo and return a new AsyncConnection.
|
||||
|
||||
Can raise ConnectionFailure.
|
||||
|
||||
@ -1751,7 +1275,7 @@ class Pool:
|
||||
|
||||
raise
|
||||
|
||||
conn = Connection(sock, self, self.address, conn_id) # type: ignore[arg-type]
|
||||
conn = AsyncConnection(sock, self, self.address, conn_id) # type: ignore[arg-type]
|
||||
async with self.lock:
|
||||
self.active_contexts.add(conn.cancel_context)
|
||||
try:
|
||||
@ -1771,10 +1295,10 @@ class Pool:
|
||||
@contextlib.asynccontextmanager
|
||||
async def checkout(
|
||||
self, handler: Optional[_MongoClientErrorHandler] = None
|
||||
) -> AsyncGenerator[Connection, None]:
|
||||
) -> AsyncGenerator[AsyncConnection, None]:
|
||||
"""Get a connection from the pool. Use with a "with" statement.
|
||||
|
||||
Returns a :class:`Connection` object wrapping a connected
|
||||
Returns a :class:`AsyncConnection` object wrapping a connected
|
||||
:class:`socket.socket`.
|
||||
|
||||
This method should always be used in a with-statement::
|
||||
@ -1874,8 +1398,8 @@ class Pool:
|
||||
|
||||
async def _get_conn(
|
||||
self, checkout_started_time: float, handler: Optional[_MongoClientErrorHandler] = None
|
||||
) -> Connection:
|
||||
"""Get or create a Connection. Can raise ConnectionFailure."""
|
||||
) -> AsyncConnection:
|
||||
"""Get or create a AsyncConnection. Can raise ConnectionFailure."""
|
||||
# We use the pid here to avoid issues with fork / multiprocessing.
|
||||
# See test.test_client:TestClient.test_fork for an example of
|
||||
# what could go wrong otherwise
|
||||
@ -1998,7 +1522,7 @@ class Pool:
|
||||
conn.active = True
|
||||
return conn
|
||||
|
||||
async def checkin(self, conn: Connection) -> None:
|
||||
async def checkin(self, conn: AsyncConnection) -> None:
|
||||
"""Return the connection to the pool, or if it's closed discard it.
|
||||
|
||||
:param conn: The connection to check into the pool.
|
||||
@ -2070,7 +1594,7 @@ class Pool:
|
||||
self.operation_count -= 1
|
||||
self.size_cond.notify()
|
||||
|
||||
def _perished(self, conn: Connection) -> bool:
|
||||
def _perished(self, conn: AsyncConnection) -> bool:
|
||||
"""Return True and close the connection if it is "perished".
|
||||
|
||||
This side-effecty function checks if this socket has been idle for
|
||||
|
||||
@ -1,624 +0,0 @@
|
||||
# Copyright 2012-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License",
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for choosing which member of a replica set to read from."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import abc
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
|
||||
|
||||
from pymongo.asynchronous import max_staleness_selectors
|
||||
from pymongo.asynchronous.server_selectors import (
|
||||
member_with_tags_server_selector,
|
||||
secondary_with_tags_server_selector,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.server_selectors import Selection
|
||||
from pymongo.asynchronous.topology_description import TopologyDescription
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
_PRIMARY = 0
|
||||
_PRIMARY_PREFERRED = 1
|
||||
_SECONDARY = 2
|
||||
_SECONDARY_PREFERRED = 3
|
||||
_NEAREST = 4
|
||||
|
||||
|
||||
_MONGOS_MODES = (
|
||||
"primary",
|
||||
"primaryPreferred",
|
||||
"secondary",
|
||||
"secondaryPreferred",
|
||||
"nearest",
|
||||
)
|
||||
|
||||
_Hedge = Mapping[str, Any]
|
||||
_TagSets = Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]:
|
||||
"""Validate tag sets for a MongoClient."""
|
||||
if tag_sets is None:
|
||||
return tag_sets
|
||||
|
||||
if not isinstance(tag_sets, (list, tuple)):
|
||||
raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence")
|
||||
if len(tag_sets) == 0:
|
||||
raise ValueError(
|
||||
f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags"
|
||||
)
|
||||
|
||||
for tags in tag_sets:
|
||||
if not isinstance(tags, abc.Mapping):
|
||||
raise TypeError(
|
||||
f"Tag set {tags!r} invalid, must be an instance of dict, "
|
||||
"bson.son.SON or other type that inherits from "
|
||||
"collection.Mapping"
|
||||
)
|
||||
|
||||
return list(tag_sets)
|
||||
|
||||
|
||||
def _invalid_max_staleness_msg(max_staleness: Any) -> str:
|
||||
return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness
|
||||
|
||||
|
||||
# Some duplication with common.py to avoid import cycle.
|
||||
def _validate_max_staleness(max_staleness: Any) -> int:
|
||||
"""Validate max_staleness."""
|
||||
if max_staleness == -1:
|
||||
return -1
|
||||
|
||||
if not isinstance(max_staleness, int):
|
||||
raise TypeError(_invalid_max_staleness_msg(max_staleness))
|
||||
|
||||
if max_staleness <= 0:
|
||||
raise ValueError(_invalid_max_staleness_msg(max_staleness))
|
||||
|
||||
return max_staleness
|
||||
|
||||
|
||||
def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]:
|
||||
"""Validate hedge."""
|
||||
if hedge is None:
|
||||
return None
|
||||
|
||||
if not isinstance(hedge, dict):
|
||||
raise TypeError(f"hedge must be a dictionary, not {hedge!r}")
|
||||
|
||||
return hedge
|
||||
|
||||
|
||||
class _ServerMode:
|
||||
"""Base class for all read preferences."""
|
||||
|
||||
__slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: int,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
self.__mongos_mode = _MONGOS_MODES[mode]
|
||||
self.__mode = mode
|
||||
self.__tag_sets = _validate_tag_sets(tag_sets)
|
||||
self.__max_staleness = _validate_max_staleness(max_staleness)
|
||||
self.__hedge = _validate_hedge(hedge)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of this read preference."""
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def mongos_mode(self) -> str:
|
||||
"""The mongos mode of this read preference."""
|
||||
return self.__mongos_mode
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""Read preference as a document."""
|
||||
doc: dict[str, Any] = {"mode": self.__mongos_mode}
|
||||
if self.__tag_sets not in (None, [{}]):
|
||||
doc["tags"] = self.__tag_sets
|
||||
if self.__max_staleness != -1:
|
||||
doc["maxStalenessSeconds"] = self.__max_staleness
|
||||
if self.__hedge not in (None, {}):
|
||||
doc["hedge"] = self.__hedge
|
||||
return doc
|
||||
|
||||
@property
|
||||
def mode(self) -> int:
|
||||
"""The mode of this read preference instance."""
|
||||
return self.__mode
|
||||
|
||||
@property
|
||||
def tag_sets(self) -> _TagSets:
|
||||
"""Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to
|
||||
read only from members whose ``dc`` tag has the value ``"ny"``.
|
||||
To specify a priority-order for tag sets, provide a list of
|
||||
tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag
|
||||
set, ``{}``, means "read from any member that matches the mode,
|
||||
ignoring tags." MongoClient tries each set of tags in turn
|
||||
until it finds a set of tags with at least one matching member.
|
||||
For example, to only send a query to an analytic node::
|
||||
|
||||
Nearest(tag_sets=[{"node":"analytics"}])
|
||||
|
||||
Or using :class:`SecondaryPreferred`::
|
||||
|
||||
SecondaryPreferred(tag_sets=[{"node":"analytics"}])
|
||||
|
||||
.. seealso:: `Data-Center Awareness
|
||||
<https://www.mongodb.com/docs/manual/data-center-awareness/>`_
|
||||
"""
|
||||
return list(self.__tag_sets) if self.__tag_sets else [{}]
|
||||
|
||||
@property
|
||||
def max_staleness(self) -> int:
|
||||
"""The maximum estimated length of time (in seconds) a replica set
|
||||
secondary can fall behind the primary in replication before it will
|
||||
no longer be selected for operations, or -1 for no maximum.
|
||||
"""
|
||||
return self.__max_staleness
|
||||
|
||||
@property
|
||||
def hedge(self) -> Optional[_Hedge]:
|
||||
"""The read preference ``hedge`` parameter.
|
||||
|
||||
A dictionary that configures how the server will perform hedged reads.
|
||||
It consists of the following keys:
|
||||
|
||||
- ``enabled``: Enables or disables hedged reads in sharded clusters.
|
||||
|
||||
Hedged reads are automatically enabled in MongoDB 4.4+ when using a
|
||||
``nearest`` read preference. To explicitly enable hedged reads, set
|
||||
the ``enabled`` key to ``true``::
|
||||
|
||||
>>> Nearest(hedge={'enabled': True})
|
||||
|
||||
To explicitly disable hedged reads, set the ``enabled`` key to
|
||||
``False``::
|
||||
|
||||
>>> Nearest(hedge={'enabled': False})
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
return self.__hedge
|
||||
|
||||
@property
|
||||
def min_wire_version(self) -> int:
|
||||
"""The wire protocol version the server must support.
|
||||
|
||||
Some read preferences impose version requirements on all servers (e.g.
|
||||
maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5).
|
||||
|
||||
All servers' maxWireVersion must be at least this read preference's
|
||||
`min_wire_version`, or the driver raises
|
||||
:exc:`~pymongo.errors.ConfigurationError`.
|
||||
"""
|
||||
return 0 if self.__max_staleness == -1 else 5
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format(
|
||||
self.name,
|
||||
self.__tag_sets,
|
||||
self.__max_staleness,
|
||||
self.__hedge,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, _ServerMode):
|
||||
return (
|
||||
self.mode == other.mode
|
||||
and self.tag_sets == other.tag_sets
|
||||
and self.max_staleness == other.max_staleness
|
||||
and self.hedge == other.hedge
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
"""Return value of object for pickling.
|
||||
|
||||
Needed explicitly because __slots__() defined.
|
||||
"""
|
||||
return {
|
||||
"mode": self.__mode,
|
||||
"tag_sets": self.__tag_sets,
|
||||
"max_staleness": self.__max_staleness,
|
||||
"hedge": self.__hedge,
|
||||
}
|
||||
|
||||
def __setstate__(self, value: Mapping[str, Any]) -> None:
|
||||
"""Restore from pickling."""
|
||||
self.__mode = value["mode"]
|
||||
self.__mongos_mode = _MONGOS_MODES[self.__mode]
|
||||
self.__tag_sets = _validate_tag_sets(value["tag_sets"])
|
||||
self.__max_staleness = _validate_max_staleness(value["max_staleness"])
|
||||
self.__hedge = _validate_hedge(value["hedge"])
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
return selection
|
||||
|
||||
|
||||
class Primary(_ServerMode):
|
||||
"""Primary read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed if the server
|
||||
is standalone or a replica set primary.
|
||||
* When connected to a mongos queries are sent to the primary of a shard.
|
||||
* When connected to a replica set queries are sent to the primary of
|
||||
the replica set.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(_PRIMARY)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to a Selection."""
|
||||
return selection.primary_selection
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Primary()"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, _ServerMode):
|
||||
return other.mode == _PRIMARY
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class PrimaryPreferred(_ServerMode):
|
||||
"""PrimaryPreferred read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are sent to the primary of a shard if
|
||||
available, otherwise a shard secondary.
|
||||
* When connected to a replica set queries are sent to the primary if
|
||||
available, otherwise a secondary.
|
||||
|
||||
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
|
||||
created reads will be routed to an available secondary until the
|
||||
primary of the replica set is discovered.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` to use if the primary is not
|
||||
available.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` to use if the primary is not available.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
if selection.primary:
|
||||
return selection.primary_selection
|
||||
else:
|
||||
return secondary_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
|
||||
class Secondary(_ServerMode):
|
||||
"""Secondary read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are distributed among shard
|
||||
secondaries. An error is raised if no secondaries are available.
|
||||
* When connected to a replica set queries are distributed among
|
||||
secondaries. An error is raised if no secondaries are available.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` for this read preference.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` for this read preference.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_SECONDARY, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
return secondary_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
|
||||
class SecondaryPreferred(_ServerMode):
|
||||
"""SecondaryPreferred read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are distributed among shard
|
||||
secondaries, or the shard primary if no secondary is available.
|
||||
* When connected to a replica set queries are distributed among
|
||||
secondaries, or the primary if no secondary is available.
|
||||
|
||||
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
|
||||
created reads will be routed to the primary of the replica set until
|
||||
an available secondary is discovered.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` for this read preference.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` for this read preference.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
secondaries = secondary_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
if secondaries:
|
||||
return secondaries
|
||||
else:
|
||||
return selection.primary_selection
|
||||
|
||||
|
||||
class Nearest(_ServerMode):
|
||||
"""Nearest read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are distributed among all members of
|
||||
a shard.
|
||||
* When connected to a replica set queries are distributed among all
|
||||
members.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` for this read preference.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` for this read preference.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_NEAREST, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
return member_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
|
||||
class _AggWritePref:
|
||||
"""Agg $out/$merge write preference.
|
||||
|
||||
* If there are readable servers and there is any pre-5.0 server, use
|
||||
primary read preference.
|
||||
* Otherwise use `pref` read preference.
|
||||
|
||||
:param pref: The read preference to use on MongoDB 5.0+.
|
||||
"""
|
||||
|
||||
__slots__ = ("pref", "effective_pref")
|
||||
|
||||
def __init__(self, pref: _ServerMode):
|
||||
self.pref = pref
|
||||
self.effective_pref: _ServerMode = ReadPreference.PRIMARY
|
||||
|
||||
def selection_hook(self, topology_description: TopologyDescription) -> None:
|
||||
common_wv = topology_description.common_wire_version
|
||||
if (
|
||||
topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED)
|
||||
and common_wv
|
||||
and common_wv < 13
|
||||
):
|
||||
self.effective_pref = ReadPreference.PRIMARY
|
||||
else:
|
||||
self.effective_pref = self.pref
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to a Selection."""
|
||||
return self.effective_pref(selection)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"_AggWritePref(pref={self.pref!r})"
|
||||
|
||||
# Proxy other calls to the effective_pref so that _AggWritePref can be
|
||||
# used in place of an actual read preference.
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self.effective_pref, name)
|
||||
|
||||
|
||||
_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest)
|
||||
|
||||
|
||||
def make_read_preference(
|
||||
mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1
|
||||
) -> _ServerMode:
|
||||
if mode == _PRIMARY:
|
||||
if tag_sets not in (None, [{}]):
|
||||
raise ConfigurationError("Read preference primary cannot be combined with tags")
|
||||
if max_staleness != -1:
|
||||
raise ConfigurationError(
|
||||
"Read preference primary cannot be combined with maxStalenessSeconds"
|
||||
)
|
||||
return Primary()
|
||||
return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore
|
||||
|
||||
|
||||
_MODES = (
|
||||
"PRIMARY",
|
||||
"PRIMARY_PREFERRED",
|
||||
"SECONDARY",
|
||||
"SECONDARY_PREFERRED",
|
||||
"NEAREST",
|
||||
)
|
||||
|
||||
|
||||
class ReadPreference:
|
||||
"""An enum that defines some commonly used read preference modes.
|
||||
|
||||
Apps can also create a custom read preference, for example::
|
||||
|
||||
Nearest(tag_sets=[{"node":"analytics"}])
|
||||
|
||||
See :doc:`/examples/high_availability` for code examples.
|
||||
|
||||
A read preference is used in three cases:
|
||||
|
||||
:class:`~pymongo.mongo_client.MongoClient` connected to a single mongod:
|
||||
|
||||
- ``PRIMARY``: Queries are allowed if the server is standalone or a replica
|
||||
set primary.
|
||||
- All other modes allow queries to standalone servers, to a replica set
|
||||
primary, or to replica set secondaries.
|
||||
|
||||
:class:`~pymongo.mongo_client.MongoClient` initialized with the
|
||||
``replicaSet`` option:
|
||||
|
||||
- ``PRIMARY``: Read from the primary. This is the default, and provides the
|
||||
strongest consistency. If no primary is available, raise
|
||||
:class:`~pymongo.errors.AutoReconnect`.
|
||||
|
||||
- ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is
|
||||
none, read from a secondary.
|
||||
|
||||
- ``SECONDARY``: Read from a secondary. If no secondary is available,
|
||||
raise :class:`~pymongo.errors.AutoReconnect`.
|
||||
|
||||
- ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise
|
||||
from the primary.
|
||||
|
||||
- ``NEAREST``: Read from any member.
|
||||
|
||||
:class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a
|
||||
sharded cluster of replica sets:
|
||||
|
||||
- ``PRIMARY``: Read from the primary of the shard, or raise
|
||||
:class:`~pymongo.errors.OperationFailure` if there is none.
|
||||
This is the default.
|
||||
|
||||
- ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is
|
||||
none, read from a secondary of the shard.
|
||||
|
||||
- ``SECONDARY``: Read from a secondary of the shard, or raise
|
||||
:class:`~pymongo.errors.OperationFailure` if there is none.
|
||||
|
||||
- ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available,
|
||||
otherwise from the shard primary.
|
||||
|
||||
- ``NEAREST``: Read from any shard member.
|
||||
"""
|
||||
|
||||
PRIMARY = Primary()
|
||||
PRIMARY_PREFERRED = PrimaryPreferred()
|
||||
SECONDARY = Secondary()
|
||||
SECONDARY_PREFERRED = SecondaryPreferred()
|
||||
NEAREST = Nearest()
|
||||
|
||||
|
||||
def read_pref_mode_from_name(name: str) -> int:
|
||||
"""Get the read preference mode from mongos/uri name."""
|
||||
return _MONGOS_MODES.index(name)
|
||||
|
||||
|
||||
class MovingAverage:
|
||||
"""Tracks an exponentially-weighted moving average."""
|
||||
|
||||
average: Optional[float]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.average = None
|
||||
|
||||
def add_sample(self, sample: float) -> None:
|
||||
if sample < 0:
|
||||
# Likely system time change while waiting for hello response
|
||||
# and not using time.monotonic. Ignore it, the next one will
|
||||
# probably be valid.
|
||||
return
|
||||
if self.average is None:
|
||||
self.average = sample
|
||||
else:
|
||||
# The Server Selection Spec requires an exponentially weighted
|
||||
# average with alpha = 0.2.
|
||||
self.average = 0.8 * self.average + 0.2 * sample
|
||||
|
||||
def get(self) -> Optional[float]:
|
||||
"""Get the calculated average, or None if no samples yet."""
|
||||
return self.average
|
||||
|
||||
def reset(self) -> None:
|
||||
self.average = None
|
||||
@ -1,133 +0,0 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Represent a response from the server."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from pymongo.asynchronous.message import _OpMsg, _OpReply
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.typings import _Address, _DocumentOut
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class Response:
|
||||
__slots__ = ("_data", "_address", "_request_id", "_duration", "_from_command", "_docs")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Union[_OpMsg, _OpReply],
|
||||
address: _Address,
|
||||
request_id: int,
|
||||
duration: Optional[timedelta],
|
||||
from_command: bool,
|
||||
docs: Sequence[Mapping[str, Any]],
|
||||
):
|
||||
"""Represent a response from the server.
|
||||
|
||||
:param data: A network response message.
|
||||
:param address: (host, port) of the source server.
|
||||
:param request_id: The request id of this operation.
|
||||
:param duration: The duration of the operation.
|
||||
:param from_command: if the response is the result of a db command.
|
||||
"""
|
||||
self._data = data
|
||||
self._address = address
|
||||
self._request_id = request_id
|
||||
self._duration = duration
|
||||
self._from_command = from_command
|
||||
self._docs = docs
|
||||
|
||||
@property
|
||||
def data(self) -> Union[_OpMsg, _OpReply]:
|
||||
"""Server response's raw BSON bytes."""
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def address(self) -> _Address:
|
||||
"""(host, port) of the source server."""
|
||||
return self._address
|
||||
|
||||
@property
|
||||
def request_id(self) -> int:
|
||||
"""The request id of this operation."""
|
||||
return self._request_id
|
||||
|
||||
@property
|
||||
def duration(self) -> Optional[timedelta]:
|
||||
"""The duration of the operation."""
|
||||
return self._duration
|
||||
|
||||
@property
|
||||
def from_command(self) -> bool:
|
||||
"""If the response is a result from a db command."""
|
||||
return self._from_command
|
||||
|
||||
@property
|
||||
def docs(self) -> Sequence[Mapping[str, Any]]:
|
||||
"""The decoded document(s)."""
|
||||
return self._docs
|
||||
|
||||
|
||||
class PinnedResponse(Response):
|
||||
__slots__ = ("_conn", "_more_to_come")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Union[_OpMsg, _OpReply],
|
||||
address: _Address,
|
||||
conn: Connection,
|
||||
request_id: int,
|
||||
duration: Optional[timedelta],
|
||||
from_command: bool,
|
||||
docs: list[_DocumentOut],
|
||||
more_to_come: bool,
|
||||
):
|
||||
"""Represent a response to an exhaust cursor's initial query.
|
||||
|
||||
:param data: A network response message.
|
||||
:param address: (host, port) of the source server.
|
||||
:param conn: The Connection used for the initial query.
|
||||
:param request_id: The request id of this operation.
|
||||
:param duration: The duration of the operation.
|
||||
:param from_command: If the response is the result of a db command.
|
||||
:param docs: List of documents.
|
||||
:param more_to_come: Bool indicating whether cursor is ready to be
|
||||
exhausted.
|
||||
"""
|
||||
super().__init__(data, address, request_id, duration, from_command, docs)
|
||||
self._conn = conn
|
||||
self._more_to_come = more_to_come
|
||||
|
||||
@property
|
||||
def conn(self) -> Connection:
|
||||
"""The Connection used for the initial query.
|
||||
|
||||
The server will send batches on this socket, without waiting for
|
||||
getMores from the client, until the result set is exhausted or there
|
||||
is an error.
|
||||
"""
|
||||
return self._conn
|
||||
|
||||
@property
|
||||
def more_to_come(self) -> bool:
|
||||
"""If true, server is ready to send batches on the socket until the
|
||||
result set is exhausted or there is an error.
|
||||
"""
|
||||
return self._more_to_come
|
||||
@ -27,11 +27,12 @@ from typing import (
|
||||
)
|
||||
|
||||
from bson import _decode_all_selective
|
||||
from pymongo.asynchronous.helpers import _check_command_response, _handle_reauth
|
||||
from pymongo.asynchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.asynchronous.message import _convert_exception, _GetMore, _OpMsg, _Query
|
||||
from pymongo.asynchronous.response import PinnedResponse, Response
|
||||
from pymongo.asynchronous.helpers import _handle_reauth
|
||||
from pymongo.errors import NotPrimaryError, OperationFailure
|
||||
from pymongo.helpers_shared import _check_command_response
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query
|
||||
from pymongo.response import PinnedResponse, Response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from queue import Queue
|
||||
@ -40,11 +41,11 @@ if TYPE_CHECKING:
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler
|
||||
from pymongo.asynchronous.monitor import Monitor
|
||||
from pymongo.asynchronous.monitoring import _EventListeners
|
||||
from pymongo.asynchronous.pool import Connection, Pool
|
||||
from pymongo.asynchronous.read_preferences import _ServerMode
|
||||
from pymongo.asynchronous.server_description import ServerDescription
|
||||
from pymongo.asynchronous.typings import _DocumentOut
|
||||
from pymongo.asynchronous.pool import AsyncConnection, Pool
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.typings import _DocumentOut
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
@ -105,10 +106,23 @@ class Server:
|
||||
"""Check the server's state soon."""
|
||||
self._monitor.request_check()
|
||||
|
||||
async def operation_to_command(
|
||||
self, operation: Union[_Query, _GetMore], conn: AsyncConnection, apply_timeout: bool = False
|
||||
) -> tuple[dict[str, Any], str]:
|
||||
cmd, db = operation.as_command(conn, apply_timeout)
|
||||
# Support auto encryption
|
||||
if operation.client._encrypter and not operation.client._encrypter._bypass_auto_encryption:
|
||||
cmd = await operation.client._encrypter.encrypt( # type: ignore[misc, assignment]
|
||||
operation.db, cmd, operation.codec_options
|
||||
)
|
||||
operation.update_command(cmd)
|
||||
|
||||
return cmd, db
|
||||
|
||||
@_handle_reauth
|
||||
async def run_operation(
|
||||
self,
|
||||
conn: Connection,
|
||||
conn: AsyncConnection,
|
||||
operation: Union[_Query, _GetMore],
|
||||
read_preference: _ServerMode,
|
||||
listeners: Optional[_EventListeners],
|
||||
@ -121,26 +135,26 @@ class Server:
|
||||
cursors.
|
||||
Can raise ConnectionFailure, OperationFailure, etc.
|
||||
|
||||
:param conn: A Connection instance.
|
||||
:param conn: An AsyncConnection instance.
|
||||
:param operation: A _Query or _GetMore object.
|
||||
:param read_preference: The read preference to use.
|
||||
:param listeners: Instance of _EventListeners or None.
|
||||
:param unpack_res: A callable that decodes the wire protocol response.
|
||||
:param client: An AsyncMongoClient instance.
|
||||
"""
|
||||
duration = None
|
||||
assert listeners is not None
|
||||
publish = listeners.enabled_for_commands
|
||||
start = datetime.now()
|
||||
|
||||
use_cmd = operation.use_command(conn)
|
||||
more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come
|
||||
cmd, dbn = await self.operation_to_command(operation, conn, use_cmd)
|
||||
if more_to_come:
|
||||
request_id = 0
|
||||
else:
|
||||
message = await operation.get_message(read_preference, conn, use_cmd)
|
||||
message = operation.get_message(read_preference, conn, use_cmd)
|
||||
request_id, data, max_doc_size = self._split_message(message)
|
||||
|
||||
cmd, dbn = await operation.as_command(conn)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
@ -159,7 +173,6 @@ class Server:
|
||||
)
|
||||
|
||||
if publish:
|
||||
cmd, dbn = await operation.as_command(conn)
|
||||
if "$db" not in cmd:
|
||||
cmd["$db"] = dbn
|
||||
assert listeners is not None
|
||||
@ -195,7 +208,7 @@ class Server:
|
||||
)
|
||||
if use_cmd:
|
||||
first = docs[0]
|
||||
await operation.client._process_response(first, operation.session)
|
||||
await operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type]
|
||||
_check_command_response(first, conn.max_wire_version)
|
||||
except Exception as exc:
|
||||
duration = datetime.now() - start
|
||||
@ -278,7 +291,7 @@ class Server:
|
||||
)
|
||||
|
||||
# Decrypt response.
|
||||
client = operation.client
|
||||
client = operation.client # type: ignore[assignment]
|
||||
if client and client._encrypter:
|
||||
if use_cmd:
|
||||
decrypted = client._encrypter.decrypt(reply.raw_command_response())
|
||||
@ -286,7 +299,7 @@ class Server:
|
||||
|
||||
response: Response
|
||||
|
||||
if client._should_pin_cursor(operation.session) or operation.exhaust:
|
||||
if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type]
|
||||
conn.pin_cursor()
|
||||
if isinstance(reply, _OpMsg):
|
||||
# In OP_MSG, the server keeps sending only if the
|
||||
@ -321,7 +334,7 @@ class Server:
|
||||
|
||||
async def checkout(
|
||||
self, handler: Optional[_MongoClientErrorHandler] = None
|
||||
) -> AsyncContextManager[Connection]:
|
||||
) -> AsyncContextManager[AsyncConnection]:
|
||||
return self.pool.checkout(handler)
|
||||
|
||||
@property
|
||||
|
||||
@ -1,301 +0,0 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Represent one server the driver is connected to."""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from bson import EPOCH_NAIVE
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.asynchronous.hello import Hello
|
||||
from pymongo.asynchronous.typings import ClusterTime, _Address
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class ServerDescription:
|
||||
"""Immutable representation of one server.
|
||||
|
||||
:param address: A (host, port) pair
|
||||
:param hello: Optional Hello instance
|
||||
:param round_trip_time: Optional float
|
||||
:param error: Optional, the last error attempting to connect to the server
|
||||
:param round_trip_time: Optional float, the min latency from the most recent samples
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"_address",
|
||||
"_server_type",
|
||||
"_all_hosts",
|
||||
"_tags",
|
||||
"_replica_set_name",
|
||||
"_primary",
|
||||
"_max_bson_size",
|
||||
"_max_message_size",
|
||||
"_max_write_batch_size",
|
||||
"_min_wire_version",
|
||||
"_max_wire_version",
|
||||
"_round_trip_time",
|
||||
"_min_round_trip_time",
|
||||
"_me",
|
||||
"_is_writable",
|
||||
"_is_readable",
|
||||
"_ls_timeout_minutes",
|
||||
"_error",
|
||||
"_set_version",
|
||||
"_election_id",
|
||||
"_cluster_time",
|
||||
"_last_write_date",
|
||||
"_last_update_time",
|
||||
"_topology_version",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: _Address,
|
||||
hello: Optional[Hello] = None,
|
||||
round_trip_time: Optional[float] = None,
|
||||
error: Optional[Exception] = None,
|
||||
min_round_trip_time: float = 0.0,
|
||||
) -> None:
|
||||
self._address = address
|
||||
if not hello:
|
||||
hello = Hello({})
|
||||
|
||||
self._server_type = hello.server_type
|
||||
self._all_hosts = hello.all_hosts
|
||||
self._tags = hello.tags
|
||||
self._replica_set_name = hello.replica_set_name
|
||||
self._primary = hello.primary
|
||||
self._max_bson_size = hello.max_bson_size
|
||||
self._max_message_size = hello.max_message_size
|
||||
self._max_write_batch_size = hello.max_write_batch_size
|
||||
self._min_wire_version = hello.min_wire_version
|
||||
self._max_wire_version = hello.max_wire_version
|
||||
self._set_version = hello.set_version
|
||||
self._election_id = hello.election_id
|
||||
self._cluster_time = hello.cluster_time
|
||||
self._is_writable = hello.is_writable
|
||||
self._is_readable = hello.is_readable
|
||||
self._ls_timeout_minutes = hello.logical_session_timeout_minutes
|
||||
self._round_trip_time = round_trip_time
|
||||
self._min_round_trip_time = min_round_trip_time
|
||||
self._me = hello.me
|
||||
self._last_update_time = time.monotonic()
|
||||
self._error = error
|
||||
self._topology_version = hello.topology_version
|
||||
if error:
|
||||
details = getattr(error, "details", None)
|
||||
if isinstance(details, dict):
|
||||
self._topology_version = details.get("topologyVersion")
|
||||
|
||||
self._last_write_date: Optional[float]
|
||||
if hello.last_write_date:
|
||||
# Convert from datetime to seconds.
|
||||
delta = hello.last_write_date - EPOCH_NAIVE
|
||||
self._last_write_date = delta.total_seconds()
|
||||
else:
|
||||
self._last_write_date = None
|
||||
|
||||
@property
|
||||
def address(self) -> _Address:
|
||||
"""The address (host, port) of this server."""
|
||||
return self._address
|
||||
|
||||
@property
|
||||
def server_type(self) -> int:
|
||||
"""The type of this server."""
|
||||
return self._server_type
|
||||
|
||||
@property
|
||||
def server_type_name(self) -> str:
|
||||
"""The server type as a human readable string.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
return SERVER_TYPE._fields[self._server_type]
|
||||
|
||||
@property
|
||||
def all_hosts(self) -> set[tuple[str, int]]:
|
||||
"""List of hosts, passives, and arbiters known to this server."""
|
||||
return self._all_hosts
|
||||
|
||||
@property
|
||||
def tags(self) -> Mapping[str, Any]:
|
||||
return self._tags
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""Replica set name or None."""
|
||||
return self._replica_set_name
|
||||
|
||||
@property
|
||||
def primary(self) -> Optional[tuple[str, int]]:
|
||||
"""This server's opinion about who the primary is, or None."""
|
||||
return self._primary
|
||||
|
||||
@property
|
||||
def max_bson_size(self) -> int:
|
||||
return self._max_bson_size
|
||||
|
||||
@property
|
||||
def max_message_size(self) -> int:
|
||||
return self._max_message_size
|
||||
|
||||
@property
|
||||
def max_write_batch_size(self) -> int:
|
||||
return self._max_write_batch_size
|
||||
|
||||
@property
|
||||
def min_wire_version(self) -> int:
|
||||
return self._min_wire_version
|
||||
|
||||
@property
|
||||
def max_wire_version(self) -> int:
|
||||
return self._max_wire_version
|
||||
|
||||
@property
|
||||
def set_version(self) -> Optional[int]:
|
||||
return self._set_version
|
||||
|
||||
@property
|
||||
def election_id(self) -> Optional[ObjectId]:
|
||||
return self._election_id
|
||||
|
||||
@property
|
||||
def cluster_time(self) -> Optional[ClusterTime]:
|
||||
return self._cluster_time
|
||||
|
||||
@property
|
||||
def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]:
|
||||
warnings.warn(
|
||||
"'election_tuple' is deprecated, use 'set_version' and 'election_id' instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._set_version, self._election_id
|
||||
|
||||
@property
|
||||
def me(self) -> Optional[tuple[str, int]]:
|
||||
return self._me
|
||||
|
||||
@property
|
||||
def logical_session_timeout_minutes(self) -> Optional[int]:
|
||||
return self._ls_timeout_minutes
|
||||
|
||||
@property
|
||||
def last_write_date(self) -> Optional[float]:
|
||||
return self._last_write_date
|
||||
|
||||
@property
|
||||
def last_update_time(self) -> float:
|
||||
return self._last_update_time
|
||||
|
||||
@property
|
||||
def round_trip_time(self) -> Optional[float]:
|
||||
"""The current average latency or None."""
|
||||
# This override is for unittesting only!
|
||||
if self._address in self._host_to_round_trip_time:
|
||||
return self._host_to_round_trip_time[self._address]
|
||||
|
||||
return self._round_trip_time
|
||||
|
||||
@property
|
||||
def min_round_trip_time(self) -> float:
|
||||
"""The min latency from the most recent samples."""
|
||||
return self._min_round_trip_time
|
||||
|
||||
@property
|
||||
def error(self) -> Optional[Exception]:
|
||||
"""The last error attempting to connect to the server, or None."""
|
||||
return self._error
|
||||
|
||||
@property
|
||||
def is_writable(self) -> bool:
|
||||
return self._is_writable
|
||||
|
||||
@property
|
||||
def is_readable(self) -> bool:
|
||||
return self._is_readable
|
||||
|
||||
@property
|
||||
def mongos(self) -> bool:
|
||||
return self._server_type == SERVER_TYPE.Mongos
|
||||
|
||||
@property
|
||||
def is_server_type_known(self) -> bool:
|
||||
return self.server_type != SERVER_TYPE.Unknown
|
||||
|
||||
@property
|
||||
def retryable_writes_supported(self) -> bool:
|
||||
"""Checks if this server supports retryable writes."""
|
||||
return (
|
||||
self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)
|
||||
) or self._server_type == SERVER_TYPE.LoadBalancer
|
||||
|
||||
@property
|
||||
def retryable_reads_supported(self) -> bool:
|
||||
"""Checks if this server supports retryable writes."""
|
||||
return self._max_wire_version >= 6
|
||||
|
||||
@property
|
||||
def topology_version(self) -> Optional[Mapping[str, Any]]:
|
||||
return self._topology_version
|
||||
|
||||
def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription:
|
||||
unknown = ServerDescription(self.address, error=error)
|
||||
unknown._topology_version = self.topology_version
|
||||
return unknown
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, ServerDescription):
|
||||
return (
|
||||
(self._address == other.address)
|
||||
and (self._server_type == other.server_type)
|
||||
and (self._min_wire_version == other.min_wire_version)
|
||||
and (self._max_wire_version == other.max_wire_version)
|
||||
and (self._me == other.me)
|
||||
and (self._all_hosts == other.all_hosts)
|
||||
and (self._tags == other.tags)
|
||||
and (self._replica_set_name == other.replica_set_name)
|
||||
and (self._set_version == other.set_version)
|
||||
and (self._election_id == other.election_id)
|
||||
and (self._primary == other.primary)
|
||||
and (self._ls_timeout_minutes == other.logical_session_timeout_minutes)
|
||||
and (self._error == other.error)
|
||||
)
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __repr__(self) -> str:
|
||||
errmsg = ""
|
||||
if self.error:
|
||||
errmsg = f", error={self.error!r}"
|
||||
return "<{} {} server_type: {}, rtt: {}{}>".format(
|
||||
self.__class__.__name__,
|
||||
self.address,
|
||||
self.server_type_name,
|
||||
self.round_trip_time,
|
||||
errmsg,
|
||||
)
|
||||
|
||||
# For unittesting only. Use under no circumstances!
|
||||
_host_to_round_trip_time: dict = {}
|
||||
@ -1,175 +0,0 @@
|
||||
# Copyright 2014-2016 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Criteria to select some ServerDescriptions from a TopologyDescription."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, TypeVar, cast
|
||||
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.server_description import ServerDescription
|
||||
from pymongo.asynchronous.topology_description import TopologyDescription
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
T = TypeVar("T")
|
||||
TagSet = Mapping[str, Any]
|
||||
TagSets = Sequence[TagSet]
|
||||
|
||||
|
||||
class Selection:
|
||||
"""Input or output of a server selector function."""
|
||||
|
||||
@classmethod
|
||||
def from_topology_description(cls, topology_description: TopologyDescription) -> Selection:
|
||||
known_servers = topology_description.known_servers
|
||||
primary = None
|
||||
for sd in known_servers:
|
||||
if sd.server_type == SERVER_TYPE.RSPrimary:
|
||||
primary = sd
|
||||
break
|
||||
|
||||
return Selection(
|
||||
topology_description,
|
||||
topology_description.known_servers,
|
||||
topology_description.common_wire_version,
|
||||
primary,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
topology_description: TopologyDescription,
|
||||
server_descriptions: list[ServerDescription],
|
||||
common_wire_version: Optional[int],
|
||||
primary: Optional[ServerDescription],
|
||||
):
|
||||
self.topology_description = topology_description
|
||||
self.server_descriptions = server_descriptions
|
||||
self.primary = primary
|
||||
self.common_wire_version = common_wire_version
|
||||
|
||||
def with_server_descriptions(self, server_descriptions: list[ServerDescription]) -> Selection:
|
||||
return Selection(
|
||||
self.topology_description, server_descriptions, self.common_wire_version, self.primary
|
||||
)
|
||||
|
||||
def secondary_with_max_last_write_date(self) -> Optional[ServerDescription]:
|
||||
secondaries = secondary_server_selector(self)
|
||||
if secondaries.server_descriptions:
|
||||
return max(
|
||||
secondaries.server_descriptions, key=lambda sd: cast(float, sd.last_write_date)
|
||||
)
|
||||
return None
|
||||
|
||||
@property
|
||||
def primary_selection(self) -> Selection:
|
||||
primaries = [self.primary] if self.primary else []
|
||||
return self.with_server_descriptions(primaries)
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self) -> int:
|
||||
return self.topology_description.heartbeat_frequency
|
||||
|
||||
@property
|
||||
def topology_type(self) -> int:
|
||||
return self.topology_description.topology_type
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.server_descriptions)
|
||||
|
||||
def __getitem__(self, item: int) -> ServerDescription:
|
||||
return self.server_descriptions[item]
|
||||
|
||||
|
||||
def any_server_selector(selection: T) -> T:
|
||||
return selection
|
||||
|
||||
|
||||
def readable_server_selector(selection: Selection) -> Selection:
|
||||
return selection.with_server_descriptions(
|
||||
[s for s in selection.server_descriptions if s.is_readable]
|
||||
)
|
||||
|
||||
|
||||
def writable_server_selector(selection: Selection) -> Selection:
|
||||
return selection.with_server_descriptions(
|
||||
[s for s in selection.server_descriptions if s.is_writable]
|
||||
)
|
||||
|
||||
|
||||
def secondary_server_selector(selection: Selection) -> Selection:
|
||||
return selection.with_server_descriptions(
|
||||
[s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSSecondary]
|
||||
)
|
||||
|
||||
|
||||
def arbiter_server_selector(selection: Selection) -> Selection:
|
||||
return selection.with_server_descriptions(
|
||||
[s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSArbiter]
|
||||
)
|
||||
|
||||
|
||||
def writable_preferred_server_selector(selection: Selection) -> Selection:
|
||||
"""Like PrimaryPreferred but doesn't use tags or latency."""
|
||||
return writable_server_selector(selection) or secondary_server_selector(selection)
|
||||
|
||||
|
||||
def apply_single_tag_set(tag_set: TagSet, selection: Selection) -> Selection:
|
||||
"""All servers matching one tag set.
|
||||
|
||||
A tag set is a dict. A server matches if its tags are a superset:
|
||||
A server tagged {'a': '1', 'b': '2'} matches the tag set {'a': '1'}.
|
||||
|
||||
The empty tag set {} matches any server.
|
||||
"""
|
||||
|
||||
def tags_match(server_tags: Mapping[str, Any]) -> bool:
|
||||
for key, value in tag_set.items():
|
||||
if key not in server_tags or server_tags[key] != value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return selection.with_server_descriptions(
|
||||
[s for s in selection.server_descriptions if tags_match(s.tags)]
|
||||
)
|
||||
|
||||
|
||||
def apply_tag_sets(tag_sets: TagSets, selection: Selection) -> Selection:
|
||||
"""All servers match a list of tag sets.
|
||||
|
||||
tag_sets is a list of dicts. The empty tag set {} matches any server,
|
||||
and may be provided at the end of the list as a fallback. So
|
||||
[{'a': 'value'}, {}] expresses a preference for servers tagged
|
||||
{'a': 'value'}, but accepts any server if none matches the first
|
||||
preference.
|
||||
"""
|
||||
for tag_set in tag_sets:
|
||||
with_tag_set = apply_single_tag_set(tag_set, selection)
|
||||
if with_tag_set:
|
||||
return with_tag_set
|
||||
|
||||
return selection.with_server_descriptions([])
|
||||
|
||||
|
||||
def secondary_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection:
|
||||
"""All near-enough secondaries matching the tag sets."""
|
||||
return apply_tag_sets(tag_sets, secondary_server_selector(selection))
|
||||
|
||||
|
||||
def member_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection:
|
||||
"""All near-enough members matching the tag sets."""
|
||||
return apply_tag_sets(tag_sets, readable_server_selector(selection))
|
||||
@ -20,12 +20,14 @@ import traceback
|
||||
from typing import Any, Collection, Optional, Type, Union
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.asynchronous import common, monitor, pool
|
||||
from pymongo.asynchronous.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT
|
||||
from pymongo.asynchronous.pool import Pool, PoolOptions
|
||||
from pymongo.asynchronous.server_description import ServerDescription
|
||||
from pymongo.asynchronous.topology_description import TOPOLOGY_TYPE, _ServerSelector
|
||||
from pymongo import common
|
||||
from pymongo.asynchronous import monitor, pool
|
||||
from pymongo.asynchronous.pool import Pool
|
||||
from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE, _ServerSelector
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
@ -1,149 +0,0 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pymongo.asynchronous.common import CONNECT_TIMEOUT
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dns import resolver
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _have_dnspython() -> bool:
|
||||
try:
|
||||
import dns # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
# dnspython can return bytes or str from various parts
|
||||
# of its API depending on version. We always want str.
|
||||
def maybe_decode(text: Union[str, bytes]) -> str:
|
||||
if isinstance(text, bytes):
|
||||
return text.decode()
|
||||
return text
|
||||
|
||||
|
||||
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
|
||||
def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
|
||||
from dns import resolver
|
||||
|
||||
if hasattr(resolver, "resolve"):
|
||||
# dnspython >= 2
|
||||
return resolver.resolve(*args, **kwargs)
|
||||
# dnspython 1.X
|
||||
return resolver.query(*args, **kwargs)
|
||||
|
||||
|
||||
_INVALID_HOST_MSG = (
|
||||
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. "
|
||||
"Did you mean to use 'mongodb://'?"
|
||||
)
|
||||
|
||||
|
||||
class _SrvResolver:
|
||||
def __init__(
|
||||
self,
|
||||
fqdn: str,
|
||||
connect_timeout: Optional[float],
|
||||
srv_service_name: str,
|
||||
srv_max_hosts: int = 0,
|
||||
):
|
||||
self.__fqdn = fqdn
|
||||
self.__srv = srv_service_name
|
||||
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
|
||||
self.__srv_max_hosts = srv_max_hosts or 0
|
||||
# Validate the fully qualified domain name.
|
||||
try:
|
||||
ipaddress.ip_address(fqdn)
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.__plist = self.__fqdn.split(".")[1:]
|
||||
except Exception:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
|
||||
self.__slen = len(self.__plist)
|
||||
if self.__slen < 2:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
|
||||
|
||||
def get_options(self) -> Optional[str]:
|
||||
from dns import resolver
|
||||
|
||||
try:
|
||||
results = _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout)
|
||||
except (resolver.NoAnswer, resolver.NXDOMAIN):
|
||||
# No TXT records
|
||||
return None
|
||||
except Exception as exc:
|
||||
raise ConfigurationError(str(exc)) from None
|
||||
if len(results) > 1:
|
||||
raise ConfigurationError("Only one TXT record is supported")
|
||||
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8")
|
||||
|
||||
def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer:
|
||||
try:
|
||||
results = _resolve(
|
||||
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
|
||||
)
|
||||
except Exception as exc:
|
||||
if not encapsulate_errors:
|
||||
# Raise the original error.
|
||||
raise
|
||||
# Else, raise all errors as ConfigurationError.
|
||||
raise ConfigurationError(str(exc)) from None
|
||||
return results
|
||||
|
||||
def _get_srv_response_and_hosts(
|
||||
self, encapsulate_errors: bool
|
||||
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
|
||||
results = self._resolve_uri(encapsulate_errors)
|
||||
|
||||
# Construct address tuples
|
||||
nodes = [
|
||||
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) for res in results
|
||||
]
|
||||
|
||||
# Validate hosts
|
||||
for node in nodes:
|
||||
try:
|
||||
nlist = node[0].lower().split(".")[1:][-self.__slen :]
|
||||
except Exception:
|
||||
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
|
||||
if self.__plist != nlist:
|
||||
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
|
||||
if self.__srv_max_hosts:
|
||||
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
|
||||
return results, nodes
|
||||
|
||||
def get_hosts(self) -> list[tuple[str, Any]]:
|
||||
_, nodes = self._get_srv_response_and_hosts(True)
|
||||
return nodes
|
||||
|
||||
def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]:
|
||||
results, nodes = self._get_srv_response_and_hosts(False)
|
||||
rrset = results.rrset
|
||||
ttl = rrset.ttl if rrset else 0
|
||||
return nodes, ttl
|
||||
@ -27,33 +27,12 @@ import weakref
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
|
||||
|
||||
from pymongo import _csot, helpers_constants
|
||||
from pymongo.asynchronous import common, periodic_executor
|
||||
from pymongo import _csot, common, helpers_shared
|
||||
from pymongo.asynchronous import periodic_executor
|
||||
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
|
||||
from pymongo.asynchronous.hello import Hello
|
||||
from pymongo.asynchronous.logger import (
|
||||
_SERVER_SELECTION_LOGGER,
|
||||
_debug_log,
|
||||
_ServerSelectionStatusMessage,
|
||||
)
|
||||
from pymongo.asynchronous.monitor import SrvMonitor
|
||||
from pymongo.asynchronous.pool import Pool, PoolOptions
|
||||
from pymongo.asynchronous.pool import Pool
|
||||
from pymongo.asynchronous.server import Server
|
||||
from pymongo.asynchronous.server_description import ServerDescription
|
||||
from pymongo.asynchronous.server_selectors import (
|
||||
Selection,
|
||||
any_server_selector,
|
||||
arbiter_server_selector,
|
||||
secondary_server_selector,
|
||||
writable_server_selector,
|
||||
)
|
||||
from pymongo.asynchronous.topology_description import (
|
||||
SRV_POLLING_TOPOLOGIES,
|
||||
TOPOLOGY_TYPE,
|
||||
TopologyDescription,
|
||||
_updated_topology_description_srv_polling,
|
||||
updated_topology_description,
|
||||
)
|
||||
from pymongo.errors import (
|
||||
ConnectionFailure,
|
||||
InvalidOperation,
|
||||
@ -64,12 +43,34 @@ from pymongo.errors import (
|
||||
ServerSelectionTimeoutError,
|
||||
WriteError,
|
||||
)
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _ACondition, _ALock, _create_lock
|
||||
from pymongo.logger import (
|
||||
_SERVER_SELECTION_LOGGER,
|
||||
_debug_log,
|
||||
_ServerSelectionStatusMessage,
|
||||
)
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import (
|
||||
Selection,
|
||||
any_server_selector,
|
||||
arbiter_server_selector,
|
||||
secondary_server_selector,
|
||||
writable_server_selector,
|
||||
)
|
||||
from pymongo.topology_description import (
|
||||
SRV_POLLING_TOPOLOGIES,
|
||||
TOPOLOGY_TYPE,
|
||||
TopologyDescription,
|
||||
_updated_topology_description_srv_polling,
|
||||
updated_topology_description,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson import ObjectId
|
||||
from pymongo.asynchronous.settings import TopologySettings
|
||||
from pymongo.asynchronous.typings import ClusterTime, _Address
|
||||
from pymongo.typings import ClusterTime, _Address
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
@ -781,8 +782,8 @@ class Topology:
|
||||
# Default error code if one does not exist.
|
||||
default = 10107 if isinstance(error, NotPrimaryError) else None
|
||||
err_code = error.details.get("code", default) # type: ignore[union-attr]
|
||||
if err_code in helpers_constants._NOT_PRIMARY_CODES:
|
||||
is_shutting_down = err_code in helpers_constants._SHUTDOWN_CODES
|
||||
if err_code in helpers_shared._NOT_PRIMARY_CODES:
|
||||
is_shutting_down = err_code in helpers_shared._SHUTDOWN_CODES
|
||||
# Mark server Unknown, clear the pool, and request check.
|
||||
if not self._settings.load_balanced:
|
||||
await self._process_change(ServerDescription(address, error=error))
|
||||
|
||||
@ -1,678 +0,0 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Represent a deployment of MongoDB servers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from random import sample
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
|
||||
from bson.min_key import MinKey
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo.asynchronous.read_preferences import ReadPreference, _AggWritePref, _ServerMode
|
||||
from pymongo.asynchronous.server_description import ServerDescription
|
||||
from pymongo.asynchronous.server_selectors import Selection
|
||||
from pymongo.asynchronous.typings import _Address
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
# Enumeration for various kinds of MongoDB cluster topologies.
|
||||
class _TopologyType(NamedTuple):
|
||||
Single: int
|
||||
ReplicaSetNoPrimary: int
|
||||
ReplicaSetWithPrimary: int
|
||||
Sharded: int
|
||||
Unknown: int
|
||||
LoadBalanced: int
|
||||
|
||||
|
||||
TOPOLOGY_TYPE = _TopologyType(*range(6))
|
||||
|
||||
# Topologies compatible with SRV record polling.
|
||||
SRV_POLLING_TOPOLOGIES: tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded)
|
||||
|
||||
|
||||
_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]]
|
||||
|
||||
|
||||
class TopologyDescription:
|
||||
def __init__(
|
||||
self,
|
||||
topology_type: int,
|
||||
server_descriptions: dict[_Address, ServerDescription],
|
||||
replica_set_name: Optional[str],
|
||||
max_set_version: Optional[int],
|
||||
max_election_id: Optional[ObjectId],
|
||||
topology_settings: Any,
|
||||
) -> None:
|
||||
"""Representation of a deployment of MongoDB servers.
|
||||
|
||||
:param topology_type: initial type
|
||||
:param server_descriptions: dict of (address, ServerDescription) for
|
||||
all seeds
|
||||
:param replica_set_name: replica set name or None
|
||||
:param max_set_version: greatest setVersion seen from a primary, or None
|
||||
:param max_election_id: greatest electionId seen from a primary, or None
|
||||
:param topology_settings: a TopologySettings
|
||||
"""
|
||||
self._topology_type = topology_type
|
||||
self._replica_set_name = replica_set_name
|
||||
self._server_descriptions = server_descriptions
|
||||
self._max_set_version = max_set_version
|
||||
self._max_election_id = max_election_id
|
||||
|
||||
# The heartbeat_frequency is used in staleness estimates.
|
||||
self._topology_settings = topology_settings
|
||||
|
||||
# Is PyMongo compatible with all servers' wire protocols?
|
||||
self._incompatible_err = None
|
||||
if self._topology_type != TOPOLOGY_TYPE.LoadBalanced:
|
||||
self._init_incompatible_err()
|
||||
|
||||
# Server Discovery And Monitoring Spec: Whenever a client updates the
|
||||
# TopologyDescription from an hello response, it MUST set
|
||||
# TopologyDescription.logicalSessionTimeoutMinutes to the smallest
|
||||
# logicalSessionTimeoutMinutes value among ServerDescriptions of all
|
||||
# data-bearing server types. If any have a null
|
||||
# logicalSessionTimeoutMinutes, then
|
||||
# TopologyDescription.logicalSessionTimeoutMinutes MUST be set to null.
|
||||
readable_servers = self.readable_servers
|
||||
if not readable_servers:
|
||||
self._ls_timeout_minutes = None
|
||||
elif any(s.logical_session_timeout_minutes is None for s in readable_servers):
|
||||
self._ls_timeout_minutes = None
|
||||
else:
|
||||
self._ls_timeout_minutes = min( # type: ignore[type-var]
|
||||
s.logical_session_timeout_minutes for s in readable_servers
|
||||
)
|
||||
|
||||
def _init_incompatible_err(self) -> None:
|
||||
"""Internal compatibility check for non-load balanced topologies."""
|
||||
for s in self._server_descriptions.values():
|
||||
if not s.is_server_type_known:
|
||||
continue
|
||||
|
||||
# s.min/max_wire_version is the server's wire protocol.
|
||||
# MIN/MAX_SUPPORTED_WIRE_VERSION is what PyMongo supports.
|
||||
server_too_new = (
|
||||
# Server too new.
|
||||
s.min_wire_version is not None
|
||||
and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION
|
||||
)
|
||||
|
||||
server_too_old = (
|
||||
# Server too old.
|
||||
s.max_wire_version is not None
|
||||
and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION
|
||||
)
|
||||
|
||||
if server_too_new:
|
||||
self._incompatible_err = (
|
||||
"Server at %s:%d requires wire version %d, but this " # type: ignore
|
||||
"version of PyMongo only supports up to %d."
|
||||
% (
|
||||
s.address[0],
|
||||
s.address[1] or 0,
|
||||
s.min_wire_version,
|
||||
common.MAX_SUPPORTED_WIRE_VERSION,
|
||||
)
|
||||
)
|
||||
|
||||
elif server_too_old:
|
||||
self._incompatible_err = (
|
||||
"Server at %s:%d reports wire version %d, but this " # type: ignore
|
||||
"version of PyMongo requires at least %d (MongoDB %s)."
|
||||
% (
|
||||
s.address[0],
|
||||
s.address[1] or 0,
|
||||
s.max_wire_version,
|
||||
common.MIN_SUPPORTED_WIRE_VERSION,
|
||||
common.MIN_SUPPORTED_SERVER_VERSION,
|
||||
)
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
def check_compatible(self) -> None:
|
||||
"""Raise ConfigurationError if any server is incompatible.
|
||||
|
||||
A server is incompatible if its wire protocol version range does not
|
||||
overlap with PyMongo's.
|
||||
"""
|
||||
if self._incompatible_err:
|
||||
raise ConfigurationError(self._incompatible_err)
|
||||
|
||||
def has_server(self, address: _Address) -> bool:
|
||||
return address in self._server_descriptions
|
||||
|
||||
def reset_server(self, address: _Address) -> TopologyDescription:
|
||||
"""A copy of this description, with one server marked Unknown."""
|
||||
unknown_sd = self._server_descriptions[address].to_unknown()
|
||||
return updated_topology_description(self, unknown_sd)
|
||||
|
||||
def reset(self) -> TopologyDescription:
|
||||
"""A copy of this description, with all servers marked Unknown."""
|
||||
if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary:
|
||||
topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary
|
||||
else:
|
||||
topology_type = self._topology_type
|
||||
|
||||
# The default ServerDescription's type is Unknown.
|
||||
sds = {address: ServerDescription(address) for address in self._server_descriptions}
|
||||
|
||||
return TopologyDescription(
|
||||
topology_type,
|
||||
sds,
|
||||
self._replica_set_name,
|
||||
self._max_set_version,
|
||||
self._max_election_id,
|
||||
self._topology_settings,
|
||||
)
|
||||
|
||||
def server_descriptions(self) -> dict[_Address, ServerDescription]:
|
||||
"""dict of (address,
|
||||
:class:`~pymongo.server_description.ServerDescription`).
|
||||
"""
|
||||
return self._server_descriptions.copy()
|
||||
|
||||
@property
|
||||
def topology_type(self) -> int:
|
||||
"""The type of this topology."""
|
||||
return self._topology_type
|
||||
|
||||
@property
|
||||
def topology_type_name(self) -> str:
|
||||
"""The topology type as a human readable string.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
return TOPOLOGY_TYPE._fields[self._topology_type]
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""The replica set name."""
|
||||
return self._replica_set_name
|
||||
|
||||
@property
|
||||
def max_set_version(self) -> Optional[int]:
|
||||
"""Greatest setVersion seen from a primary, or None."""
|
||||
return self._max_set_version
|
||||
|
||||
@property
|
||||
def max_election_id(self) -> Optional[ObjectId]:
|
||||
"""Greatest electionId seen from a primary, or None."""
|
||||
return self._max_election_id
|
||||
|
||||
@property
|
||||
def logical_session_timeout_minutes(self) -> Optional[int]:
|
||||
"""Minimum logical session timeout, or None."""
|
||||
return self._ls_timeout_minutes
|
||||
|
||||
@property
|
||||
def known_servers(self) -> list[ServerDescription]:
|
||||
"""List of Servers of types besides Unknown."""
|
||||
return [s for s in self._server_descriptions.values() if s.is_server_type_known]
|
||||
|
||||
@property
|
||||
def has_known_servers(self) -> bool:
|
||||
"""Whether there are any Servers of types besides Unknown."""
|
||||
return any(s for s in self._server_descriptions.values() if s.is_server_type_known)
|
||||
|
||||
@property
|
||||
def readable_servers(self) -> list[ServerDescription]:
|
||||
"""List of readable Servers."""
|
||||
return [s for s in self._server_descriptions.values() if s.is_readable]
|
||||
|
||||
@property
|
||||
def common_wire_version(self) -> Optional[int]:
|
||||
"""Minimum of all servers' max wire versions, or None."""
|
||||
servers = self.known_servers
|
||||
if servers:
|
||||
return min(s.max_wire_version for s in self.known_servers)
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self) -> int:
|
||||
return self._topology_settings.heartbeat_frequency
|
||||
|
||||
@property
|
||||
def srv_max_hosts(self) -> int:
|
||||
return self._topology_settings._srv_max_hosts
|
||||
|
||||
def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]:
|
||||
if not selection:
|
||||
return []
|
||||
round_trip_times: list[float] = []
|
||||
for server in selection.server_descriptions:
|
||||
if server.round_trip_time is None:
|
||||
config_err_msg = f"round_trip_time for server {server.address} is unexpectedly None: {self}, servers: {selection.server_descriptions}"
|
||||
raise ConfigurationError(config_err_msg)
|
||||
round_trip_times.append(server.round_trip_time)
|
||||
# Round trip time in seconds.
|
||||
fastest = min(round_trip_times)
|
||||
threshold = self._topology_settings.local_threshold_ms / 1000.0
|
||||
return [
|
||||
s
|
||||
for s in selection.server_descriptions
|
||||
if (cast(float, s.round_trip_time) - fastest) <= threshold
|
||||
]
|
||||
|
||||
def apply_selector(
|
||||
self,
|
||||
selector: Any,
|
||||
address: Optional[_Address] = None,
|
||||
custom_selector: Optional[_ServerSelector] = None,
|
||||
) -> list[ServerDescription]:
|
||||
"""List of servers matching the provided selector(s).
|
||||
|
||||
:param selector: a callable that takes a Selection as input and returns
|
||||
a Selection as output. For example, an instance of a read
|
||||
preference from :mod:`~pymongo.read_preferences`.
|
||||
:param address: A server address to select.
|
||||
:param custom_selector: A callable that augments server
|
||||
selection rules. Accepts a list of
|
||||
:class:`~pymongo.server_description.ServerDescription` objects and
|
||||
return a list of server descriptions that should be considered
|
||||
suitable for the desired operation.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
if getattr(selector, "min_wire_version", 0):
|
||||
common_wv = self.common_wire_version
|
||||
if common_wv and common_wv < selector.min_wire_version:
|
||||
raise ConfigurationError(
|
||||
"%s requires min wire version %d, but topology's min"
|
||||
" wire version is %d" % (selector, selector.min_wire_version, common_wv)
|
||||
)
|
||||
|
||||
if isinstance(selector, _AggWritePref):
|
||||
selector.selection_hook(self)
|
||||
|
||||
if self.topology_type == TOPOLOGY_TYPE.Unknown:
|
||||
return []
|
||||
elif self.topology_type in (TOPOLOGY_TYPE.Single, TOPOLOGY_TYPE.LoadBalanced):
|
||||
# Ignore selectors for standalone and load balancer mode.
|
||||
return self.known_servers
|
||||
if address:
|
||||
# Ignore selectors when explicit address is requested.
|
||||
description = self.server_descriptions().get(address)
|
||||
return [description] if description else []
|
||||
|
||||
selection = Selection.from_topology_description(self)
|
||||
# Ignore read preference for sharded clusters.
|
||||
if self.topology_type != TOPOLOGY_TYPE.Sharded:
|
||||
selection = selector(selection)
|
||||
|
||||
# Apply custom selector followed by localThresholdMS.
|
||||
if custom_selector is not None and selection:
|
||||
selection = selection.with_server_descriptions(
|
||||
custom_selector(selection.server_descriptions)
|
||||
)
|
||||
return self._apply_local_threshold(selection)
|
||||
|
||||
def has_readable_server(self, read_preference: _ServerMode = ReadPreference.PRIMARY) -> bool:
|
||||
"""Does this topology have any readable servers available matching the
|
||||
given read preference?
|
||||
|
||||
:param read_preference: an instance of a read preference from
|
||||
:mod:`~pymongo.read_preferences`. Defaults to
|
||||
:attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`.
|
||||
|
||||
.. note:: When connected directly to a single server this method
|
||||
always returns ``True``.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
common.validate_read_preference("read_preference", read_preference)
|
||||
return any(self.apply_selector(read_preference))
|
||||
|
||||
def has_writable_server(self) -> bool:
|
||||
"""Does this topology have a writable server available?
|
||||
|
||||
.. note:: When connected directly to a single server this method
|
||||
always returns ``True``.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
return self.has_readable_server(ReadPreference.PRIMARY)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Sort the servers by address.
|
||||
servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address)
|
||||
return "<{} id: {}, topology_type: {}, servers: {!r}>".format(
|
||||
self.__class__.__name__,
|
||||
self._topology_settings._topology_id,
|
||||
self.topology_type_name,
|
||||
servers,
|
||||
)
|
||||
|
||||
|
||||
# If topology type is Unknown and we receive a hello response, what should
|
||||
# the new topology type be?
|
||||
_SERVER_TYPE_TO_TOPOLOGY_TYPE = {
|
||||
SERVER_TYPE.Mongos: TOPOLOGY_TYPE.Sharded,
|
||||
SERVER_TYPE.RSPrimary: TOPOLOGY_TYPE.ReplicaSetWithPrimary,
|
||||
SERVER_TYPE.RSSecondary: TOPOLOGY_TYPE.ReplicaSetNoPrimary,
|
||||
SERVER_TYPE.RSArbiter: TOPOLOGY_TYPE.ReplicaSetNoPrimary,
|
||||
SERVER_TYPE.RSOther: TOPOLOGY_TYPE.ReplicaSetNoPrimary,
|
||||
# Note: SERVER_TYPE.LoadBalancer and Unknown are intentionally left out.
|
||||
}
|
||||
|
||||
|
||||
def updated_topology_description(
|
||||
topology_description: TopologyDescription, server_description: ServerDescription
|
||||
) -> TopologyDescription:
|
||||
"""Return an updated copy of a TopologyDescription.
|
||||
|
||||
:param topology_description: the current TopologyDescription
|
||||
:param server_description: a new ServerDescription that resulted from
|
||||
a hello call
|
||||
|
||||
Called after attempting (successfully or not) to call hello on the
|
||||
server at server_description.address. Does not modify topology_description.
|
||||
"""
|
||||
address = server_description.address
|
||||
|
||||
# These values will be updated, if necessary, to form the new
|
||||
# TopologyDescription.
|
||||
topology_type = topology_description.topology_type
|
||||
set_name = topology_description.replica_set_name
|
||||
max_set_version = topology_description.max_set_version
|
||||
max_election_id = topology_description.max_election_id
|
||||
server_type = server_description.server_type
|
||||
|
||||
# Don't mutate the original dict of server descriptions; copy it.
|
||||
sds = topology_description.server_descriptions()
|
||||
|
||||
# Replace this server's description with the new one.
|
||||
sds[address] = server_description
|
||||
|
||||
if topology_type == TOPOLOGY_TYPE.Single:
|
||||
# Set server type to Unknown if replica set name does not match.
|
||||
if set_name is not None and set_name != server_description.replica_set_name:
|
||||
error = ConfigurationError(
|
||||
"client is configured to connect to a replica set named "
|
||||
"'{}' but this node belongs to a set named '{}'".format(
|
||||
set_name, server_description.replica_set_name
|
||||
)
|
||||
)
|
||||
sds[address] = server_description.to_unknown(error=error)
|
||||
# Single type never changes.
|
||||
return TopologyDescription(
|
||||
TOPOLOGY_TYPE.Single,
|
||||
sds,
|
||||
set_name,
|
||||
max_set_version,
|
||||
max_election_id,
|
||||
topology_description._topology_settings,
|
||||
)
|
||||
|
||||
if topology_type == TOPOLOGY_TYPE.Unknown:
|
||||
if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.LoadBalancer):
|
||||
if len(topology_description._topology_settings.seeds) == 1:
|
||||
topology_type = TOPOLOGY_TYPE.Single
|
||||
else:
|
||||
# Remove standalone from Topology when given multiple seeds.
|
||||
sds.pop(address)
|
||||
elif server_type not in (SERVER_TYPE.Unknown, SERVER_TYPE.RSGhost):
|
||||
topology_type = _SERVER_TYPE_TO_TOPOLOGY_TYPE[server_type]
|
||||
|
||||
if topology_type == TOPOLOGY_TYPE.Sharded:
|
||||
if server_type not in (SERVER_TYPE.Mongos, SERVER_TYPE.Unknown):
|
||||
sds.pop(address)
|
||||
|
||||
elif topology_type == TOPOLOGY_TYPE.ReplicaSetNoPrimary:
|
||||
if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos):
|
||||
sds.pop(address)
|
||||
|
||||
elif server_type == SERVER_TYPE.RSPrimary:
|
||||
(topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary(
|
||||
sds, set_name, server_description, max_set_version, max_election_id
|
||||
)
|
||||
|
||||
elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther):
|
||||
topology_type, set_name = _update_rs_no_primary_from_member(
|
||||
sds, set_name, server_description
|
||||
)
|
||||
|
||||
elif topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary:
|
||||
if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos):
|
||||
sds.pop(address)
|
||||
topology_type = _check_has_primary(sds)
|
||||
|
||||
elif server_type == SERVER_TYPE.RSPrimary:
|
||||
(topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary(
|
||||
sds, set_name, server_description, max_set_version, max_election_id
|
||||
)
|
||||
|
||||
elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther):
|
||||
topology_type = _update_rs_with_primary_from_member(sds, set_name, server_description)
|
||||
|
||||
else:
|
||||
# Server type is Unknown or RSGhost: did we just lose the primary?
|
||||
topology_type = _check_has_primary(sds)
|
||||
|
||||
# Return updated copy.
|
||||
return TopologyDescription(
|
||||
topology_type,
|
||||
sds,
|
||||
set_name,
|
||||
max_set_version,
|
||||
max_election_id,
|
||||
topology_description._topology_settings,
|
||||
)
|
||||
|
||||
|
||||
def _updated_topology_description_srv_polling(
|
||||
topology_description: TopologyDescription, seedlist: list[tuple[str, Any]]
|
||||
) -> TopologyDescription:
|
||||
"""Return an updated copy of a TopologyDescription.
|
||||
|
||||
:param topology_description: the current TopologyDescription
|
||||
:param seedlist: a list of new seeds new ServerDescription that resulted from
|
||||
a hello call
|
||||
"""
|
||||
assert topology_description.topology_type in SRV_POLLING_TOPOLOGIES
|
||||
# Create a copy of the server descriptions.
|
||||
sds = topology_description.server_descriptions()
|
||||
|
||||
# If seeds haven't changed, don't do anything.
|
||||
if set(sds.keys()) == set(seedlist):
|
||||
return topology_description
|
||||
|
||||
# Remove SDs corresponding to servers no longer part of the SRV record.
|
||||
for address in list(sds.keys()):
|
||||
if address not in seedlist:
|
||||
sds.pop(address)
|
||||
|
||||
if topology_description.srv_max_hosts != 0:
|
||||
new_hosts = set(seedlist) - set(sds.keys())
|
||||
n_to_add = topology_description.srv_max_hosts - len(sds)
|
||||
if n_to_add > 0:
|
||||
seedlist = sample(sorted(new_hosts), min(n_to_add, len(new_hosts)))
|
||||
else:
|
||||
seedlist = []
|
||||
# Add SDs corresponding to servers recently added to the SRV record.
|
||||
for address in seedlist:
|
||||
if address not in sds:
|
||||
sds[address] = ServerDescription(address)
|
||||
return TopologyDescription(
|
||||
topology_description.topology_type,
|
||||
sds,
|
||||
topology_description.replica_set_name,
|
||||
topology_description.max_set_version,
|
||||
topology_description.max_election_id,
|
||||
topology_description._topology_settings,
|
||||
)
|
||||
|
||||
|
||||
def _update_rs_from_primary(
|
||||
sds: MutableMapping[_Address, ServerDescription],
|
||||
replica_set_name: Optional[str],
|
||||
server_description: ServerDescription,
|
||||
max_set_version: Optional[int],
|
||||
max_election_id: Optional[ObjectId],
|
||||
) -> tuple[int, Optional[str], Optional[int], Optional[ObjectId]]:
|
||||
"""Update topology description from a primary's hello response.
|
||||
|
||||
Pass in a dict of ServerDescriptions, current replica set name, the
|
||||
ServerDescription we are processing, and the TopologyDescription's
|
||||
max_set_version and max_election_id if any.
|
||||
|
||||
Returns (new topology type, new replica_set_name, new max_set_version,
|
||||
new max_election_id).
|
||||
"""
|
||||
if replica_set_name is None:
|
||||
replica_set_name = server_description.replica_set_name
|
||||
|
||||
elif replica_set_name != server_description.replica_set_name:
|
||||
# We found a primary but it doesn't have the replica_set_name
|
||||
# provided by the user.
|
||||
sds.pop(server_description.address)
|
||||
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
|
||||
|
||||
if server_description.max_wire_version is None or server_description.max_wire_version < 17:
|
||||
new_election_tuple: tuple = (server_description.set_version, server_description.election_id)
|
||||
max_election_tuple: tuple = (max_set_version, max_election_id)
|
||||
if None not in new_election_tuple:
|
||||
if None not in max_election_tuple and new_election_tuple < max_election_tuple:
|
||||
# Stale primary, set to type Unknown.
|
||||
sds[server_description.address] = server_description.to_unknown()
|
||||
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
|
||||
max_election_id = server_description.election_id
|
||||
|
||||
if server_description.set_version is not None and (
|
||||
max_set_version is None or server_description.set_version > max_set_version
|
||||
):
|
||||
max_set_version = server_description.set_version
|
||||
else:
|
||||
new_election_tuple = server_description.election_id, server_description.set_version
|
||||
max_election_tuple = max_election_id, max_set_version
|
||||
new_election_safe = tuple(MinKey() if i is None else i for i in new_election_tuple)
|
||||
max_election_safe = tuple(MinKey() if i is None else i for i in max_election_tuple)
|
||||
if new_election_safe < max_election_safe:
|
||||
# Stale primary, set to type Unknown.
|
||||
sds[server_description.address] = server_description.to_unknown()
|
||||
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
|
||||
else:
|
||||
max_election_id = server_description.election_id
|
||||
max_set_version = server_description.set_version
|
||||
|
||||
# We've heard from the primary. Is it the same primary as before?
|
||||
for server in sds.values():
|
||||
if (
|
||||
server.server_type is SERVER_TYPE.RSPrimary
|
||||
and server.address != server_description.address
|
||||
):
|
||||
# Reset old primary's type to Unknown.
|
||||
sds[server.address] = server.to_unknown()
|
||||
|
||||
# There can be only one prior primary.
|
||||
break
|
||||
|
||||
# Discover new hosts from this primary's response.
|
||||
for new_address in server_description.all_hosts:
|
||||
if new_address not in sds:
|
||||
sds[new_address] = ServerDescription(new_address)
|
||||
|
||||
# Remove hosts not in the response.
|
||||
for addr in set(sds) - server_description.all_hosts:
|
||||
sds.pop(addr)
|
||||
|
||||
# If the host list differs from the seed list, we may not have a primary
|
||||
# after all.
|
||||
return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id)
|
||||
|
||||
|
||||
def _update_rs_with_primary_from_member(
|
||||
sds: MutableMapping[_Address, ServerDescription],
|
||||
replica_set_name: Optional[str],
|
||||
server_description: ServerDescription,
|
||||
) -> int:
|
||||
"""RS with known primary. Process a response from a non-primary.
|
||||
|
||||
Pass in a dict of ServerDescriptions, current replica set name, and the
|
||||
ServerDescription we are processing.
|
||||
|
||||
Returns new topology type.
|
||||
"""
|
||||
assert replica_set_name is not None
|
||||
|
||||
if replica_set_name != server_description.replica_set_name:
|
||||
sds.pop(server_description.address)
|
||||
elif server_description.me and server_description.address != server_description.me:
|
||||
sds.pop(server_description.address)
|
||||
|
||||
# Had this member been the primary?
|
||||
return _check_has_primary(sds)
|
||||
|
||||
|
||||
def _update_rs_no_primary_from_member(
|
||||
sds: MutableMapping[_Address, ServerDescription],
|
||||
replica_set_name: Optional[str],
|
||||
server_description: ServerDescription,
|
||||
) -> tuple[int, Optional[str]]:
|
||||
"""RS without known primary. Update from a non-primary's response.
|
||||
|
||||
Pass in a dict of ServerDescriptions, current replica set name, and the
|
||||
ServerDescription we are processing.
|
||||
|
||||
Returns (new topology type, new replica_set_name).
|
||||
"""
|
||||
topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary
|
||||
if replica_set_name is None:
|
||||
replica_set_name = server_description.replica_set_name
|
||||
|
||||
elif replica_set_name != server_description.replica_set_name:
|
||||
sds.pop(server_description.address)
|
||||
return topology_type, replica_set_name
|
||||
|
||||
# This isn't the primary's response, so don't remove any servers
|
||||
# it doesn't report. Only add new servers.
|
||||
for address in server_description.all_hosts:
|
||||
if address not in sds:
|
||||
sds[address] = ServerDescription(address)
|
||||
|
||||
if server_description.me and server_description.address != server_description.me:
|
||||
sds.pop(server_description.address)
|
||||
|
||||
return topology_type, replica_set_name
|
||||
|
||||
|
||||
def _check_has_primary(sds: Mapping[_Address, ServerDescription]) -> int:
|
||||
"""Current topology type is ReplicaSetWithPrimary. Is primary still known?
|
||||
|
||||
Pass in a dict of ServerDescriptions.
|
||||
|
||||
Returns new topology type.
|
||||
"""
|
||||
for s in sds.values():
|
||||
if s.server_type == SERVER_TYPE.RSPrimary:
|
||||
return TOPOLOGY_TYPE.ReplicaSetWithPrimary
|
||||
else: # noqa: PLW0120
|
||||
return TOPOLOGY_TYPE.ReplicaSetNoPrimary
|
||||
@ -1,61 +0,0 @@
|
||||
# Copyright 2022-Present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Type aliases used by PyMongo"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson.typings import _DocumentOut, _DocumentType, _DocumentTypeArg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.collation import Collation
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Common Shared Types.
|
||||
_Address = Tuple[str, Optional[int]]
|
||||
_CollationIn = Union[Mapping[str, Any], "Collation"]
|
||||
_Pipeline = Sequence[Mapping[str, Any]]
|
||||
ClusterTime = Mapping[str, Any]
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def strip_optional(elem: Optional[_T]) -> _T:
|
||||
"""This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T
|
||||
while inside a list comprehension.
|
||||
"""
|
||||
assert elem is not None
|
||||
return elem
|
||||
|
||||
|
||||
__all__ = [
|
||||
"_DocumentOut",
|
||||
"_DocumentType",
|
||||
"_DocumentTypeArg",
|
||||
"_Address",
|
||||
"_CollationIn",
|
||||
"_Pipeline",
|
||||
"strip_optional",
|
||||
]
|
||||
@ -1,624 +0,0 @@
|
||||
# Copyright 2011-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
"""Tools to parse and validate a MongoDB URI."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sized,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from pymongo.asynchronous.client_options import _parse_ssl_options
|
||||
from pymongo.asynchronous.common import (
|
||||
INTERNAL_URI_OPTION_NAME_MAP,
|
||||
SRV_SERVICE_NAME,
|
||||
URI_OPTIONS_DEPRECATION_MAP,
|
||||
_CaseInsensitiveDictionary,
|
||||
get_validated_options,
|
||||
)
|
||||
from pymongo.asynchronous.srv_resolver import _have_dnspython, _SrvResolver
|
||||
from pymongo.asynchronous.typings import _Address
|
||||
from pymongo.errors import ConfigurationError, InvalidURI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pyopenssl_context import SSLContext
|
||||
|
||||
_IS_SYNC = False
|
||||
SCHEME = "mongodb://"
|
||||
SCHEME_LEN = len(SCHEME)
|
||||
SRV_SCHEME = "mongodb+srv://"
|
||||
SRV_SCHEME_LEN = len(SRV_SCHEME)
|
||||
DEFAULT_PORT = 27017
|
||||
|
||||
|
||||
def _unquoted_percent(s: str) -> bool:
|
||||
"""Check for unescaped percent signs.
|
||||
|
||||
:param s: A string. `s` can have things like '%25', '%2525',
|
||||
and '%E2%85%A8' but cannot have unquoted percent like '%foo'.
|
||||
"""
|
||||
for i in range(len(s)):
|
||||
if s[i] == "%":
|
||||
sub = s[i : i + 3]
|
||||
# If unquoting yields the same string this means there was an
|
||||
# unquoted %.
|
||||
if unquote_plus(sub) == sub:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_userinfo(userinfo: str) -> tuple[str, str]:
|
||||
"""Validates the format of user information in a MongoDB URI.
|
||||
Reserved characters that are gen-delimiters (":", "/", "?", "#", "[",
|
||||
"]", "@") as per RFC 3986 must be escaped.
|
||||
|
||||
Returns a 2-tuple containing the unescaped username followed
|
||||
by the unescaped password.
|
||||
|
||||
:param userinfo: A string of the form <username>:<password>
|
||||
"""
|
||||
if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo):
|
||||
raise InvalidURI(
|
||||
"Username and password must be escaped according to "
|
||||
"RFC 3986, use urllib.parse.quote_plus"
|
||||
)
|
||||
|
||||
user, _, passwd = userinfo.partition(":")
|
||||
# No password is expected with GSSAPI authentication.
|
||||
if not user:
|
||||
raise InvalidURI("The empty string is not valid username.")
|
||||
|
||||
return unquote_plus(user), unquote_plus(passwd)
|
||||
|
||||
|
||||
def parse_ipv6_literal_host(
|
||||
entity: str, default_port: Optional[int]
|
||||
) -> tuple[str, Optional[Union[str, int]]]:
|
||||
"""Validates an IPv6 literal host:port string.
|
||||
|
||||
Returns a 2-tuple of IPv6 literal followed by port where
|
||||
port is default_port if it wasn't specified in entity.
|
||||
|
||||
:param entity: A string that represents an IPv6 literal enclosed
|
||||
in braces (e.g. '[::1]' or '[::1]:27017').
|
||||
:param default_port: The port number to use when one wasn't
|
||||
specified in entity.
|
||||
"""
|
||||
if entity.find("]") == -1:
|
||||
raise ValueError(
|
||||
"an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732."
|
||||
)
|
||||
i = entity.find("]:")
|
||||
if i == -1:
|
||||
return entity[1:-1], default_port
|
||||
return entity[1:i], entity[i + 2 :]
|
||||
|
||||
|
||||
def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address:
|
||||
"""Validates a host string
|
||||
|
||||
Returns a 2-tuple of host followed by port where port is default_port
|
||||
if it wasn't specified in the string.
|
||||
|
||||
:param entity: A host or host:port string where host could be a
|
||||
hostname or IP address.
|
||||
:param default_port: The port number to use when one wasn't
|
||||
specified in entity.
|
||||
"""
|
||||
host = entity
|
||||
port: Optional[Union[str, int]] = default_port
|
||||
if entity[0] == "[":
|
||||
host, port = parse_ipv6_literal_host(entity, default_port)
|
||||
elif entity.endswith(".sock"):
|
||||
return entity, default_port
|
||||
elif entity.find(":") != -1:
|
||||
if entity.count(":") > 1:
|
||||
raise ValueError(
|
||||
"Reserved characters such as ':' must be "
|
||||
"escaped according RFC 2396. An IPv6 "
|
||||
"address literal must be enclosed in '[' "
|
||||
"and ']' according to RFC 2732."
|
||||
)
|
||||
host, port = host.split(":", 1)
|
||||
if isinstance(port, str):
|
||||
if not port.isdigit() or int(port) > 65535 or int(port) <= 0:
|
||||
raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}")
|
||||
port = int(port)
|
||||
|
||||
# Normalize hostname to lowercase, since DNS is case-insensitive:
|
||||
# http://tools.ietf.org/html/rfc4343
|
||||
# This prevents useless rediscovery if "foo.com" is in the seed list but
|
||||
# "FOO.com" is in the hello response.
|
||||
return host.lower(), port
|
||||
|
||||
|
||||
# Options whose values are implicitly determined by tlsInsecure.
|
||||
_IMPLICIT_TLSINSECURE_OPTS = {
|
||||
"tlsallowinvalidcertificates",
|
||||
"tlsallowinvalidhostnames",
|
||||
"tlsdisableocspendpointcheck",
|
||||
}
|
||||
|
||||
|
||||
def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary:
|
||||
"""Helper method for split_options which creates the options dict.
|
||||
Also handles the creation of a list for the URI tag_sets/
|
||||
readpreferencetags portion, and the use of a unicode options string.
|
||||
"""
|
||||
options = _CaseInsensitiveDictionary()
|
||||
for uriopt in opts.split(delim):
|
||||
key, value = uriopt.split("=")
|
||||
if key.lower() == "readpreferencetags":
|
||||
options.setdefault(key, []).append(value)
|
||||
else:
|
||||
if key in options:
|
||||
warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2)
|
||||
if key.lower() == "authmechanismproperties":
|
||||
val = value
|
||||
else:
|
||||
val = unquote_plus(value)
|
||||
options[key] = val
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||
"""Raise appropriate errors when conflicting TLS options are present in
|
||||
the options dictionary.
|
||||
|
||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||
MongoDB URI options.
|
||||
"""
|
||||
# Implicitly defined options must not be explicitly specified.
|
||||
tlsinsecure = options.get("tlsinsecure")
|
||||
if tlsinsecure is not None:
|
||||
for opt in _IMPLICIT_TLSINSECURE_OPTS:
|
||||
if opt in options:
|
||||
err_msg = "URI options %s and %s cannot be specified simultaneously."
|
||||
raise InvalidURI(
|
||||
err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt))
|
||||
)
|
||||
|
||||
# Handle co-occurence of OCSP & tlsAllowInvalidCertificates options.
|
||||
tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates")
|
||||
if tlsallowinvalidcerts is not None:
|
||||
if "tlsdisableocspendpointcheck" in options:
|
||||
err_msg = "URI options %s and %s cannot be specified simultaneously."
|
||||
raise InvalidURI(
|
||||
err_msg
|
||||
% ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck"))
|
||||
)
|
||||
if tlsallowinvalidcerts is True:
|
||||
options["tlsdisableocspendpointcheck"] = True
|
||||
|
||||
# Handle co-occurence of CRL and OCSP-related options.
|
||||
tlscrlfile = options.get("tlscrlfile")
|
||||
if tlscrlfile is not None:
|
||||
for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"):
|
||||
if options.get(opt) is True:
|
||||
err_msg = "URI option %s=True cannot be specified when CRL checking is enabled."
|
||||
raise InvalidURI(err_msg % (opt,))
|
||||
|
||||
if "ssl" in options and "tls" in options:
|
||||
|
||||
def truth_value(val: Any) -> Any:
|
||||
if val in ("true", "false"):
|
||||
return val == "true"
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
return val
|
||||
|
||||
if truth_value(options.get("ssl")) != truth_value(options.get("tls")):
|
||||
err_msg = "Can not specify conflicting values for URI options %s and %s."
|
||||
raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls")))
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||
"""Issue appropriate warnings when deprecated options are present in the
|
||||
options dictionary. Removes deprecated option key, value pairs if the
|
||||
options dictionary is found to also have the renamed option.
|
||||
|
||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||
MongoDB URI options.
|
||||
"""
|
||||
for optname in list(options):
|
||||
if optname in URI_OPTIONS_DEPRECATION_MAP:
|
||||
mode, message = URI_OPTIONS_DEPRECATION_MAP[optname]
|
||||
if mode == "renamed":
|
||||
newoptname = message
|
||||
if newoptname in options:
|
||||
warn_msg = "Deprecated option '%s' ignored in favor of '%s'."
|
||||
warnings.warn(
|
||||
warn_msg % (options.cased_key(optname), options.cased_key(newoptname)),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
options.pop(optname)
|
||||
continue
|
||||
warn_msg = "Option '%s' is deprecated, use '%s' instead."
|
||||
warnings.warn(
|
||||
warn_msg % (options.cased_key(optname), newoptname),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif mode == "removed":
|
||||
warn_msg = "Option '%s' is deprecated. %s."
|
||||
warnings.warn(
|
||||
warn_msg % (options.cased_key(optname), message),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||
"""Normalizes option names in the options dictionary by converting them to
|
||||
their internally-used names.
|
||||
|
||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||
MongoDB URI options.
|
||||
"""
|
||||
# Expand the tlsInsecure option.
|
||||
tlsinsecure = options.get("tlsinsecure")
|
||||
if tlsinsecure is not None:
|
||||
for opt in _IMPLICIT_TLSINSECURE_OPTS:
|
||||
# Implicit options are logically the same as tlsInsecure.
|
||||
options[opt] = tlsinsecure
|
||||
|
||||
for optname in list(options):
|
||||
intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None)
|
||||
if intname is not None:
|
||||
options[intname] = options.pop(optname)
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]:
|
||||
"""Validates and normalizes options passed in a MongoDB URI.
|
||||
|
||||
Returns a new dictionary of validated and normalized options. If warn is
|
||||
False then errors will be thrown for invalid options, otherwise they will
|
||||
be ignored and a warning will be issued.
|
||||
|
||||
:param opts: A dict of MongoDB URI options.
|
||||
:param warn: If ``True`` then warnings will be logged and
|
||||
invalid options will be ignored. Otherwise invalid options will
|
||||
cause errors.
|
||||
"""
|
||||
return get_validated_options(opts, warn)
|
||||
|
||||
|
||||
def split_options(
|
||||
opts: str, validate: bool = True, warn: bool = False, normalize: bool = True
|
||||
) -> MutableMapping[str, Any]:
|
||||
"""Takes the options portion of a MongoDB URI, validates each option
|
||||
and returns the options in a dictionary.
|
||||
|
||||
:param opt: A string representing MongoDB URI options.
|
||||
:param validate: If ``True`` (the default), validate and normalize all
|
||||
options.
|
||||
:param warn: If ``False`` (the default), suppress all warnings raised
|
||||
during validation of options.
|
||||
:param normalize: If ``True`` (the default), renames all options to their
|
||||
internally-used names.
|
||||
"""
|
||||
and_idx = opts.find("&")
|
||||
semi_idx = opts.find(";")
|
||||
try:
|
||||
if and_idx >= 0 and semi_idx >= 0:
|
||||
raise InvalidURI("Can not mix '&' and ';' for option separators.")
|
||||
elif and_idx >= 0:
|
||||
options = _parse_options(opts, "&")
|
||||
elif semi_idx >= 0:
|
||||
options = _parse_options(opts, ";")
|
||||
elif opts.find("=") != -1:
|
||||
options = _parse_options(opts, None)
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise InvalidURI("MongoDB URI options are key=value pairs.") from None
|
||||
|
||||
options = _handle_security_options(options)
|
||||
|
||||
options = _handle_option_deprecations(options)
|
||||
|
||||
if normalize:
|
||||
options = _normalize_options(options)
|
||||
|
||||
if validate:
|
||||
options = cast(_CaseInsensitiveDictionary, validate_options(options, warn))
|
||||
if options.get("authsource") == "":
|
||||
raise InvalidURI("the authSource database cannot be an empty string")
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]:
|
||||
"""Takes a string of the form host1[:port],host2[:port]... and
|
||||
splits it into (host, port) tuples. If [:port] isn't present the
|
||||
default_port is used.
|
||||
|
||||
Returns a set of 2-tuples containing the host name (or IP) followed by
|
||||
port number.
|
||||
|
||||
:param hosts: A string of the form host1[:port],host2[:port],...
|
||||
:param default_port: The port number to use when one wasn't specified
|
||||
for a host.
|
||||
"""
|
||||
nodes = []
|
||||
for entity in hosts.split(","):
|
||||
if not entity:
|
||||
raise ConfigurationError("Empty host (or extra comma in host list).")
|
||||
port = default_port
|
||||
# Unix socket entities don't have ports
|
||||
if entity.endswith(".sock"):
|
||||
port = None
|
||||
nodes.append(parse_host(entity, port))
|
||||
return nodes
|
||||
|
||||
|
||||
# Prohibited characters in database name. DB names also can't have ".", but for
|
||||
# backward-compat we allow "db.collection" in URI.
|
||||
_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]")
|
||||
|
||||
_ALLOWED_TXT_OPTS = frozenset(
|
||||
["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"]
|
||||
)
|
||||
|
||||
|
||||
def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None:
|
||||
# Ensure directConnection was not True if there are multiple seeds.
|
||||
if len(nodes) > 1 and options.get("directconnection"):
|
||||
raise ConfigurationError("Cannot specify multiple hosts with directConnection=true")
|
||||
|
||||
if options.get("loadbalanced"):
|
||||
if len(nodes) > 1:
|
||||
raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true")
|
||||
if options.get("directconnection"):
|
||||
raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true")
|
||||
if options.get("replicaset"):
|
||||
raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true")
|
||||
|
||||
|
||||
def parse_uri(
|
||||
uri: str,
|
||||
default_port: Optional[int] = DEFAULT_PORT,
|
||||
validate: bool = True,
|
||||
warn: bool = False,
|
||||
normalize: bool = True,
|
||||
connect_timeout: Optional[float] = None,
|
||||
srv_service_name: Optional[str] = None,
|
||||
srv_max_hosts: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Parse and validate a MongoDB URI.
|
||||
|
||||
Returns a dict of the form::
|
||||
|
||||
{
|
||||
'nodelist': <list of (host, port) tuples>,
|
||||
'username': <username> or None,
|
||||
'password': <password> or None,
|
||||
'database': <database name> or None,
|
||||
'collection': <collection name> or None,
|
||||
'options': <dict of MongoDB URI options>,
|
||||
'fqdn': <fqdn of the MongoDB+SRV URI> or None
|
||||
}
|
||||
|
||||
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
|
||||
to build nodelist and options.
|
||||
|
||||
:param uri: The MongoDB URI to parse.
|
||||
:param default_port: The port number to use when one wasn't specified
|
||||
for a host in the URI.
|
||||
:param validate: If ``True`` (the default), validate and
|
||||
normalize all options. Default: ``True``.
|
||||
:param warn: When validating, if ``True`` then will warn
|
||||
the user then ignore any invalid options or values. If ``False``,
|
||||
validation will error when options are unsupported or values are
|
||||
invalid. Default: ``False``.
|
||||
:param normalize: If ``True``, convert names of URI options
|
||||
to their internally-used names. Default: ``True``.
|
||||
:param connect_timeout: The maximum time in milliseconds to
|
||||
wait for a response from the DNS server.
|
||||
:param srv_service_name: A custom SRV service name
|
||||
|
||||
.. versionchanged:: 4.6
|
||||
The delimiting slash (``/``) between hosts and connection options is now optional.
|
||||
For example, "mongodb://example.com?tls=true" is now a valid URI.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
To better follow RFC 3986, unquoted percent signs ("%") are no longer
|
||||
supported.
|
||||
|
||||
.. versionchanged:: 3.9
|
||||
Added the ``normalize`` parameter.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added support for mongodb+srv:// URIs.
|
||||
|
||||
.. versionchanged:: 3.5
|
||||
Return the original value of the ``readPreference`` MongoDB URI option
|
||||
instead of the validated read preference mode.
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
``warn`` added so invalid options can be ignored.
|
||||
"""
|
||||
if uri.startswith(SCHEME):
|
||||
is_srv = False
|
||||
scheme_free = uri[SCHEME_LEN:]
|
||||
elif uri.startswith(SRV_SCHEME):
|
||||
if not _have_dnspython():
|
||||
python_path = sys.executable or "python"
|
||||
raise ConfigurationError(
|
||||
'The "dnspython" module must be '
|
||||
"installed to use mongodb+srv:// URIs. "
|
||||
"To fix this error install pymongo again:\n "
|
||||
"%s -m pip install pymongo>=4.3" % (python_path)
|
||||
)
|
||||
is_srv = True
|
||||
scheme_free = uri[SRV_SCHEME_LEN:]
|
||||
else:
|
||||
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
|
||||
|
||||
if not scheme_free:
|
||||
raise InvalidURI("Must provide at least one hostname or IP.")
|
||||
|
||||
user = None
|
||||
passwd = None
|
||||
dbase = None
|
||||
collection = None
|
||||
options = _CaseInsensitiveDictionary()
|
||||
|
||||
host_plus_db_part, _, opts = scheme_free.partition("?")
|
||||
if "/" in host_plus_db_part:
|
||||
host_part, _, dbase = host_plus_db_part.partition("/")
|
||||
else:
|
||||
host_part = host_plus_db_part
|
||||
|
||||
if dbase:
|
||||
dbase = unquote_plus(dbase)
|
||||
if "." in dbase:
|
||||
dbase, collection = dbase.split(".", 1)
|
||||
if _BAD_DB_CHARS.search(dbase):
|
||||
raise InvalidURI('Bad database name "%s"' % dbase)
|
||||
else:
|
||||
dbase = None
|
||||
|
||||
if opts:
|
||||
options.update(split_options(opts, validate, warn, normalize))
|
||||
if srv_service_name is None:
|
||||
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
|
||||
if "@" in host_part:
|
||||
userinfo, _, hosts = host_part.rpartition("@")
|
||||
user, passwd = parse_userinfo(userinfo)
|
||||
else:
|
||||
hosts = host_part
|
||||
|
||||
if "/" in hosts:
|
||||
raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part)
|
||||
|
||||
hosts = unquote_plus(hosts)
|
||||
fqdn = None
|
||||
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
|
||||
if is_srv:
|
||||
if options.get("directConnection"):
|
||||
raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs")
|
||||
nodes = split_hosts(hosts, default_port=None)
|
||||
if len(nodes) != 1:
|
||||
raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname")
|
||||
fqdn, port = nodes[0]
|
||||
if port is not None:
|
||||
raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number")
|
||||
|
||||
# Use the connection timeout. connectTimeoutMS passed as a keyword
|
||||
# argument overrides the same option passed in the connection string.
|
||||
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
|
||||
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
|
||||
nodes = dns_resolver.get_hosts()
|
||||
dns_options = dns_resolver.get_options()
|
||||
if dns_options:
|
||||
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
|
||||
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
|
||||
raise ConfigurationError(
|
||||
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
|
||||
)
|
||||
for opt, val in parsed_dns_options.items():
|
||||
if opt not in options:
|
||||
options[opt] = val
|
||||
if options.get("loadBalanced") and srv_max_hosts:
|
||||
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
|
||||
if options.get("replicaSet") and srv_max_hosts:
|
||||
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
|
||||
if "tls" not in options and "ssl" not in options:
|
||||
options["tls"] = True if validate else "true"
|
||||
elif not is_srv and options.get("srvServiceName") is not None:
|
||||
raise ConfigurationError(
|
||||
"The srvServiceName option is only allowed with 'mongodb+srv://' URIs"
|
||||
)
|
||||
elif not is_srv and srv_max_hosts:
|
||||
raise ConfigurationError(
|
||||
"The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs"
|
||||
)
|
||||
else:
|
||||
nodes = split_hosts(hosts, default_port=default_port)
|
||||
|
||||
_check_options(nodes, options)
|
||||
|
||||
return {
|
||||
"nodelist": nodes,
|
||||
"username": user,
|
||||
"password": passwd,
|
||||
"database": dbase,
|
||||
"collection": collection,
|
||||
"options": options,
|
||||
"fqdn": fqdn,
|
||||
}
|
||||
|
||||
|
||||
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]:
|
||||
"""Parse KMS TLS connection options."""
|
||||
if not kms_tls_options:
|
||||
return {}
|
||||
if not isinstance(kms_tls_options, dict):
|
||||
raise TypeError("kms_tls_options must be a dict")
|
||||
contexts = {}
|
||||
for provider, options in kms_tls_options.items():
|
||||
if not isinstance(options, dict):
|
||||
raise TypeError(f'kms_tls_options["{provider}"] must be a dict')
|
||||
options.setdefault("tls", True)
|
||||
opts = _CaseInsensitiveDictionary(options)
|
||||
opts = _handle_security_options(opts)
|
||||
opts = _normalize_options(opts)
|
||||
opts = cast(_CaseInsensitiveDictionary, validate_options(opts))
|
||||
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts)
|
||||
if ssl_context is None:
|
||||
raise ConfigurationError("TLS is required for KMS providers")
|
||||
if allow_invalid_hostnames:
|
||||
raise ConfigurationError("Insecure TLS options prohibited")
|
||||
|
||||
for n in [
|
||||
"tlsInsecure",
|
||||
"tlsAllowInvalidCertificates",
|
||||
"tlsAllowInvalidHostnames",
|
||||
"tlsDisableCertificateRevocationCheck",
|
||||
]:
|
||||
if n in opts:
|
||||
raise ConfigurationError(f"Insecure TLS options prohibited: {n}")
|
||||
contexts[provider] = ssl_context
|
||||
return contexts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pprint
|
||||
|
||||
try:
|
||||
pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203
|
||||
except InvalidURI as exc:
|
||||
print(exc) # noqa: T201
|
||||
sys.exit(0)
|
||||
@ -15,6 +15,7 @@
|
||||
"""Re-import of synchronous Auth API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pymongo.auth_shared import * # noqa: F403
|
||||
from pymongo.synchronous.auth import * # noqa: F403
|
||||
from pymongo.synchronous.auth import __doc__ as original_doc
|
||||
|
||||
|
||||
@ -15,7 +15,9 @@
|
||||
"""Re-import of synchronous AuthOIDC API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pymongo.auth_oidc_shared import * # noqa: F403
|
||||
from pymongo.synchronous.auth_oidc import * # noqa: F403
|
||||
from pymongo.synchronous.auth_oidc import __doc__ as original_doc
|
||||
|
||||
__doc__ = original_doc
|
||||
__all__ = ["OIDCCallback", "OIDCCallbackContext", "OIDCCallbackResult", "OIDCIdPInfo"] # noqa: F405
|
||||
|
||||
118
pymongo/auth_oidc_shared.py
Normal file
118
pymongo/auth_oidc_shared.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
"""Constants, types, and classes shared across OIDC auth implementations."""
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from pymongo._azure_helpers import _get_azure_response
|
||||
from pymongo._gcp_helpers import _get_gcp_response
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCIdPInfo:
|
||||
issuer: str
|
||||
clientId: Optional[str] = field(default=None)
|
||||
requestScopes: Optional[list[str]] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCCallbackContext:
|
||||
timeout_seconds: float
|
||||
username: str
|
||||
version: int
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[OIDCIdPInfo] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCCallbackResult:
|
||||
access_token: str
|
||||
expires_in_seconds: Optional[float] = field(default=None)
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
class OIDCCallback(abc.ABC):
|
||||
"""A base class for defining OIDC callbacks."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
"""Convert the given BSON value into our own type."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCProperties:
|
||||
callback: Optional[OIDCCallback] = field(default=None)
|
||||
human_callback: Optional[OIDCCallback] = field(default=None)
|
||||
environment: Optional[str] = field(default=None)
|
||||
allowed_hosts: list[str] = field(default_factory=list)
|
||||
token_resource: Optional[str] = field(default=None)
|
||||
username: str = ""
|
||||
|
||||
|
||||
"""Mechanism properties for MONGODB-OIDC authentication."""
|
||||
|
||||
TOKEN_BUFFER_MINUTES = 5
|
||||
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
|
||||
CALLBACK_VERSION = 1
|
||||
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
|
||||
TIME_BETWEEN_CALLS_SECONDS = 0.1
|
||||
|
||||
|
||||
class _OIDCTestCallback(OIDCCallback):
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
token_file = os.environ.get("OIDC_TOKEN_FILE")
|
||||
if not token_file:
|
||||
raise RuntimeError(
|
||||
'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set'
|
||||
)
|
||||
with open(token_file) as fid:
|
||||
return OIDCCallbackResult(access_token=fid.read().strip())
|
||||
|
||||
|
||||
class _OIDCAWSCallback(OIDCCallback):
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
|
||||
if not token_file:
|
||||
raise RuntimeError(
|
||||
'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set'
|
||||
)
|
||||
with open(token_file) as fid:
|
||||
return OIDCCallbackResult(access_token=fid.read().strip())
|
||||
|
||||
|
||||
class _OIDCAzureCallback(OIDCCallback):
|
||||
def __init__(self, token_resource: str) -> None:
|
||||
self.token_resource = quote(token_resource)
|
||||
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds)
|
||||
return OIDCCallbackResult(
|
||||
access_token=resp["access_token"], expires_in_seconds=resp["expires_in"]
|
||||
)
|
||||
|
||||
|
||||
class _OIDCGCPCallback(OIDCCallback):
|
||||
def __init__(self, token_resource: str) -> None:
|
||||
self.token_resource = quote(token_resource)
|
||||
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
|
||||
return OIDCCallbackResult(access_token=resp["access_token"])
|
||||
236
pymongo/auth_shared.py
Normal file
236
pymongo/auth_shared.py
Normal file
@ -0,0 +1,236 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
"""Constants and types shared across multiple auth types."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing
|
||||
from base64 import standard_b64encode
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
from bson import Binary
|
||||
from pymongo.auth_oidc_shared import (
|
||||
_OIDCAzureCallback,
|
||||
_OIDCGCPCallback,
|
||||
_OIDCProperties,
|
||||
_OIDCTestCallback,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
MECHANISMS = frozenset(
|
||||
[
|
||||
"GSSAPI",
|
||||
"MONGODB-CR",
|
||||
"MONGODB-OIDC",
|
||||
"MONGODB-X509",
|
||||
"MONGODB-AWS",
|
||||
"PLAIN",
|
||||
"SCRAM-SHA-1",
|
||||
"SCRAM-SHA-256",
|
||||
"DEFAULT",
|
||||
]
|
||||
)
|
||||
"""The authentication mechanisms supported by PyMongo."""
|
||||
|
||||
|
||||
class _Cache:
|
||||
__slots__ = ("data",)
|
||||
|
||||
_hash_val = hash("_Cache")
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = None
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
# Two instances must always compare equal.
|
||||
if isinstance(other, _Cache):
|
||||
return True
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
if isinstance(other, _Cache):
|
||||
return False
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self._hash_val
|
||||
|
||||
|
||||
MongoCredential = namedtuple(
|
||||
"MongoCredential",
|
||||
["mechanism", "source", "username", "password", "mechanism_properties", "cache"],
|
||||
)
|
||||
"""A hashable namedtuple of values used for authentication."""
|
||||
|
||||
|
||||
GSSAPIProperties = namedtuple(
|
||||
"GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"]
|
||||
)
|
||||
"""Mechanism properties for GSSAPI authentication."""
|
||||
|
||||
|
||||
_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
|
||||
"""Mechanism properties for MONGODB-AWS authentication."""
|
||||
|
||||
|
||||
def _build_credentials_tuple(
|
||||
mech: str,
|
||||
source: Optional[str],
|
||||
user: str,
|
||||
passwd: str,
|
||||
extra: Mapping[str, Any],
|
||||
database: Optional[str],
|
||||
) -> MongoCredential:
|
||||
"""Build and return a mechanism specific credentials tuple."""
|
||||
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
|
||||
raise ConfigurationError(f"{mech} requires a username.")
|
||||
if mech == "GSSAPI":
|
||||
if source is not None and source != "$external":
|
||||
raise ValueError("authentication source must be $external or None for GSSAPI")
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
service_name = properties.get("SERVICE_NAME", "mongodb")
|
||||
canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False))
|
||||
service_realm = properties.get("SERVICE_REALM")
|
||||
props = GSSAPIProperties(
|
||||
service_name=service_name,
|
||||
canonicalize_host_name=canonicalize,
|
||||
service_realm=service_realm,
|
||||
)
|
||||
# Source is always $external.
|
||||
return MongoCredential(mech, "$external", user, passwd, props, None)
|
||||
elif mech == "MONGODB-X509":
|
||||
if passwd is not None:
|
||||
raise ConfigurationError("Passwords are not supported by MONGODB-X509")
|
||||
if source is not None and source != "$external":
|
||||
raise ValueError("authentication source must be $external or None for MONGODB-X509")
|
||||
# Source is always $external, user can be None.
|
||||
return MongoCredential(mech, "$external", user, None, None, None)
|
||||
elif mech == "MONGODB-AWS":
|
||||
if user is not None and passwd is None:
|
||||
raise ConfigurationError("username without a password is not supported by MONGODB-AWS")
|
||||
if source is not None and source != "$external":
|
||||
raise ConfigurationError(
|
||||
"authentication source must be $external or None for MONGODB-AWS"
|
||||
)
|
||||
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
aws_session_token = properties.get("AWS_SESSION_TOKEN")
|
||||
aws_props = _AWSProperties(aws_session_token=aws_session_token)
|
||||
# user can be None for temporary link-local EC2 credentials.
|
||||
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
|
||||
elif mech == "MONGODB-OIDC":
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
callback = properties.get("OIDC_CALLBACK")
|
||||
human_callback = properties.get("OIDC_HUMAN_CALLBACK")
|
||||
environ = properties.get("ENVIRONMENT")
|
||||
token_resource = properties.get("TOKEN_RESOURCE", "")
|
||||
default_allowed = [
|
||||
"*.mongodb.net",
|
||||
"*.mongodb-dev.net",
|
||||
"*.mongodb-qa.net",
|
||||
"*.mongodbgov.net",
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
]
|
||||
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
|
||||
msg = (
|
||||
"authentication with MONGODB-OIDC requires providing either a callback or a environment"
|
||||
)
|
||||
if passwd is not None:
|
||||
msg = "password is not supported by MONGODB-OIDC"
|
||||
raise ConfigurationError(msg)
|
||||
if callback or human_callback:
|
||||
if environ is not None:
|
||||
raise ConfigurationError(msg)
|
||||
if callback and human_callback:
|
||||
msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK"
|
||||
raise ConfigurationError(msg)
|
||||
elif environ is not None:
|
||||
if environ == "test":
|
||||
if user is not None:
|
||||
msg = "test environment for MONGODB-OIDC does not support username"
|
||||
raise ConfigurationError(msg)
|
||||
callback = _OIDCTestCallback()
|
||||
elif environ == "azure":
|
||||
passwd = None
|
||||
if not token_resource:
|
||||
raise ConfigurationError(
|
||||
"Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
|
||||
)
|
||||
callback = _OIDCAzureCallback(token_resource)
|
||||
elif environ == "gcp":
|
||||
passwd = None
|
||||
if not token_resource:
|
||||
raise ConfigurationError(
|
||||
"GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
|
||||
)
|
||||
callback = _OIDCGCPCallback(token_resource)
|
||||
else:
|
||||
raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}")
|
||||
else:
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
oidc_props = _OIDCProperties(
|
||||
callback=callback,
|
||||
human_callback=human_callback,
|
||||
environment=environ,
|
||||
allowed_hosts=allowed_hosts,
|
||||
token_resource=token_resource,
|
||||
username=user,
|
||||
)
|
||||
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
|
||||
|
||||
elif mech == "PLAIN":
|
||||
source_database = source or database or "$external"
|
||||
return MongoCredential(mech, source_database, user, passwd, None, None)
|
||||
else:
|
||||
source_database = source or database or "admin"
|
||||
if passwd is None:
|
||||
raise ConfigurationError("A password is required.")
|
||||
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
|
||||
|
||||
|
||||
def _xor(fir: bytes, sec: bytes) -> bytes:
|
||||
"""XOR two byte strings together."""
|
||||
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
|
||||
|
||||
|
||||
def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]:
|
||||
"""Split a scram response into key, value pairs."""
|
||||
return dict(
|
||||
typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1))
|
||||
for item in response.split(b",")
|
||||
)
|
||||
|
||||
|
||||
def _authenticate_scram_start(
|
||||
credentials: MongoCredential, mechanism: str
|
||||
) -> tuple[bytes, bytes, typing.MutableMapping[str, Any]]:
|
||||
username = credentials.username
|
||||
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
|
||||
nonce = standard_b64encode(os.urandom(32))
|
||||
first_bare = b"n=" + user + b",r=" + nonce
|
||||
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": mechanism,
|
||||
"payload": Binary(b"n,," + first_bare),
|
||||
"autoAuthorize": 1,
|
||||
"options": {"skipEmptyExchange": True},
|
||||
}
|
||||
return nonce, first_bare, cmd
|
||||
131
pymongo/bulk_shared.py
Normal file
131
pymongo/bulk_shared.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
"""Constants, types, and classes shared across Bulk Write API implementations."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, NoReturn
|
||||
|
||||
from pymongo.errors import BulkWriteError, OperationFailure
|
||||
from pymongo.helpers_shared import _get_wce_doc
|
||||
from pymongo.message import (
|
||||
_DELETE,
|
||||
_INSERT,
|
||||
_UPDATE,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.typings import _DocumentOut
|
||||
|
||||
|
||||
_DELETE_ALL: int = 0
|
||||
_DELETE_ONE: int = 1
|
||||
|
||||
# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err
|
||||
_BAD_VALUE: int = 2
|
||||
_UNKNOWN_ERROR: int = 8
|
||||
_WRITE_CONCERN_ERROR: int = 64
|
||||
|
||||
_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete")
|
||||
|
||||
|
||||
class _Run:
|
||||
"""Represents a batch of write operations."""
|
||||
|
||||
def __init__(self, op_type: int) -> None:
|
||||
"""Initialize a new Run object."""
|
||||
self.op_type: int = op_type
|
||||
self.index_map: list[int] = []
|
||||
self.ops: list[Any] = []
|
||||
self.idx_offset: int = 0
|
||||
|
||||
def index(self, idx: int) -> int:
|
||||
"""Get the original index of an operation in this run.
|
||||
|
||||
:param idx: The Run index that maps to the original index.
|
||||
"""
|
||||
return self.index_map[idx]
|
||||
|
||||
def add(self, original_index: int, operation: Any) -> None:
|
||||
"""Add an operation to this Run instance.
|
||||
|
||||
:param original_index: The original index of this operation
|
||||
within a larger bulk operation.
|
||||
:param operation: The operation document.
|
||||
"""
|
||||
self.index_map.append(original_index)
|
||||
self.ops.append(operation)
|
||||
|
||||
|
||||
def _merge_command(
|
||||
run: _Run,
|
||||
full_result: MutableMapping[str, Any],
|
||||
offset: int,
|
||||
result: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Merge a write command result into the full bulk result."""
|
||||
affected = result.get("n", 0)
|
||||
|
||||
if run.op_type == _INSERT:
|
||||
full_result["nInserted"] += affected
|
||||
|
||||
elif run.op_type == _DELETE:
|
||||
full_result["nRemoved"] += affected
|
||||
|
||||
elif run.op_type == _UPDATE:
|
||||
upserted = result.get("upserted")
|
||||
if upserted:
|
||||
n_upserted = len(upserted)
|
||||
for doc in upserted:
|
||||
doc["index"] = run.index(doc["index"] + offset)
|
||||
full_result["upserted"].extend(upserted)
|
||||
full_result["nUpserted"] += n_upserted
|
||||
full_result["nMatched"] += affected - n_upserted
|
||||
else:
|
||||
full_result["nMatched"] += affected
|
||||
full_result["nModified"] += result["nModified"]
|
||||
|
||||
write_errors = result.get("writeErrors")
|
||||
if write_errors:
|
||||
for doc in write_errors:
|
||||
# Leave the server response intact for APM.
|
||||
replacement = doc.copy()
|
||||
idx = doc["index"] + offset
|
||||
replacement["index"] = run.index(idx)
|
||||
# Add the failed operation to the error document.
|
||||
replacement["op"] = run.ops[idx]
|
||||
full_result["writeErrors"].append(replacement)
|
||||
|
||||
wce = _get_wce_doc(result)
|
||||
if wce:
|
||||
full_result["writeConcernErrors"].append(wce)
|
||||
|
||||
|
||||
def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn:
|
||||
"""Raise a BulkWriteError from the full bulk api result."""
|
||||
# retryWrites on MMAPv1 should raise an actionable error.
|
||||
if full_result["writeErrors"]:
|
||||
full_result["writeErrors"].sort(key=lambda error: error["index"])
|
||||
err = full_result["writeErrors"][0]
|
||||
code = err["code"]
|
||||
msg = err["errmsg"]
|
||||
if code == 20 and msg.startswith("Transaction numbers"):
|
||||
errmsg = (
|
||||
"This MongoDB deployment does not support "
|
||||
"retryable writes. Please add retryWrites=false "
|
||||
"to your connection string."
|
||||
)
|
||||
raise OperationFailure(errmsg, code, full_result)
|
||||
raise BulkWriteError(full_result)
|
||||
@ -19,3 +19,4 @@ from pymongo.synchronous.change_stream import * # noqa: F403
|
||||
from pymongo.synchronous.change_stream import __doc__ as original_doc
|
||||
|
||||
__doc__ = original_doc
|
||||
__all__ = ["ChangeStream", "ClusterChangeStream", "CollectionChangeStream", "DatabaseChangeStream"] # noqa: F405
|
||||
|
||||
@ -1,21 +1,332 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Re-import of synchronous ClientOptions API for compatibility."""
|
||||
"""Tools to parse mongo client options."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pymongo.synchronous.client_options import * # noqa: F403
|
||||
from pymongo.synchronous.client_options import __doc__ as original_doc
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast
|
||||
|
||||
__doc__ = original_doc
|
||||
from bson.codec_options import _parse_codec_options
|
||||
from pymongo import common
|
||||
from pymongo.compression_support import CompressionSettings
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.monitoring import _EventListener, _EventListeners
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import (
|
||||
_ServerMode,
|
||||
make_read_preference,
|
||||
read_pref_mode_from_name,
|
||||
)
|
||||
from pymongo.server_selectors import any_server_selector
|
||||
from pymongo.ssl_support import get_ssl_context
|
||||
from pymongo.write_concern import WriteConcern, validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.codec_options import CodecOptions
|
||||
from pymongo.auth_shared import MongoCredential
|
||||
from pymongo.encryption_options import AutoEncryptionOpts
|
||||
from pymongo.pyopenssl_context import SSLContext
|
||||
from pymongo.topology_description import _ServerSelector
|
||||
|
||||
|
||||
def _parse_credentials(
|
||||
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
) -> Optional[MongoCredential]:
|
||||
"""Parse authentication credentials."""
|
||||
mechanism = options.get("authmechanism", "DEFAULT" if username else None)
|
||||
source = options.get("authsource")
|
||||
if username or mechanism:
|
||||
from pymongo.auth_shared import _build_credentials_tuple
|
||||
|
||||
return _build_credentials_tuple(mechanism, source, username, password, options, database)
|
||||
return None
|
||||
|
||||
|
||||
def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode:
|
||||
"""Parse read preference options."""
|
||||
if "read_preference" in options:
|
||||
return options["read_preference"]
|
||||
|
||||
name = options.get("readpreference", "primary")
|
||||
mode = read_pref_mode_from_name(name)
|
||||
tags = options.get("readpreferencetags")
|
||||
max_staleness = options.get("maxstalenessseconds", -1)
|
||||
return make_read_preference(mode, tags, max_staleness)
|
||||
|
||||
|
||||
def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern:
|
||||
"""Parse write concern options."""
|
||||
concern = options.get("w")
|
||||
wtimeout = options.get("wtimeoutms")
|
||||
j = options.get("journal")
|
||||
fsync = options.get("fsync")
|
||||
return WriteConcern(concern, wtimeout, j, fsync)
|
||||
|
||||
|
||||
def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
|
||||
"""Parse read concern options."""
|
||||
concern = options.get("readconcernlevel")
|
||||
return ReadConcern(concern)
|
||||
|
||||
|
||||
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
|
||||
"""Parse ssl options."""
|
||||
use_tls = options.get("tls")
|
||||
if use_tls is not None:
|
||||
validate_boolean("tls", use_tls)
|
||||
|
||||
certfile = options.get("tlscertificatekeyfile")
|
||||
passphrase = options.get("tlscertificatekeyfilepassword")
|
||||
ca_certs = options.get("tlscafile")
|
||||
crlfile = options.get("tlscrlfile")
|
||||
allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False)
|
||||
allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False)
|
||||
disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False)
|
||||
|
||||
enabled_tls_opts = []
|
||||
for opt in (
|
||||
"tlscertificatekeyfile",
|
||||
"tlscertificatekeyfilepassword",
|
||||
"tlscafile",
|
||||
"tlscrlfile",
|
||||
):
|
||||
# Any non-null value of these options implies tls=True.
|
||||
if opt in options and options[opt]:
|
||||
enabled_tls_opts.append(opt)
|
||||
for opt in (
|
||||
"tlsallowinvalidcertificates",
|
||||
"tlsallowinvalidhostnames",
|
||||
"tlsdisableocspendpointcheck",
|
||||
):
|
||||
# A value of False for these options implies tls=True.
|
||||
if opt in options and not options[opt]:
|
||||
enabled_tls_opts.append(opt)
|
||||
|
||||
if enabled_tls_opts:
|
||||
if use_tls is None:
|
||||
# Implicitly enable TLS when one of the tls* options is set.
|
||||
use_tls = True
|
||||
elif not use_tls:
|
||||
# Error since tls is explicitly disabled but a tls option is set.
|
||||
raise ConfigurationError(
|
||||
"TLS has not been enabled but the "
|
||||
"following tls parameters have been set: "
|
||||
"%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts)
|
||||
)
|
||||
|
||||
if use_tls:
|
||||
ctx = get_ssl_context(
|
||||
certfile,
|
||||
passphrase,
|
||||
ca_certs,
|
||||
crlfile,
|
||||
allow_invalid_certificates,
|
||||
allow_invalid_hostnames,
|
||||
disable_ocsp_endpoint_check,
|
||||
)
|
||||
return ctx, allow_invalid_hostnames
|
||||
return None, allow_invalid_hostnames
|
||||
|
||||
|
||||
def _parse_pool_options(
|
||||
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
) -> PoolOptions:
|
||||
"""Parse connection pool options."""
|
||||
credentials = _parse_credentials(username, password, database, options)
|
||||
max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE)
|
||||
min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE)
|
||||
max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC)
|
||||
if max_pool_size is not None and min_pool_size > max_pool_size:
|
||||
raise ValueError("minPoolSize must be smaller or equal to maxPoolSize")
|
||||
connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT)
|
||||
socket_timeout = options.get("sockettimeoutms")
|
||||
wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT)
|
||||
event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners"))
|
||||
appname = options.get("appname")
|
||||
driver = options.get("driver")
|
||||
server_api = options.get("server_api")
|
||||
compression_settings = CompressionSettings(
|
||||
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
|
||||
)
|
||||
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
|
||||
load_balanced = options.get("loadbalanced")
|
||||
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
|
||||
return PoolOptions(
|
||||
max_pool_size,
|
||||
min_pool_size,
|
||||
max_idle_time_seconds,
|
||||
connect_timeout,
|
||||
socket_timeout,
|
||||
wait_queue_timeout,
|
||||
ssl_context,
|
||||
tls_allow_invalid_hostnames,
|
||||
_EventListeners(event_listeners),
|
||||
appname,
|
||||
driver,
|
||||
compression_settings,
|
||||
max_connecting=max_connecting,
|
||||
server_api=server_api,
|
||||
load_balanced=load_balanced,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
|
||||
class ClientOptions:
|
||||
"""Read only configuration options for an AsyncMongoClient/MongoClient.
|
||||
|
||||
Should not be instantiated directly by application developers. Access
|
||||
a client's options via :attr:`pymongo.mongo_client.AsyncMongoClient.options` or :attr:`pymongo.mongo_client.MongoClient.options`
|
||||
instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
):
|
||||
self.__options = options
|
||||
self.__codec_options = _parse_codec_options(options)
|
||||
self.__direct_connection = options.get("directconnection")
|
||||
self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS)
|
||||
# self.__server_selection_timeout is in seconds. Must use full name for
|
||||
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
|
||||
self.__server_selection_timeout = options.get(
|
||||
"serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT
|
||||
)
|
||||
self.__pool_options = _parse_pool_options(username, password, database, options)
|
||||
self.__read_preference = _parse_read_preference(options)
|
||||
self.__replica_set_name = options.get("replicaset")
|
||||
self.__write_concern = _parse_write_concern(options)
|
||||
self.__read_concern = _parse_read_concern(options)
|
||||
self.__connect = options.get("connect")
|
||||
self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY)
|
||||
self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES)
|
||||
self.__retry_reads = options.get("retryreads", common.RETRY_READS)
|
||||
self.__server_selector = options.get("server_selector", any_server_selector)
|
||||
self.__auto_encryption_opts = options.get("auto_encryption_opts")
|
||||
self.__load_balanced = options.get("loadbalanced")
|
||||
self.__timeout = options.get("timeoutms")
|
||||
self.__server_monitoring_mode = options.get(
|
||||
"servermonitoringmode", common.SERVER_MONITORING_MODE
|
||||
)
|
||||
|
||||
@property
|
||||
def _options(self) -> Mapping[str, Any]:
|
||||
"""The original options used to create this ClientOptions."""
|
||||
return self.__options
|
||||
|
||||
@property
|
||||
def connect(self) -> Optional[bool]:
|
||||
"""Whether to begin discovering a MongoDB topology automatically."""
|
||||
return self.__connect
|
||||
|
||||
@property
|
||||
def codec_options(self) -> CodecOptions:
|
||||
"""A :class:`~bson.codec_options.CodecOptions` instance."""
|
||||
return self.__codec_options
|
||||
|
||||
@property
|
||||
def direct_connection(self) -> Optional[bool]:
|
||||
"""Whether to connect to the deployment in 'Single' topology."""
|
||||
return self.__direct_connection
|
||||
|
||||
@property
|
||||
def local_threshold_ms(self) -> int:
|
||||
"""The local threshold for this instance."""
|
||||
return self.__local_threshold_ms
|
||||
|
||||
@property
|
||||
def server_selection_timeout(self) -> int:
|
||||
"""The server selection timeout for this instance in seconds."""
|
||||
return self.__server_selection_timeout
|
||||
|
||||
@property
|
||||
def server_selector(self) -> _ServerSelector:
|
||||
return self.__server_selector
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self) -> int:
|
||||
"""The monitoring frequency in seconds."""
|
||||
return self.__heartbeat_frequency
|
||||
|
||||
@property
|
||||
def pool_options(self) -> PoolOptions:
|
||||
"""A :class:`~pymongo.pool.PoolOptions` instance."""
|
||||
return self.__pool_options
|
||||
|
||||
@property
|
||||
def read_preference(self) -> _ServerMode:
|
||||
"""A read preference instance."""
|
||||
return self.__read_preference
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""Replica set name or None."""
|
||||
return self.__replica_set_name
|
||||
|
||||
@property
|
||||
def write_concern(self) -> WriteConcern:
|
||||
"""A :class:`~pymongo.write_concern.WriteConcern` instance."""
|
||||
return self.__write_concern
|
||||
|
||||
@property
|
||||
def read_concern(self) -> ReadConcern:
|
||||
"""A :class:`~pymongo.read_concern.ReadConcern` instance."""
|
||||
return self.__read_concern
|
||||
|
||||
@property
|
||||
def timeout(self) -> Optional[float]:
|
||||
"""The configured timeoutMS converted to seconds, or None.
|
||||
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
return self.__timeout
|
||||
|
||||
@property
|
||||
def retry_writes(self) -> bool:
|
||||
"""If this instance should retry supported write operations."""
|
||||
return self.__retry_writes
|
||||
|
||||
@property
|
||||
def retry_reads(self) -> bool:
|
||||
"""If this instance should retry supported read operations."""
|
||||
return self.__retry_reads
|
||||
|
||||
@property
|
||||
def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]:
|
||||
"""A :class:`~pymongo.encryption.AutoEncryptionOpts` or None."""
|
||||
return self.__auto_encryption_opts
|
||||
|
||||
@property
|
||||
def load_balanced(self) -> Optional[bool]:
|
||||
"""True if the client was configured to connect to a load balancer."""
|
||||
return self.__load_balanced
|
||||
|
||||
@property
|
||||
def event_listeners(self) -> list[_EventListeners]:
|
||||
"""The event listeners registered for this client.
|
||||
|
||||
See :mod:`~pymongo.monitoring` for details.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
assert self.__pool_options._event_listeners is not None
|
||||
return self.__pool_options._event_listeners.event_listeners()
|
||||
|
||||
@property
|
||||
def server_monitoring_mode(self) -> str:
|
||||
"""The configured serverMonitoringMode option.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
"""
|
||||
return self.__server_monitoring_mode
|
||||
|
||||
@ -19,3 +19,4 @@ from pymongo.synchronous.client_session import * # noqa: F403
|
||||
from pymongo.synchronous.client_session import __doc__ as original_doc
|
||||
|
||||
__doc__ = original_doc
|
||||
__all__ = ["ClientSession", "SessionOptions", "TransactionOptions"] # noqa: F405
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
# Copyright 2016-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,10 +12,213 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Re-import of synchronous Collation API for compatibility."""
|
||||
"""Tools for working with `collations`_.
|
||||
|
||||
.. _collations: https://www.mongodb.com/docs/manual/reference/collation/
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pymongo.synchronous.collation import * # noqa: F403
|
||||
from pymongo.synchronous.collation import __doc__ as original_doc
|
||||
from typing import Any, Mapping, Optional, Union
|
||||
|
||||
__doc__ = original_doc
|
||||
from pymongo import common
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
|
||||
class CollationStrength:
|
||||
"""
|
||||
An enum that defines values for `strength` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
PRIMARY = 1
|
||||
"""Differentiate base (unadorned) characters."""
|
||||
|
||||
SECONDARY = 2
|
||||
"""Differentiate character accents."""
|
||||
|
||||
TERTIARY = 3
|
||||
"""Differentiate character case."""
|
||||
|
||||
QUATERNARY = 4
|
||||
"""Differentiate words with and without punctuation."""
|
||||
|
||||
IDENTICAL = 5
|
||||
"""Differentiate unicode code point (characters are exactly identical)."""
|
||||
|
||||
|
||||
class CollationAlternate:
|
||||
"""
|
||||
An enum that defines values for `alternate` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
NON_IGNORABLE = "non-ignorable"
|
||||
"""Spaces and punctuation are treated as base characters."""
|
||||
|
||||
SHIFTED = "shifted"
|
||||
"""Spaces and punctuation are *not* considered base characters.
|
||||
|
||||
Spaces and punctuation are distinguished regardless when the
|
||||
:class:`~pymongo.collation.Collation` strength is at least
|
||||
:data:`~pymongo.collation.CollationStrength.QUATERNARY`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class CollationMaxVariable:
|
||||
"""
|
||||
An enum that defines values for `max_variable` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
PUNCT = "punct"
|
||||
"""Both punctuation and spaces are ignored."""
|
||||
|
||||
SPACE = "space"
|
||||
"""Spaces alone are ignored."""
|
||||
|
||||
|
||||
class CollationCaseFirst:
|
||||
"""
|
||||
An enum that defines values for `case_first` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
UPPER = "upper"
|
||||
"""Sort uppercase characters first."""
|
||||
|
||||
LOWER = "lower"
|
||||
"""Sort lowercase characters first."""
|
||||
|
||||
OFF = "off"
|
||||
"""Default for locale or collation strength."""
|
||||
|
||||
|
||||
class Collation:
|
||||
"""Collation
|
||||
|
||||
:param locale: (string) The locale of the collation. This should be a string
|
||||
that identifies an `ICU locale ID` exactly. For example, ``en_US`` is
|
||||
valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB
|
||||
documentation for a list of supported locales.
|
||||
:param caseLevel: (optional) If ``True``, turn on case sensitivity if
|
||||
`strength` is 1 or 2 (case sensitivity is implied if `strength` is
|
||||
greater than 2). Defaults to ``False``.
|
||||
:param caseFirst: (optional) Specify that either uppercase or lowercase
|
||||
characters take precedence. Must be one of the following values:
|
||||
|
||||
* :data:`~CollationCaseFirst.UPPER`
|
||||
* :data:`~CollationCaseFirst.LOWER`
|
||||
* :data:`~CollationCaseFirst.OFF` (the default)
|
||||
|
||||
:param strength: Specify the comparison strength. This is also
|
||||
known as the ICU comparison level. This must be one of the following
|
||||
values:
|
||||
|
||||
* :data:`~CollationStrength.PRIMARY`
|
||||
* :data:`~CollationStrength.SECONDARY`
|
||||
* :data:`~CollationStrength.TERTIARY` (the default)
|
||||
* :data:`~CollationStrength.QUATERNARY`
|
||||
* :data:`~CollationStrength.IDENTICAL`
|
||||
|
||||
Each successive level builds upon the previous. For example, a
|
||||
`strength` of :data:`~CollationStrength.SECONDARY` differentiates
|
||||
characters based both on the unadorned base character and its accents.
|
||||
|
||||
:param numericOrdering: If ``True``, order numbers numerically
|
||||
instead of in collation order (defaults to ``False``).
|
||||
:param alternate: Specify whether spaces and punctuation are
|
||||
considered base characters. This must be one of the following values:
|
||||
|
||||
* :data:`~CollationAlternate.NON_IGNORABLE` (the default)
|
||||
* :data:`~CollationAlternate.SHIFTED`
|
||||
|
||||
:param maxVariable: When `alternate` is
|
||||
:data:`~CollationAlternate.SHIFTED`, this option specifies what
|
||||
characters may be ignored. This must be one of the following values:
|
||||
|
||||
* :data:`~CollationMaxVariable.PUNCT` (the default)
|
||||
* :data:`~CollationMaxVariable.SPACE`
|
||||
|
||||
:param normalization: If ``True``, normalizes text into Unicode
|
||||
NFD. Defaults to ``False``.
|
||||
:param backwards: If ``True``, accents on characters are
|
||||
considered from the back of the word to the front, as it is done in some
|
||||
French dictionary ordering traditions. Defaults to ``False``.
|
||||
:param kwargs: Keyword arguments supplying any additional options
|
||||
to be sent with this Collation object.
|
||||
|
||||
.. versionadded: 3.4
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
locale: str,
|
||||
caseLevel: Optional[bool] = None,
|
||||
caseFirst: Optional[str] = None,
|
||||
strength: Optional[int] = None,
|
||||
numericOrdering: Optional[bool] = None,
|
||||
alternate: Optional[str] = None,
|
||||
maxVariable: Optional[str] = None,
|
||||
normalization: Optional[bool] = None,
|
||||
backwards: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
locale = common.validate_string("locale", locale)
|
||||
self.__document: dict[str, Any] = {"locale": locale}
|
||||
if caseLevel is not None:
|
||||
self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel)
|
||||
if caseFirst is not None:
|
||||
self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst)
|
||||
if strength is not None:
|
||||
self.__document["strength"] = common.validate_integer("strength", strength)
|
||||
if numericOrdering is not None:
|
||||
self.__document["numericOrdering"] = validate_boolean(
|
||||
"numericOrdering", numericOrdering
|
||||
)
|
||||
if alternate is not None:
|
||||
self.__document["alternate"] = common.validate_string("alternate", alternate)
|
||||
if maxVariable is not None:
|
||||
self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable)
|
||||
if normalization is not None:
|
||||
self.__document["normalization"] = validate_boolean("normalization", normalization)
|
||||
if backwards is not None:
|
||||
self.__document["backwards"] = validate_boolean("backwards", backwards)
|
||||
self.__document.update(kwargs)
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""The document representation of this collation.
|
||||
|
||||
.. note::
|
||||
:class:`Collation` is immutable. Mutating the value of
|
||||
:attr:`document` does not mutate this :class:`Collation`.
|
||||
"""
|
||||
return self.__document.copy()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
document = self.document
|
||||
return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Collation):
|
||||
return self.document == other.document
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
def validate_collation_or_none(
|
||||
value: Optional[Union[Mapping[str, Any], Collation]]
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, Collation):
|
||||
return value.document
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")
|
||||
|
||||
@ -19,3 +19,7 @@ from pymongo.synchronous.collection import * # noqa: F403
|
||||
from pymongo.synchronous.collection import __doc__ as original_doc
|
||||
|
||||
__doc__ = original_doc
|
||||
__all__ = [ # noqa: F405
|
||||
"Collection",
|
||||
"ReturnDocument",
|
||||
]
|
||||
|
||||
@ -19,3 +19,4 @@ from pymongo.synchronous.command_cursor import * # noqa: F403
|
||||
from pymongo.synchronous.command_cursor import __doc__ as original_doc
|
||||
|
||||
__doc__ = original_doc
|
||||
__all__ = ["CommandCursor", "RawBatchCommandCursor"] # noqa: F405
|
||||
|
||||
@ -40,22 +40,21 @@ from bson import SON
|
||||
from bson.binary import UuidRepresentation
|
||||
from bson.codec_options import CodecOptions, DatetimeConversion, TypeRegistry
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo.driver_info import DriverInfo
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.synchronous.compression_support import (
|
||||
from pymongo.compression_support import (
|
||||
validate_compressors,
|
||||
validate_zlib_compression_level,
|
||||
)
|
||||
from pymongo.synchronous.monitoring import _validate_event_listeners
|
||||
from pymongo.synchronous.read_preferences import _MONGOS_MODES, _ServerMode
|
||||
from pymongo.driver_info import DriverInfo
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.monitoring import _validate_event_listeners
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _MONGOS_MODES, _ServerMode
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
from pymongo.typings import _AgnosticClientSession
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict)
|
||||
|
||||
@ -69,7 +68,8 @@ MAX_WRITE_BATCH_SIZE = 1000
|
||||
# What this version of PyMongo supports.
|
||||
MIN_SUPPORTED_SERVER_VERSION = "3.6"
|
||||
MIN_SUPPORTED_WIRE_VERSION = 6
|
||||
MAX_SUPPORTED_WIRE_VERSION = 21
|
||||
# MongoDB 8.0
|
||||
MAX_SUPPORTED_WIRE_VERSION = 25
|
||||
|
||||
# Frequency to call hello on servers, in seconds.
|
||||
HEARTBEAT_FREQUENCY = 10
|
||||
@ -380,7 +380,7 @@ def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode:
|
||||
|
||||
def validate_auth_mechanism(option: str, value: Any) -> str:
|
||||
"""Validate the authMechanism URI option."""
|
||||
from pymongo.synchronous.auth import MECHANISMS
|
||||
from pymongo.auth_shared import MECHANISMS
|
||||
|
||||
if value not in MECHANISMS:
|
||||
raise ValueError(f"{option} must be in {tuple(MECHANISMS)}")
|
||||
@ -446,7 +446,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
|
||||
elif key in ["ALLOWED_HOSTS"] and isinstance(value, list):
|
||||
props[key] = value
|
||||
elif key in ["OIDC_CALLBACK", "OIDC_HUMAN_CALLBACK"]:
|
||||
from pymongo.synchronous.auth_oidc import OIDCCallback
|
||||
from pymongo.auth_oidc_shared import OIDCCallback
|
||||
|
||||
if not isinstance(value, OIDCCallback):
|
||||
raise ValueError("callback must be an OIDCCallback object")
|
||||
@ -642,7 +642,7 @@ def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[A
|
||||
"""Validate the driver keyword arg."""
|
||||
if value is None:
|
||||
return value
|
||||
from pymongo.synchronous.encryption_options import AutoEncryptionOpts
|
||||
from pymongo.encryption_options import AutoEncryptionOpts
|
||||
|
||||
if not isinstance(value, AutoEncryptionOpts):
|
||||
raise TypeError(f"{option} must be an instance of AutoEncryptionOpts")
|
||||
@ -941,7 +941,7 @@ class BaseObject:
|
||||
"""
|
||||
return self._write_concern
|
||||
|
||||
def _write_concern_for(self, session: Optional[ClientSession]) -> WriteConcern:
|
||||
def _write_concern_for(self, session: Optional[_AgnosticClientSession]) -> WriteConcern:
|
||||
"""Read only access to the write concern of this instance or session."""
|
||||
# Override this operation's write concern with the transaction's.
|
||||
if session and session.in_transaction:
|
||||
@ -957,7 +957,7 @@ class BaseObject:
|
||||
"""
|
||||
return self._read_preference
|
||||
|
||||
def _read_preference_for(self, session: Optional[ClientSession]) -> _ServerMode:
|
||||
def _read_preference_for(self, session: Optional[_AgnosticClientSession]) -> _ServerMode:
|
||||
"""Read only access to the read preference of this instance or session."""
|
||||
# Override this operation's read preference with the transaction's.
|
||||
if session:
|
||||
@ -16,11 +16,8 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
from pymongo.helpers_constants import _SENSITIVE_COMMANDS
|
||||
from pymongo.synchronous.hello_compat import HelloCompat
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.helpers_shared import _SENSITIVE_COMMANDS
|
||||
|
||||
_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"}
|
||||
_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD}
|
||||
@ -20,3 +20,4 @@ from pymongo.synchronous.cursor import * # noqa: F403
|
||||
from pymongo.synchronous.cursor import __doc__ as original_doc
|
||||
|
||||
__doc__ = original_doc
|
||||
__all__ = ["Cursor", "CursorType", "RawBatchCursor"] # noqa: F405
|
||||
|
||||
@ -19,3 +19,4 @@ from pymongo.synchronous.database import * # noqa: F403
|
||||
from pymongo.synchronous.database import __doc__ as original_doc
|
||||
|
||||
__doc__ = original_doc
|
||||
__all__ = ["Database"] # noqa: F405
|
||||
|
||||
@ -19,3 +19,4 @@ from pymongo.synchronous.encryption import * # noqa: F403
|
||||
from pymongo.synchronous.encryption import __doc__ as original_doc
|
||||
|
||||
__doc__ = original_doc
|
||||
__all__ = ["Algorithm", "ClientEncryption", "QueryType", "RewrapManyDataKeyResult"] # noqa: F405
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,10 +12,261 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Re-import of synchronous EncryptionOptions API for compatibility."""
|
||||
"""Support for automatic client-side field level encryption."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pymongo.synchronous.encryption_options import * # noqa: F403
|
||||
from pymongo.synchronous.encryption_options import __doc__ as original_doc
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
||||
|
||||
__doc__ = original_doc
|
||||
try:
|
||||
import pymongocrypt # type:ignore[import] # noqa: F401
|
||||
|
||||
# Check for pymongocrypt>=1.10.
|
||||
from pymongocrypt import synchronous as _ # noqa: F401
|
||||
|
||||
_HAVE_PYMONGOCRYPT = True
|
||||
except ImportError:
|
||||
_HAVE_PYMONGOCRYPT = False
|
||||
from bson import int64
|
||||
from pymongo.common import validate_is_mapping
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.uri_parser import _parse_kms_tls_options
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg
|
||||
|
||||
|
||||
class AutoEncryptionOpts:
|
||||
"""Options to configure automatic client-side field level encryption."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: Optional[_AgnosticMongoClient[_DocumentTypeArg]] = None,
|
||||
schema_map: Optional[Mapping[str, Any]] = None,
|
||||
bypass_auto_encryption: bool = False,
|
||||
mongocryptd_uri: str = "mongodb://localhost:27020",
|
||||
mongocryptd_bypass_spawn: bool = False,
|
||||
mongocryptd_spawn_path: str = "mongocryptd",
|
||||
mongocryptd_spawn_args: Optional[list[str]] = None,
|
||||
kms_tls_options: Optional[Mapping[str, Any]] = None,
|
||||
crypt_shared_lib_path: Optional[str] = None,
|
||||
crypt_shared_lib_required: bool = False,
|
||||
bypass_query_analysis: bool = False,
|
||||
encrypted_fields_map: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Options to configure automatic client-side field level encryption.
|
||||
|
||||
Automatic client-side field level encryption requires MongoDB >=4.2
|
||||
enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not
|
||||
supported for operations on a database or view and will result in
|
||||
error.
|
||||
|
||||
Although automatic encryption requires MongoDB >=4.2 enterprise or a
|
||||
MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all
|
||||
users. To configure automatic *decryption* without automatic
|
||||
*encryption* set ``bypass_auto_encryption=True``. Explicit
|
||||
encryption and explicit decryption is also supported for all users
|
||||
with the :class:`~pymongo.encryption.ClientEncryption` class.
|
||||
|
||||
See :ref:`automatic-client-side-encryption` for an example.
|
||||
|
||||
:param kms_providers: Map of KMS provider options. The `kms_providers`
|
||||
map values differ by provider:
|
||||
|
||||
- `aws`: Map with "accessKeyId" and "secretAccessKey" as strings.
|
||||
These are the AWS access key ID and AWS secret access key used
|
||||
to generate KMS messages. An optional "sessionToken" may be
|
||||
included to support temporary AWS credentials.
|
||||
- `azure`: Map with "tenantId", "clientId", and "clientSecret" as
|
||||
strings. Additionally, "identityPlatformEndpoint" may also be
|
||||
specified as a string (defaults to 'login.microsoftonline.com').
|
||||
These are the Azure Active Directory credentials used to
|
||||
generate Azure Key Vault messages.
|
||||
- `gcp`: Map with "email" as a string and "privateKey"
|
||||
as `bytes` or a base64 encoded string.
|
||||
Additionally, "endpoint" may also be specified as a string
|
||||
(defaults to 'oauth2.googleapis.com'). These are the
|
||||
credentials used to generate Google Cloud KMS messages.
|
||||
- `kmip`: Map with "endpoint" as a host with required port.
|
||||
For example: ``{"endpoint": "example.com:443"}``.
|
||||
- `local`: Map with "key" as `bytes` (96 bytes in length) or
|
||||
a base64 encoded string which decodes
|
||||
to 96 bytes. "key" is the master key used to encrypt/decrypt
|
||||
data keys. This key should be generated and stored as securely
|
||||
as possible.
|
||||
|
||||
KMS providers may be specified with an optional name suffix
|
||||
separated by a colon, for example "kmip:name" or "aws:name".
|
||||
Named KMS providers do not support :ref:`CSFLE on-demand credentials`.
|
||||
Named KMS providers enables more than one of each KMS provider type to be configured.
|
||||
For example, to configure multiple local KMS providers::
|
||||
|
||||
kms_providers = {
|
||||
"local": {"key": local_kek1}, # Unnamed KMS provider.
|
||||
"local:myname": {"key": local_kek2}, # Named KMS provider with name "myname".
|
||||
}
|
||||
|
||||
:param key_vault_namespace: The namespace for the key vault collection.
|
||||
The key vault collection contains all data keys used for encryption
|
||||
and decryption. Data keys are stored as documents in this MongoDB
|
||||
collection. Data keys are protected with encryption by a KMS
|
||||
provider.
|
||||
:param key_vault_client: By default, the key vault collection
|
||||
is assumed to reside in the same MongoDB cluster as the encrypted
|
||||
AsyncMongoClient/MongoClient. Use this option to route data key queries to a
|
||||
separate MongoDB cluster.
|
||||
:param schema_map: Map of collection namespace ("db.coll") to
|
||||
JSON Schema. By default, a collection's JSONSchema is periodically
|
||||
polled with the listCollections command. But a JSONSchema may be
|
||||
specified locally with the schemaMap option.
|
||||
|
||||
**Supplying a `schema_map` provides more security than relying on
|
||||
JSON Schemas obtained from the server. It protects against a
|
||||
malicious server advertising a false JSON Schema, which could trick
|
||||
the client into sending unencrypted data that should be
|
||||
encrypted.**
|
||||
|
||||
Schemas supplied in the schemaMap only apply to configuring
|
||||
automatic encryption for client side encryption. Other validation
|
||||
rules in the JSON schema will not be enforced by the driver and
|
||||
will result in an error.
|
||||
:param bypass_auto_encryption: If ``True``, automatic
|
||||
encryption will be disabled but automatic decryption will still be
|
||||
enabled. Defaults to ``False``.
|
||||
:param mongocryptd_uri: The MongoDB URI used to connect
|
||||
to the *local* mongocryptd process. Defaults to
|
||||
``'mongodb://localhost:27020'``.
|
||||
:param mongocryptd_bypass_spawn: If ``True``, the encrypted
|
||||
AsyncMongoClient/MongoClient will not attempt to spawn the mongocryptd process.
|
||||
Defaults to ``False``.
|
||||
:param mongocryptd_spawn_path: Used for spawning the
|
||||
mongocryptd process. Defaults to ``'mongocryptd'`` and spawns
|
||||
mongocryptd from the system path.
|
||||
:param mongocryptd_spawn_args: A list of string arguments to
|
||||
use when spawning the mongocryptd process. Defaults to
|
||||
``['--idleShutdownTimeoutSecs=60']``. If the list does not include
|
||||
the ``idleShutdownTimeoutSecs`` option then
|
||||
``'--idleShutdownTimeoutSecs=60'`` will be added.
|
||||
:param kms_tls_options: A map of KMS provider names to TLS
|
||||
options to use when creating secure connections to KMS providers.
|
||||
Accepts the same TLS options as
|
||||
:class:`pymongo.mongo_client.AsyncMongoClient` and :class:`pymongo.mongo_client.MongoClient`. For example, to
|
||||
override the system default CA file::
|
||||
|
||||
kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}}
|
||||
|
||||
Or to supply a client certificate::
|
||||
|
||||
kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}}
|
||||
:param crypt_shared_lib_path: Override the path to load the crypt_shared library.
|
||||
:param crypt_shared_lib_required: If True, raise an error if libmongocrypt is
|
||||
unable to load the crypt_shared library.
|
||||
:param bypass_query_analysis: If ``True``, disable automatic analysis
|
||||
of outgoing commands. Set `bypass_query_analysis` to use explicit
|
||||
encryption on indexed fields without the MongoDB Enterprise Advanced
|
||||
licensed crypt_shared library.
|
||||
:param encrypted_fields_map: Map of collection namespace ("db.coll") to documents
|
||||
that described the encrypted fields for Queryable Encryption. For example::
|
||||
|
||||
{
|
||||
"db.encryptedCollection": {
|
||||
"escCollection": "enxcol_.encryptedCollection.esc",
|
||||
"ecocCollection": "enxcol_.encryptedCollection.ecoc",
|
||||
"fields": [
|
||||
{
|
||||
"path": "firstName",
|
||||
"keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')),
|
||||
"bsonType": "string",
|
||||
"queries": {"queryType": "equality"}
|
||||
},
|
||||
{
|
||||
"path": "ssn",
|
||||
"keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')),
|
||||
"bsonType": "string"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`,
|
||||
and `bypass_query_analysis` parameters.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Added the `kms_tls_options` parameter and the "kmip" KMS provider.
|
||||
|
||||
.. versionadded:: 3.9
|
||||
"""
|
||||
if not _HAVE_PYMONGOCRYPT:
|
||||
raise ConfigurationError(
|
||||
"client side encryption requires the pymongocrypt library: "
|
||||
"install a compatible version with: "
|
||||
"python -m pip install 'pymongo[encryption]'"
|
||||
)
|
||||
if encrypted_fields_map:
|
||||
validate_is_mapping("encrypted_fields_map", encrypted_fields_map)
|
||||
self._encrypted_fields_map = encrypted_fields_map
|
||||
self._bypass_query_analysis = bypass_query_analysis
|
||||
self._crypt_shared_lib_path = crypt_shared_lib_path
|
||||
self._crypt_shared_lib_required = crypt_shared_lib_required
|
||||
self._kms_providers = kms_providers
|
||||
self._key_vault_namespace = key_vault_namespace
|
||||
self._key_vault_client = key_vault_client
|
||||
self._schema_map = schema_map
|
||||
self._bypass_auto_encryption = bypass_auto_encryption
|
||||
self._mongocryptd_uri = mongocryptd_uri
|
||||
self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
|
||||
self._mongocryptd_spawn_path = mongocryptd_spawn_path
|
||||
if mongocryptd_spawn_args is None:
|
||||
mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"]
|
||||
self._mongocryptd_spawn_args = mongocryptd_spawn_args
|
||||
if not isinstance(self._mongocryptd_spawn_args, list):
|
||||
raise TypeError("mongocryptd_spawn_args must be a list")
|
||||
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
|
||||
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
|
||||
# Maps KMS provider name to a SSLContext.
|
||||
self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options)
|
||||
self._bypass_query_analysis = bypass_query_analysis
|
||||
|
||||
|
||||
class RangeOpts:
|
||||
"""Options to configure encrypted queries using the range algorithm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sparsity: int,
|
||||
trim_factor: int,
|
||||
min: Optional[Any] = None,
|
||||
max: Optional[Any] = None,
|
||||
precision: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Options to configure encrypted queries using the range algorithm.
|
||||
|
||||
:param sparsity: An integer.
|
||||
:param trim_factor: An integer.
|
||||
:param min: A BSON scalar value corresponding to the type being queried.
|
||||
:param max: A BSON scalar value corresponding to the type being queried.
|
||||
:param precision: An integer, may only be set for double or decimal128 types.
|
||||
|
||||
.. versionadded:: 4.4
|
||||
"""
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.sparsity = sparsity
|
||||
self.trim_factor = trim_factor
|
||||
self.precision = precision
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
doc = {}
|
||||
for k, v in [
|
||||
("sparsity", int64.Int64(self.sparsity)),
|
||||
("trimFactor", self.trim_factor),
|
||||
("precision", self.precision),
|
||||
("min", self.min),
|
||||
("max", self.max),
|
||||
]:
|
||||
if v is not None:
|
||||
doc[k] = v
|
||||
return doc
|
||||
|
||||
@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional, Sequence, Un
|
||||
from bson.errors import InvalidDocument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.typings import _DocumentOut
|
||||
from pymongo.typings import _DocumentOut
|
||||
|
||||
|
||||
class PyMongoError(Exception):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
# Copyright 2020-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,10 +12,212 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Re-import of synchronous EventLoggers API for compatibility."""
|
||||
|
||||
"""Example event logger classes.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
|
||||
These loggers can be registered using :func:`register` or
|
||||
:class:`~pymongo.mongo_client.MongoClient`.
|
||||
|
||||
``monitoring.register(CommandLogger())``
|
||||
|
||||
or
|
||||
|
||||
``MongoClient(event_listeners=[CommandLogger()])``
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pymongo.synchronous.event_loggers import * # noqa: F403
|
||||
from pymongo.synchronous.event_loggers import __doc__ as original_doc
|
||||
import logging
|
||||
|
||||
__doc__ = original_doc
|
||||
from pymongo import monitoring
|
||||
|
||||
|
||||
class CommandLogger(monitoring.CommandListener):
|
||||
"""A simple listener that logs command events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.CommandStartedEvent`,
|
||||
:class:`~pymongo.monitoring.CommandSucceededEvent` and
|
||||
:class:`~pymongo.monitoring.CommandFailedEvent` events and
|
||||
logs them at the `INFO` severity level using :mod:`logging`.
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def started(self, event: monitoring.CommandStartedEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} started on server "
|
||||
f"{event.connection_id}"
|
||||
)
|
||||
|
||||
def succeeded(self, event: monitoring.CommandSucceededEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} on server {event.connection_id} "
|
||||
f"succeeded in {event.duration_micros} "
|
||||
"microseconds"
|
||||
)
|
||||
|
||||
def failed(self, event: monitoring.CommandFailedEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} on server {event.connection_id} "
|
||||
f"failed in {event.duration_micros} "
|
||||
"microseconds"
|
||||
)
|
||||
|
||||
|
||||
class ServerLogger(monitoring.ServerListener):
|
||||
"""A simple listener that logs server discovery events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`,
|
||||
:class:`~pymongo.monitoring.ServerDescriptionChangedEvent`,
|
||||
and :class:`~pymongo.monitoring.ServerClosedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def opened(self, event: monitoring.ServerOpeningEvent) -> None:
|
||||
logging.info(f"Server {event.server_address} added to topology {event.topology_id}")
|
||||
|
||||
def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None:
|
||||
previous_server_type = event.previous_description.server_type
|
||||
new_server_type = event.new_description.server_type
|
||||
if new_server_type != previous_server_type:
|
||||
# server_type_name was added in PyMongo 3.4
|
||||
logging.info(
|
||||
f"Server {event.server_address} changed type from "
|
||||
f"{event.previous_description.server_type_name} to "
|
||||
f"{event.new_description.server_type_name}"
|
||||
)
|
||||
|
||||
def closed(self, event: monitoring.ServerClosedEvent) -> None:
|
||||
logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}")
|
||||
|
||||
|
||||
class HeartbeatLogger(monitoring.ServerHeartbeatListener):
|
||||
"""A simple listener that logs server heartbeat events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`,
|
||||
:class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`,
|
||||
and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
|
||||
logging.info(f"Heartbeat sent to server {event.connection_id}")
|
||||
|
||||
def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
|
||||
# The reply.document attribute was added in PyMongo 3.4.
|
||||
logging.info(
|
||||
f"Heartbeat to server {event.connection_id} "
|
||||
"succeeded with reply "
|
||||
f"{event.reply.document}"
|
||||
)
|
||||
|
||||
def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
|
||||
logging.warning(
|
||||
f"Heartbeat to server {event.connection_id} failed with error {event.reply}"
|
||||
)
|
||||
|
||||
|
||||
class TopologyLogger(monitoring.TopologyListener):
|
||||
"""A simple listener that logs server topology events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`,
|
||||
:class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`,
|
||||
and :class:`~pymongo.monitoring.TopologyClosedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def opened(self, event: monitoring.TopologyOpenedEvent) -> None:
|
||||
logging.info(f"Topology with id {event.topology_id} opened")
|
||||
|
||||
def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None:
|
||||
logging.info(f"Topology description updated for topology id {event.topology_id}")
|
||||
previous_topology_type = event.previous_description.topology_type
|
||||
new_topology_type = event.new_description.topology_type
|
||||
if new_topology_type != previous_topology_type:
|
||||
# topology_type_name was added in PyMongo 3.4
|
||||
logging.info(
|
||||
f"Topology {event.topology_id} changed type from "
|
||||
f"{event.previous_description.topology_type_name} to "
|
||||
f"{event.new_description.topology_type_name}"
|
||||
)
|
||||
# The has_writable_server and has_readable_server methods
|
||||
# were added in PyMongo 3.4.
|
||||
if not event.new_description.has_writable_server():
|
||||
logging.warning("No writable servers available.")
|
||||
if not event.new_description.has_readable_server():
|
||||
logging.warning("No readable servers available.")
|
||||
|
||||
def closed(self, event: monitoring.TopologyClosedEvent) -> None:
|
||||
logging.info(f"Topology with id {event.topology_id} closed")
|
||||
|
||||
|
||||
class ConnectionPoolLogger(monitoring.ConnectionPoolListener):
|
||||
"""A simple listener that logs server connection pool events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`,
|
||||
:class:`~pymongo.monitoring.PoolClearedEvent`,
|
||||
:class:`~pymongo.monitoring.PoolClosedEvent`,
|
||||
:~pymongo.monitoring.class:`ConnectionCreatedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionReadyEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionClosedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckedOutEvent`,
|
||||
and :class:`~pymongo.monitoring.ConnectionCheckedInEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def pool_created(self, event: monitoring.PoolCreatedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool created")
|
||||
|
||||
def pool_ready(self, event: monitoring.PoolReadyEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool ready")
|
||||
|
||||
def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool cleared")
|
||||
|
||||
def pool_closed(self, event: monitoring.PoolClosedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool closed")
|
||||
|
||||
def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created")
|
||||
|
||||
def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded"
|
||||
)
|
||||
|
||||
def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] "
|
||||
f'connection closed, reason: "{event.reason}"'
|
||||
)
|
||||
|
||||
def connection_check_out_started(
|
||||
self, event: monitoring.ConnectionCheckOutStartedEvent
|
||||
) -> None:
|
||||
logging.info(f"[pool {event.address}] connection check out started")
|
||||
|
||||
def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}")
|
||||
|
||||
def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool"
|
||||
)
|
||||
|
||||
def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool"
|
||||
)
|
||||
|
||||
@ -21,12 +21,9 @@ import itertools
|
||||
from typing import Any, Generic, Mapping, Optional
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo.asynchronous.hello_compat import HelloCompat
|
||||
from pymongo.asynchronous.typings import ClusterTime, _DocumentType
|
||||
from pymongo import common
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
_IS_SYNC = False
|
||||
from pymongo.typings import ClusterTime, _DocumentType
|
||||
|
||||
|
||||
def _get_server_type(doc: Mapping[str, Any]) -> int:
|
||||
@ -57,6 +54,14 @@ def _get_server_type(doc: Mapping[str, Any]) -> int:
|
||||
return SERVER_TYPE.Standalone
|
||||
|
||||
|
||||
class HelloCompat:
|
||||
CMD = "hello"
|
||||
LEGACY_CMD = "ismaster"
|
||||
PRIMARY = "isWritablePrimary"
|
||||
LEGACY_PRIMARY = "ismaster"
|
||||
LEGACY_ERROR = "not master"
|
||||
|
||||
|
||||
class Hello(Generic[_DocumentType]):
|
||||
"""Parse a hello response from the server.
|
||||
|
||||
@ -1,72 +0,0 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Constants used by the driver that don't really fit elsewhere."""
|
||||
|
||||
# From the SDAM spec, the "node is shutting down" codes.
|
||||
from __future__ import annotations
|
||||
|
||||
_SHUTDOWN_CODES: frozenset = frozenset(
|
||||
[
|
||||
11600, # InterruptedAtShutdown
|
||||
91, # ShutdownInProgress
|
||||
]
|
||||
)
|
||||
# From the SDAM spec, the "not primary" error codes are combined with the
|
||||
# "node is recovering" error codes (of which the "node is shutting down"
|
||||
# errors are a subset).
|
||||
_NOT_PRIMARY_CODES: frozenset = (
|
||||
frozenset(
|
||||
[
|
||||
10058, # LegacyNotPrimary <=3.2 "not primary" error code
|
||||
10107, # NotWritablePrimary
|
||||
13435, # NotPrimaryNoSecondaryOk
|
||||
11602, # InterruptedDueToReplStateChange
|
||||
13436, # NotPrimaryOrSecondary
|
||||
189, # PrimarySteppedDown
|
||||
]
|
||||
)
|
||||
| _SHUTDOWN_CODES
|
||||
)
|
||||
# From the retryable writes spec.
|
||||
_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
|
||||
[
|
||||
7, # HostNotFound
|
||||
6, # HostUnreachable
|
||||
89, # NetworkTimeout
|
||||
9001, # SocketException
|
||||
262, # ExceededTimeLimit
|
||||
134, # ReadConcernMajorityNotAvailableYet
|
||||
]
|
||||
)
|
||||
|
||||
# Server code raised when re-authentication is required
|
||||
_REAUTHENTICATION_REQUIRED_CODE: int = 391
|
||||
|
||||
# Server code raised when authentication fails.
|
||||
_AUTHENTICATION_FAILURE_CODE: int = 18
|
||||
|
||||
# Note - to avoid bugs from forgetting which if these is all lowercase and
|
||||
# which are camelCase, and at the same time avoid having to add a test for
|
||||
# every command, use all lowercase here and test against command_name.lower().
|
||||
_SENSITIVE_COMMANDS: set = {
|
||||
"authenticate",
|
||||
"saslstart",
|
||||
"saslcontinue",
|
||||
"getnonce",
|
||||
"createuser",
|
||||
"updateuser",
|
||||
"copydbgetnonce",
|
||||
"copydbsaslstart",
|
||||
"copydb",
|
||||
}
|
||||
327
pymongo/helpers_shared.py
Normal file
327
pymongo/helpers_shared.py
Normal file
@ -0,0 +1,327 @@
|
||||
# Copyright 2009-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Bits and pieces used by the driver that don't really fit elsewhere."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from collections import abc
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Container,
|
||||
Iterable,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pymongo import ASCENDING
|
||||
from pymongo.errors import (
|
||||
CursorNotFound,
|
||||
DuplicateKeyError,
|
||||
ExecutionTimeout,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
WriteConcernError,
|
||||
WriteError,
|
||||
WTimeoutError,
|
||||
_wtimeout_error,
|
||||
)
|
||||
from pymongo.hello import HelloCompat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.cursor_shared import _Hint
|
||||
from pymongo.operations import _IndexList
|
||||
from pymongo.typings import _DocumentOut
|
||||
|
||||
|
||||
# From the SDAM spec, the "node is shutting down" codes.
|
||||
|
||||
_SHUTDOWN_CODES: frozenset = frozenset(
|
||||
[
|
||||
11600, # InterruptedAtShutdown
|
||||
91, # ShutdownInProgress
|
||||
]
|
||||
)
|
||||
# From the SDAM spec, the "not primary" error codes are combined with the
|
||||
# "node is recovering" error codes (of which the "node is shutting down"
|
||||
# errors are a subset).
|
||||
_NOT_PRIMARY_CODES: frozenset = (
|
||||
frozenset(
|
||||
[
|
||||
10058, # LegacyNotPrimary <=3.2 "not primary" error code
|
||||
10107, # NotWritablePrimary
|
||||
13435, # NotPrimaryNoSecondaryOk
|
||||
11602, # InterruptedDueToReplStateChange
|
||||
13436, # NotPrimaryOrSecondary
|
||||
189, # PrimarySteppedDown
|
||||
]
|
||||
)
|
||||
| _SHUTDOWN_CODES
|
||||
)
|
||||
# From the retryable writes spec.
|
||||
_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
|
||||
[
|
||||
7, # HostNotFound
|
||||
6, # HostUnreachable
|
||||
89, # NetworkTimeout
|
||||
9001, # SocketException
|
||||
262, # ExceededTimeLimit
|
||||
134, # ReadConcernMajorityNotAvailableYet
|
||||
]
|
||||
)
|
||||
|
||||
# Server code raised when re-authentication is required
|
||||
_REAUTHENTICATION_REQUIRED_CODE: int = 391
|
||||
|
||||
# Server code raised when authentication fails.
|
||||
_AUTHENTICATION_FAILURE_CODE: int = 18
|
||||
|
||||
# Note - to avoid bugs from forgetting which if these is all lowercase and
|
||||
# which are camelCase, and at the same time avoid having to add a test for
|
||||
# every command, use all lowercase here and test against command_name.lower().
|
||||
_SENSITIVE_COMMANDS: set = {
|
||||
"authenticate",
|
||||
"saslstart",
|
||||
"saslcontinue",
|
||||
"getnonce",
|
||||
"createuser",
|
||||
"updateuser",
|
||||
"copydbgetnonce",
|
||||
"copydbsaslstart",
|
||||
"copydb",
|
||||
}
|
||||
|
||||
|
||||
def _gen_index_name(keys: _IndexList) -> str:
|
||||
"""Generate an index name from the set of fields it is over."""
|
||||
return "_".join(["{}_{}".format(*item) for item in keys])
|
||||
|
||||
|
||||
def _index_list(
|
||||
key_or_list: _Hint, direction: Optional[Union[int, str]] = None
|
||||
) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]:
|
||||
"""Helper to generate a list of (key, direction) pairs.
|
||||
|
||||
Takes such a list, or a single key, or a single key and direction.
|
||||
"""
|
||||
if direction is not None:
|
||||
if not isinstance(key_or_list, str):
|
||||
raise TypeError("Expected a string and a direction")
|
||||
return [(key_or_list, direction)]
|
||||
else:
|
||||
if isinstance(key_or_list, str):
|
||||
return [(key_or_list, ASCENDING)]
|
||||
elif isinstance(key_or_list, abc.ItemsView):
|
||||
return list(key_or_list) # type: ignore[arg-type]
|
||||
elif isinstance(key_or_list, abc.Mapping):
|
||||
return list(key_or_list.items())
|
||||
elif not isinstance(key_or_list, (list, tuple)):
|
||||
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
|
||||
values: list[tuple[str, int]] = []
|
||||
for item in key_or_list:
|
||||
if isinstance(item, str):
|
||||
item = (item, ASCENDING) # noqa: PLW2901
|
||||
values.append(item)
|
||||
return values
|
||||
|
||||
|
||||
def _index_document(index_list: _IndexList) -> dict[str, Any]:
|
||||
"""Helper to generate an index specifying document.
|
||||
|
||||
Takes a list of (key, direction) pairs.
|
||||
"""
|
||||
if not isinstance(index_list, (list, tuple, abc.Mapping)):
|
||||
raise TypeError(
|
||||
"must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list)
|
||||
)
|
||||
if not len(index_list):
|
||||
raise ValueError("key_or_list must not be empty")
|
||||
|
||||
index: dict[str, Any] = {}
|
||||
|
||||
if isinstance(index_list, abc.Mapping):
|
||||
for key in index_list:
|
||||
value = index_list[key]
|
||||
_validate_index_key_pair(key, value)
|
||||
index[key] = value
|
||||
else:
|
||||
for item in index_list:
|
||||
if isinstance(item, str):
|
||||
item = (item, ASCENDING) # noqa: PLW2901
|
||||
key, value = item
|
||||
_validate_index_key_pair(key, value)
|
||||
index[key] = value
|
||||
return index
|
||||
|
||||
|
||||
def _validate_index_key_pair(key: Any, value: Any) -> None:
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("first item in each key pair must be an instance of str")
|
||||
if not isinstance(value, (str, int, abc.Mapping)):
|
||||
raise TypeError(
|
||||
"second item in each key pair must be 1, -1, "
|
||||
"'2d', or another valid MongoDB index specifier."
|
||||
)
|
||||
|
||||
|
||||
def _check_command_response(
|
||||
response: _DocumentOut,
|
||||
max_wire_version: Optional[int],
|
||||
allowable_errors: Optional[Container[Union[int, str]]] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
) -> None:
|
||||
"""Check the response to a command for errors."""
|
||||
if "ok" not in response:
|
||||
# Server didn't recognize our message as a command.
|
||||
raise OperationFailure(
|
||||
response.get("$err"), # type: ignore[arg-type]
|
||||
response.get("code"),
|
||||
response,
|
||||
max_wire_version,
|
||||
)
|
||||
|
||||
if parse_write_concern_error and "writeConcernError" in response:
|
||||
_error = response["writeConcernError"]
|
||||
_labels = response.get("errorLabels")
|
||||
if _labels:
|
||||
_error.update({"errorLabels": _labels})
|
||||
_raise_write_concern_error(_error)
|
||||
|
||||
if response["ok"]:
|
||||
return
|
||||
|
||||
details = response
|
||||
# Mongos returns the error details in a 'raw' object
|
||||
# for some errors.
|
||||
if "raw" in response:
|
||||
for shard in response["raw"].values():
|
||||
# Grab the first non-empty raw error from a shard.
|
||||
if shard.get("errmsg") and not shard.get("ok"):
|
||||
details = shard
|
||||
break
|
||||
|
||||
errmsg = details["errmsg"]
|
||||
code = details.get("code")
|
||||
|
||||
# For allowable errors, only check for error messages when the code is not
|
||||
# included.
|
||||
if allowable_errors:
|
||||
if code is not None:
|
||||
if code in allowable_errors:
|
||||
return
|
||||
elif errmsg in allowable_errors:
|
||||
return
|
||||
|
||||
# Server is "not primary" or "recovering"
|
||||
if code is not None:
|
||||
if code in _NOT_PRIMARY_CODES:
|
||||
raise NotPrimaryError(errmsg, response)
|
||||
elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg:
|
||||
raise NotPrimaryError(errmsg, response)
|
||||
|
||||
# Other errors
|
||||
# findAndModify with upsert can raise duplicate key error
|
||||
if code in (11000, 11001, 12582):
|
||||
raise DuplicateKeyError(errmsg, code, response, max_wire_version)
|
||||
elif code == 50:
|
||||
raise ExecutionTimeout(errmsg, code, response, max_wire_version)
|
||||
elif code == 43:
|
||||
raise CursorNotFound(errmsg, code, response, max_wire_version)
|
||||
|
||||
raise OperationFailure(errmsg, code, response, max_wire_version)
|
||||
|
||||
|
||||
def _raise_last_write_error(write_errors: list[Any]) -> NoReturn:
|
||||
# If the last batch had multiple errors only report
|
||||
# the last error to emulate continue_on_error.
|
||||
error = write_errors[-1]
|
||||
if error.get("code") == 11000:
|
||||
raise DuplicateKeyError(error.get("errmsg"), 11000, error)
|
||||
raise WriteError(error.get("errmsg"), error.get("code"), error)
|
||||
|
||||
|
||||
def _raise_write_concern_error(error: Any) -> NoReturn:
|
||||
if _wtimeout_error(error):
|
||||
# Make sure we raise WTimeoutError
|
||||
raise WTimeoutError(error.get("errmsg"), error.get("code"), error)
|
||||
raise WriteConcernError(error.get("errmsg"), error.get("code"), error)
|
||||
|
||||
|
||||
def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
|
||||
"""Return the writeConcernError or None."""
|
||||
wce = result.get("writeConcernError")
|
||||
if wce:
|
||||
# The server reports errorLabels at the top level but it's more
|
||||
# convenient to attach it to the writeConcernError doc itself.
|
||||
error_labels = result.get("errorLabels")
|
||||
if error_labels:
|
||||
# Copy to avoid changing the original document.
|
||||
wce = wce.copy()
|
||||
wce["errorLabels"] = error_labels
|
||||
return wce
|
||||
|
||||
|
||||
def _check_write_command_response(result: Mapping[str, Any]) -> None:
|
||||
"""Backward compatibility helper for write command error handling."""
|
||||
# Prefer write errors over write concern errors
|
||||
write_errors = result.get("writeErrors")
|
||||
if write_errors:
|
||||
_raise_last_write_error(write_errors)
|
||||
|
||||
wce = _get_wce_doc(result)
|
||||
if wce:
|
||||
_raise_write_concern_error(wce)
|
||||
|
||||
|
||||
def _fields_list_to_dict(
|
||||
fields: Union[Mapping[str, Any], Iterable[str]], option_name: str
|
||||
) -> Mapping[str, Any]:
|
||||
"""Takes a sequence of field names and returns a matching dictionary.
|
||||
|
||||
["a", "b"] becomes {"a": 1, "b": 1}
|
||||
|
||||
and
|
||||
|
||||
["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1}
|
||||
"""
|
||||
if isinstance(fields, abc.Mapping):
|
||||
return fields
|
||||
|
||||
if isinstance(fields, (abc.Sequence, abc.Set)):
|
||||
if not all(isinstance(field, str) for field in fields):
|
||||
raise TypeError(f"{option_name} must be a list of key names, each an instance of str")
|
||||
return dict.fromkeys(fields, 1)
|
||||
|
||||
raise TypeError(f"{option_name} must be a mapping or list of key names")
|
||||
|
||||
|
||||
def _handle_exception() -> None:
|
||||
"""Print exceptions raised by subscribers to stderr."""
|
||||
# Heavily influenced by logging.Handler.handleError.
|
||||
|
||||
# See note here:
|
||||
# https://docs.python.org/3.4/library/sys.html#sys.__stderr__
|
||||
if sys.stderr:
|
||||
einfo = sys.exc_info()
|
||||
try:
|
||||
traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr)
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
del einfo
|
||||
@ -21,9 +21,7 @@ from typing import Any
|
||||
|
||||
from bson import UuidRepresentation, json_util
|
||||
from bson.json_util import JSONOptions, _truncate_documents
|
||||
from pymongo.synchronous.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
|
||||
|
||||
_IS_SYNC = True
|
||||
from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
|
||||
|
||||
|
||||
class _CommandStatusMessage(str, enum.Enum):
|
||||
@ -34,9 +34,8 @@ from pymongo.errors import ConfigurationError
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.server_selectors import Selection
|
||||
from pymongo.server_selectors import Selection
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Constant defined in Max Staleness Spec: An idle primary writes a no-op every
|
||||
# 10 seconds to refresh secondaries' lastWriteDate values.
|
||||
File diff suppressed because it is too large
Load Diff
@ -19,3 +19,4 @@ from pymongo.synchronous.mongo_client import * # noqa: F403
|
||||
from pymongo.synchronous.mongo_client import __doc__ as original_doc
|
||||
|
||||
__doc__ = original_doc
|
||||
__all__ = ["MongoClient"] # noqa: F405
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,10 +12,613 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Re-import of synchronous Operations API for compatibility."""
|
||||
"""Operation class definitions."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pymongo.synchronous.operations import * # noqa: F403
|
||||
from pymongo.synchronous.operations import __doc__ as original_doc
|
||||
import enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
__doc__ = original_doc
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo import helpers_shared
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.common import validate_is_mapping, validate_list
|
||||
from pymongo.helpers_shared import _gen_index_name, _index_document, _index_list
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.typings import _AgnosticBulk
|
||||
|
||||
|
||||
# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary
|
||||
_IndexList = Union[
|
||||
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
|
||||
]
|
||||
_IndexKeyHint = Union[str, _IndexList]
|
||||
|
||||
|
||||
class _Op(str, enum.Enum):
|
||||
ABORT = "abortTransaction"
|
||||
AGGREGATE = "aggregate"
|
||||
COMMIT = "commitTransaction"
|
||||
COUNT = "count"
|
||||
CREATE = "create"
|
||||
CREATE_INDEXES = "createIndexes"
|
||||
CREATE_SEARCH_INDEXES = "createSearchIndexes"
|
||||
DELETE = "delete"
|
||||
DISTINCT = "distinct"
|
||||
DROP = "drop"
|
||||
DROP_DATABASE = "dropDatabase"
|
||||
DROP_INDEXES = "dropIndexes"
|
||||
DROP_SEARCH_INDEXES = "dropSearchIndexes"
|
||||
END_SESSIONS = "endSessions"
|
||||
FIND_AND_MODIFY = "findAndModify"
|
||||
FIND = "find"
|
||||
INSERT = "insert"
|
||||
LIST_COLLECTIONS = "listCollections"
|
||||
LIST_INDEXES = "listIndexes"
|
||||
LIST_SEARCH_INDEX = "listSearchIndexes"
|
||||
LIST_DATABASES = "listDatabases"
|
||||
UPDATE = "update"
|
||||
UPDATE_INDEX = "updateIndex"
|
||||
UPDATE_SEARCH_INDEX = "updateSearchIndex"
|
||||
RENAME = "rename"
|
||||
GETMORE = "getMore"
|
||||
KILL_CURSORS = "killCursors"
|
||||
TEST = "testOperation"
|
||||
|
||||
|
||||
class InsertOne(Generic[_DocumentType]):
|
||||
"""Represents an insert_one operation."""
|
||||
|
||||
__slots__ = ("_doc",)
|
||||
|
||||
def __init__(self, document: _DocumentType) -> None:
|
||||
"""Create an InsertOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
|
||||
:param document: The document to insert. If the document is missing an
|
||||
_id field one will be added.
|
||||
"""
|
||||
self._doc = document
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None:
|
||||
"""Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`."""
|
||||
bulkobj.add_insert(self._doc) # type: ignore[arg-type]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"InsertOne({self._doc!r})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return other._doc == self._doc
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
class DeleteOne:
|
||||
"""Represents a delete_one operation."""
|
||||
|
||||
__slots__ = ("_filter", "_collation", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create a DeleteOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the document to delete.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.4 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hint`` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
self._filter = filter
|
||||
self._collation = collation
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None:
|
||||
"""Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`."""
|
||||
bulkobj.add_delete(
|
||||
self._filter,
|
||||
1,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return (other._filter, other._collation, other._hint) == (
|
||||
self._filter,
|
||||
self._collation,
|
||||
self._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
class DeleteMany:
|
||||
"""Represents a delete_many operation."""
|
||||
|
||||
__slots__ = ("_filter", "_collation", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create a DeleteMany instance.
|
||||
|
||||
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the documents to delete.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.4 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hint`` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
self._filter = filter
|
||||
self._collation = collation
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None:
|
||||
"""Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`."""
|
||||
bulkobj.add_delete(
|
||||
self._filter,
|
||||
0,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return (other._filter, other._collation, other._hint) == (
|
||||
self._filter,
|
||||
self._collation,
|
||||
self._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
class ReplaceOne(Generic[_DocumentType]):
|
||||
"""Represents a replace_one operation."""
|
||||
|
||||
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
replacement: Union[_DocumentType, RawBSONDocument],
|
||||
upsert: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create a ReplaceOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the document to replace.
|
||||
:param replacement: The new document.
|
||||
:param upsert: If ``True``, perform an insert if no documents
|
||||
match the filter.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hint`` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the ``collation`` option.
|
||||
"""
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if upsert is not None:
|
||||
validate_boolean("upsert", upsert)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
self._filter = filter
|
||||
self._doc = replacement
|
||||
self._upsert = upsert
|
||||
self._collation = collation
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None:
|
||||
"""Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`."""
|
||||
bulkobj.add_replace(
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return (
|
||||
other._filter,
|
||||
other._doc,
|
||||
other._upsert,
|
||||
other._collation,
|
||||
other._hint,
|
||||
) == (
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
other._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format(
|
||||
self.__class__.__name__,
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
self._hint,
|
||||
)
|
||||
|
||||
|
||||
class _UpdateOp:
|
||||
"""Private base class for update operations."""
|
||||
|
||||
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
doc: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool,
|
||||
collation: Optional[_CollationIn],
|
||||
array_filters: Optional[list[Mapping[str, Any]]],
|
||||
hint: Optional[_IndexKeyHint],
|
||||
):
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if upsert is not None:
|
||||
validate_boolean("upsert", upsert)
|
||||
if array_filters is not None:
|
||||
validate_list("array_filters", array_filters)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
|
||||
self._filter = filter
|
||||
self._doc = doc
|
||||
self._upsert = upsert
|
||||
self._collation = collation
|
||||
self._array_filters = array_filters
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, type(self)):
|
||||
return (
|
||||
other._filter,
|
||||
other._doc,
|
||||
other._upsert,
|
||||
other._collation,
|
||||
other._array_filters,
|
||||
other._hint,
|
||||
) == (
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
self._array_filters,
|
||||
self._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format(
|
||||
self.__class__.__name__,
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
self._array_filters,
|
||||
self._hint,
|
||||
)
|
||||
|
||||
|
||||
class UpdateOne(_UpdateOp):
|
||||
"""Represents an update_one operation."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[list[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Represents an update_one operation.
|
||||
|
||||
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the document to update.
|
||||
:param update: The modifications to apply.
|
||||
:param upsert: If ``True``, perform an insert if no documents
|
||||
match the filter.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param array_filters: A list of filters specifying which
|
||||
array elements an update should apply.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the `hint` option.
|
||||
.. versionchanged:: 3.9
|
||||
Added the ability to accept a pipeline as the `update`.
|
||||
.. versionchanged:: 3.6
|
||||
Added the `array_filters` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
super().__init__(filter, update, upsert, collation, array_filters, hint)
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None:
|
||||
"""Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`."""
|
||||
bulkobj.add_update(
|
||||
self._filter,
|
||||
self._doc,
|
||||
False,
|
||||
self._upsert,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
array_filters=self._array_filters,
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
|
||||
class UpdateMany(_UpdateOp):
|
||||
"""Represents an update_many operation."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[list[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create an UpdateMany instance.
|
||||
|
||||
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the documents to update.
|
||||
:param update: The modifications to apply.
|
||||
:param upsert: If ``True``, perform an insert if no documents
|
||||
match the filter.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param array_filters: A list of filters specifying which
|
||||
array elements an update should apply.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the `hint` option.
|
||||
.. versionchanged:: 3.9
|
||||
Added the ability to accept a pipeline as the `update`.
|
||||
.. versionchanged:: 3.6
|
||||
Added the `array_filters` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
super().__init__(filter, update, upsert, collation, array_filters, hint)
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None:
|
||||
"""Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`."""
|
||||
bulkobj.add_update(
|
||||
self._filter,
|
||||
self._doc,
|
||||
True,
|
||||
self._upsert,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
array_filters=self._array_filters,
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
|
||||
class IndexModel:
|
||||
"""Represents an index to create."""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None:
|
||||
"""Create an Index instance.
|
||||
|
||||
For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_indexes` and :meth:`~pymongo.collection.Collection.create_indexes`.
|
||||
|
||||
Takes either a single key or a list containing (key, direction) pairs
|
||||
or keys. If no direction is given, :data:`~pymongo.ASCENDING` will
|
||||
be assumed.
|
||||
The key(s) must be an instance of :class:`str`, and the direction(s) must
|
||||
be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`,
|
||||
:data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`,
|
||||
:data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`).
|
||||
|
||||
Valid options include, but are not limited to:
|
||||
|
||||
- `name`: custom name to use for this index - if none is
|
||||
given, a name will be generated.
|
||||
- `unique`: if ``True``, creates a uniqueness constraint on the index.
|
||||
- `background`: if ``True``, this index should be created in the
|
||||
background.
|
||||
- `sparse`: if ``True``, omit from the index any documents that lack
|
||||
the indexed field.
|
||||
- `bucketSize`: for use with geoHaystack indexes.
|
||||
Number of documents to group together within a certain proximity
|
||||
to a given longitude and latitude.
|
||||
- `min`: minimum value for keys in a :data:`~pymongo.GEO2D`
|
||||
index.
|
||||
- `max`: maximum value for keys in a :data:`~pymongo.GEO2D`
|
||||
index.
|
||||
- `expireAfterSeconds`: <int> Used to create an expiring (TTL)
|
||||
collection. MongoDB will automatically delete documents from
|
||||
this collection after <int> seconds. The indexed field must
|
||||
be a UTC datetime or the data will not expire.
|
||||
- `partialFilterExpression`: A document that specifies a filter for
|
||||
a partial index.
|
||||
- `collation`: An instance of :class:`~pymongo.collation.Collation`
|
||||
that specifies the collation to use.
|
||||
- `wildcardProjection`: Allows users to include or exclude specific
|
||||
field paths from a `wildcard index`_ using the { "$**" : 1} key
|
||||
pattern. Requires MongoDB >= 4.2.
|
||||
- `hidden`: if ``True``, this index will be hidden from the query
|
||||
planner and will not be evaluated as part of query plan
|
||||
selection. Requires MongoDB >= 4.4.
|
||||
|
||||
See the MongoDB documentation for a full list of supported options by
|
||||
server version.
|
||||
|
||||
:param keys: a single key or a list containing (key, direction) pairs
|
||||
or keys specifying the index to create.
|
||||
:param kwargs: any additional index creation
|
||||
options (see the above list) should be passed as keyword
|
||||
arguments.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hidden`` option.
|
||||
.. versionchanged:: 3.2
|
||||
Added the ``partialFilterExpression`` option to support partial
|
||||
indexes.
|
||||
|
||||
.. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/
|
||||
"""
|
||||
keys = _index_list(keys)
|
||||
if kwargs.get("name") is None:
|
||||
kwargs["name"] = _gen_index_name(keys)
|
||||
kwargs["key"] = _index_document(keys)
|
||||
collation = validate_collation_or_none(kwargs.pop("collation", None))
|
||||
self.__document = kwargs
|
||||
if collation is not None:
|
||||
self.__document["collation"] = collation
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""An index document suitable for passing to the createIndexes
|
||||
command.
|
||||
"""
|
||||
return self.__document
|
||||
|
||||
|
||||
class SearchIndexModel:
|
||||
"""Represents a search index to create."""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
definition: Mapping[str, Any],
|
||||
name: Optional[str] = None,
|
||||
type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a Search Index instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.create_search_index` and :meth:`~pymongo.collection.AsyncCollection.create_search_indexes`.
|
||||
|
||||
:param definition: The definition for this index.
|
||||
:param name: The name for this index, if present.
|
||||
:param type: The type for this index which defaults to "search". Alternative values include "vectorSearch".
|
||||
:param kwargs: Keyword arguments supplying any additional options.
|
||||
|
||||
.. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster.
|
||||
.. versionadded:: 4.5
|
||||
.. versionchanged:: 4.7
|
||||
Added the type and kwargs arguments.
|
||||
"""
|
||||
self.__document: dict[str, Any] = {}
|
||||
if name is not None:
|
||||
self.__document["name"] = name
|
||||
self.__document["definition"] = definition
|
||||
if type is not None:
|
||||
self.__document["type"] = type
|
||||
self.__document.update(kwargs)
|
||||
|
||||
@property
|
||||
def document(self) -> Mapping[str, Any]:
|
||||
"""The document for this index."""
|
||||
return self.__document
|
||||
|
||||
@ -19,3 +19,4 @@ from pymongo.synchronous.pool import * # noqa: F403
|
||||
from pymongo.synchronous.pool import __doc__ as original_doc
|
||||
|
||||
__doc__ = original_doc
|
||||
__all__ = ["PoolOptions"] # noqa: F405
|
||||
|
||||
507
pymongo/pool_options.py
Normal file
507
pymongo/pool_options.py
Normal file
@ -0,0 +1,507 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""AsyncConnection pool options for AsyncMongoClient/MongoClient."""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, MutableMapping, Optional
|
||||
|
||||
import bson
|
||||
from pymongo import __version__
|
||||
from pymongo.common import (
|
||||
MAX_CONNECTING,
|
||||
MAX_IDLE_TIME_SEC,
|
||||
MAX_POOL_SIZE,
|
||||
MIN_POOL_SIZE,
|
||||
WAIT_QUEUE_TIMEOUT,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.auth_shared import MongoCredential
|
||||
from pymongo.compression_support import CompressionSettings
|
||||
from pymongo.driver_info import DriverInfo
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.pyopenssl_context import SSLContext
|
||||
from pymongo.server_api import ServerApi
|
||||
|
||||
|
||||
_METADATA: dict[str, Any] = {"driver": {"name": "PyMongo", "version": __version__}}
|
||||
|
||||
if sys.platform.startswith("linux"):
|
||||
# platform.linux_distribution was deprecated in Python 3.5
|
||||
# and removed in Python 3.8. Starting in Python 3.5 it
|
||||
# raises DeprecationWarning
|
||||
# DeprecationWarning: dist() and linux_distribution() functions are deprecated in Python 3.5
|
||||
_name = platform.system()
|
||||
_METADATA["os"] = {
|
||||
"type": _name,
|
||||
"name": _name,
|
||||
"architecture": platform.machine(),
|
||||
# Kernel version (e.g. 4.4.0-17-generic).
|
||||
"version": platform.release(),
|
||||
}
|
||||
elif sys.platform == "darwin":
|
||||
_METADATA["os"] = {
|
||||
"type": platform.system(),
|
||||
"name": platform.system(),
|
||||
"architecture": platform.machine(),
|
||||
# (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin
|
||||
# kernel version.
|
||||
"version": platform.mac_ver()[0],
|
||||
}
|
||||
elif sys.platform == "win32":
|
||||
_METADATA["os"] = {
|
||||
"type": platform.system(),
|
||||
# "Windows XP", "Windows 7", "Windows 10", etc.
|
||||
"name": " ".join((platform.system(), platform.release())),
|
||||
"architecture": platform.machine(),
|
||||
# Windows patch level (e.g. 5.1.2600-SP3)
|
||||
"version": "-".join(platform.win32_ver()[1:3]),
|
||||
}
|
||||
elif sys.platform.startswith("java"):
|
||||
_name, _ver, _arch = platform.java_ver()[-1]
|
||||
_METADATA["os"] = {
|
||||
# Linux, Windows 7, Mac OS X, etc.
|
||||
"type": _name,
|
||||
"name": _name,
|
||||
# x86, x86_64, AMD64, etc.
|
||||
"architecture": _arch,
|
||||
# Linux kernel version, OSX version, etc.
|
||||
"version": _ver,
|
||||
}
|
||||
else:
|
||||
# Get potential alias (e.g. SunOS 5.11 becomes Solaris 2.11)
|
||||
_aliased = platform.system_alias(platform.system(), platform.release(), platform.version())
|
||||
_METADATA["os"] = {
|
||||
"type": platform.system(),
|
||||
"name": " ".join([part for part in _aliased[:2] if part]),
|
||||
"architecture": platform.machine(),
|
||||
"version": _aliased[2],
|
||||
}
|
||||
|
||||
if platform.python_implementation().startswith("PyPy"):
|
||||
_METADATA["platform"] = " ".join(
|
||||
(
|
||||
platform.python_implementation(),
|
||||
".".join(map(str, sys.pypy_version_info)), # type: ignore
|
||||
"(Python %s)" % ".".join(map(str, sys.version_info)),
|
||||
)
|
||||
)
|
||||
elif sys.platform.startswith("java"):
|
||||
_METADATA["platform"] = " ".join(
|
||||
(
|
||||
platform.python_implementation(),
|
||||
".".join(map(str, sys.version_info)),
|
||||
"(%s)" % " ".join((platform.system(), platform.release())),
|
||||
)
|
||||
)
|
||||
else:
|
||||
_METADATA["platform"] = " ".join(
|
||||
(platform.python_implementation(), ".".join(map(str, sys.version_info)))
|
||||
)
|
||||
|
||||
DOCKER_ENV_PATH = "/.dockerenv"
|
||||
ENV_VAR_K8S = "KUBERNETES_SERVICE_HOST"
|
||||
|
||||
RUNTIME_NAME_DOCKER = "docker"
|
||||
ORCHESTRATOR_NAME_K8S = "kubernetes"
|
||||
|
||||
|
||||
def get_container_env_info() -> dict[str, str]:
|
||||
"""Returns the runtime and orchestrator of a container.
|
||||
If neither value is present, the metadata client.env.container field will be omitted."""
|
||||
container = {}
|
||||
|
||||
if Path(DOCKER_ENV_PATH).exists():
|
||||
container["runtime"] = RUNTIME_NAME_DOCKER
|
||||
if os.getenv(ENV_VAR_K8S):
|
||||
container["orchestrator"] = ORCHESTRATOR_NAME_K8S
|
||||
|
||||
return container
|
||||
|
||||
|
||||
def _is_lambda() -> bool:
|
||||
if os.getenv("AWS_LAMBDA_RUNTIME_API"):
|
||||
return True
|
||||
env = os.getenv("AWS_EXECUTION_ENV")
|
||||
if env:
|
||||
return env.startswith("AWS_Lambda_")
|
||||
return False
|
||||
|
||||
|
||||
def _is_azure_func() -> bool:
|
||||
return bool(os.getenv("FUNCTIONS_WORKER_RUNTIME"))
|
||||
|
||||
|
||||
def _is_gcp_func() -> bool:
|
||||
return bool(os.getenv("K_SERVICE") or os.getenv("FUNCTION_NAME"))
|
||||
|
||||
|
||||
def _is_vercel() -> bool:
|
||||
return bool(os.getenv("VERCEL"))
|
||||
|
||||
|
||||
def _is_faas() -> bool:
|
||||
return _is_lambda() or _is_azure_func() or _is_gcp_func() or _is_vercel()
|
||||
|
||||
|
||||
def _getenv_int(key: str) -> Optional[int]:
|
||||
"""Like os.getenv but returns an int, or None if the value is missing/malformed."""
|
||||
val = os.getenv(key)
|
||||
if not val:
|
||||
return None
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _metadata_env() -> dict[str, Any]:
|
||||
env: dict[str, Any] = {}
|
||||
container = get_container_env_info()
|
||||
if container:
|
||||
env["container"] = container
|
||||
# Skip if multiple (or no) envs are matched.
|
||||
if (_is_lambda(), _is_azure_func(), _is_gcp_func(), _is_vercel()).count(True) != 1:
|
||||
return env
|
||||
if _is_lambda():
|
||||
env["name"] = "aws.lambda"
|
||||
region = os.getenv("AWS_REGION")
|
||||
if region:
|
||||
env["region"] = region
|
||||
memory_mb = _getenv_int("AWS_LAMBDA_FUNCTION_MEMORY_SIZE")
|
||||
if memory_mb is not None:
|
||||
env["memory_mb"] = memory_mb
|
||||
elif _is_azure_func():
|
||||
env["name"] = "azure.func"
|
||||
elif _is_gcp_func():
|
||||
env["name"] = "gcp.func"
|
||||
region = os.getenv("FUNCTION_REGION")
|
||||
if region:
|
||||
env["region"] = region
|
||||
memory_mb = _getenv_int("FUNCTION_MEMORY_MB")
|
||||
if memory_mb is not None:
|
||||
env["memory_mb"] = memory_mb
|
||||
timeout_sec = _getenv_int("FUNCTION_TIMEOUT_SEC")
|
||||
if timeout_sec is not None:
|
||||
env["timeout_sec"] = timeout_sec
|
||||
elif _is_vercel():
|
||||
env["name"] = "vercel"
|
||||
region = os.getenv("VERCEL_REGION")
|
||||
if region:
|
||||
env["region"] = region
|
||||
return env
|
||||
|
||||
|
||||
_MAX_METADATA_SIZE = 512
|
||||
|
||||
|
||||
# See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations
|
||||
def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None:
|
||||
"""Perform metadata truncation."""
|
||||
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
|
||||
return
|
||||
# 1. Omit fields from env except env.name.
|
||||
env_name = metadata.get("env", {}).get("name")
|
||||
if env_name:
|
||||
metadata["env"] = {"name": env_name}
|
||||
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
|
||||
return
|
||||
# 2. Omit fields from os except os.type.
|
||||
os_type = metadata.get("os", {}).get("type")
|
||||
if os_type:
|
||||
metadata["os"] = {"type": os_type}
|
||||
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
|
||||
return
|
||||
# 3. Omit the env document entirely.
|
||||
metadata.pop("env", None)
|
||||
encoded_size = len(bson.encode(metadata))
|
||||
if encoded_size <= _MAX_METADATA_SIZE:
|
||||
return
|
||||
# 4. Truncate platform.
|
||||
overflow = encoded_size - _MAX_METADATA_SIZE
|
||||
plat = metadata.get("platform", "")
|
||||
if plat:
|
||||
plat = plat[:-overflow]
|
||||
if plat:
|
||||
metadata["platform"] = plat
|
||||
else:
|
||||
metadata.pop("platform", None)
|
||||
encoded_size = len(bson.encode(metadata))
|
||||
if encoded_size <= _MAX_METADATA_SIZE:
|
||||
return
|
||||
# 5. Truncate driver info.
|
||||
overflow = encoded_size - _MAX_METADATA_SIZE
|
||||
driver = metadata.get("driver", {})
|
||||
if driver:
|
||||
# Truncate driver version.
|
||||
driver_version = driver.get("version")[:-overflow]
|
||||
if len(driver_version) >= len(_METADATA["driver"]["version"]):
|
||||
metadata["driver"]["version"] = driver_version
|
||||
else:
|
||||
metadata["driver"]["version"] = _METADATA["driver"]["version"]
|
||||
encoded_size = len(bson.encode(metadata))
|
||||
if encoded_size <= _MAX_METADATA_SIZE:
|
||||
return
|
||||
# Truncate driver name.
|
||||
overflow = encoded_size - _MAX_METADATA_SIZE
|
||||
driver_name = driver.get("name")[:-overflow]
|
||||
if len(driver_name) >= len(_METADATA["driver"]["name"]):
|
||||
metadata["driver"]["name"] = driver_name
|
||||
else:
|
||||
metadata["driver"]["name"] = _METADATA["driver"]["name"]
|
||||
|
||||
|
||||
# If the first getaddrinfo call of this interpreter's life is on a thread,
|
||||
# while the main thread holds the import lock, getaddrinfo deadlocks trying
|
||||
# to import the IDNA codec. Import it here, where presumably we're on the
|
||||
# main thread, to avoid the deadlock. See PYTHON-607.
|
||||
"foo".encode("idna")
|
||||
|
||||
|
||||
class PoolOptions:
|
||||
"""Read only connection pool options for an AsyncMongoClient/MongoClient.
|
||||
|
||||
Should not be instantiated directly by application developers. Access
|
||||
a client's pool options via
|
||||
:attr:`~pymongo.client_options.ClientOptions.pool_options` instead::
|
||||
|
||||
pool_opts = client.options.pool_options
|
||||
pool_opts.max_pool_size
|
||||
pool_opts.min_pool_size
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"__max_pool_size",
|
||||
"__min_pool_size",
|
||||
"__max_idle_time_seconds",
|
||||
"__connect_timeout",
|
||||
"__socket_timeout",
|
||||
"__wait_queue_timeout",
|
||||
"__ssl_context",
|
||||
"__tls_allow_invalid_hostnames",
|
||||
"__event_listeners",
|
||||
"__appname",
|
||||
"__driver",
|
||||
"__metadata",
|
||||
"__compression_settings",
|
||||
"__max_connecting",
|
||||
"__pause_enabled",
|
||||
"__server_api",
|
||||
"__load_balanced",
|
||||
"__credentials",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_pool_size: int = MAX_POOL_SIZE,
|
||||
min_pool_size: int = MIN_POOL_SIZE,
|
||||
max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC,
|
||||
connect_timeout: Optional[float] = None,
|
||||
socket_timeout: Optional[float] = None,
|
||||
wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT,
|
||||
ssl_context: Optional[SSLContext] = None,
|
||||
tls_allow_invalid_hostnames: bool = False,
|
||||
event_listeners: Optional[_EventListeners] = None,
|
||||
appname: Optional[str] = None,
|
||||
driver: Optional[DriverInfo] = None,
|
||||
compression_settings: Optional[CompressionSettings] = None,
|
||||
max_connecting: int = MAX_CONNECTING,
|
||||
pause_enabled: bool = True,
|
||||
server_api: Optional[ServerApi] = None,
|
||||
load_balanced: Optional[bool] = None,
|
||||
credentials: Optional[MongoCredential] = None,
|
||||
):
|
||||
self.__max_pool_size = max_pool_size
|
||||
self.__min_pool_size = min_pool_size
|
||||
self.__max_idle_time_seconds = max_idle_time_seconds
|
||||
self.__connect_timeout = connect_timeout
|
||||
self.__socket_timeout = socket_timeout
|
||||
self.__wait_queue_timeout = wait_queue_timeout
|
||||
self.__ssl_context = ssl_context
|
||||
self.__tls_allow_invalid_hostnames = tls_allow_invalid_hostnames
|
||||
self.__event_listeners = event_listeners
|
||||
self.__appname = appname
|
||||
self.__driver = driver
|
||||
self.__compression_settings = compression_settings
|
||||
self.__max_connecting = max_connecting
|
||||
self.__pause_enabled = pause_enabled
|
||||
self.__server_api = server_api
|
||||
self.__load_balanced = load_balanced
|
||||
self.__credentials = credentials
|
||||
self.__metadata = copy.deepcopy(_METADATA)
|
||||
if appname:
|
||||
self.__metadata["application"] = {"name": appname}
|
||||
|
||||
# Combine the "driver" AsyncMongoClient option with PyMongo's info, like:
|
||||
# {
|
||||
# 'driver': {
|
||||
# 'name': 'PyMongo|MyDriver',
|
||||
# 'version': '4.2.0|1.2.3',
|
||||
# },
|
||||
# 'platform': 'CPython 3.8.0|MyPlatform'
|
||||
# }
|
||||
if driver:
|
||||
if driver.name:
|
||||
self.__metadata["driver"]["name"] = "{}|{}".format(
|
||||
_METADATA["driver"]["name"],
|
||||
driver.name,
|
||||
)
|
||||
if driver.version:
|
||||
self.__metadata["driver"]["version"] = "{}|{}".format(
|
||||
_METADATA["driver"]["version"],
|
||||
driver.version,
|
||||
)
|
||||
if driver.platform:
|
||||
self.__metadata["platform"] = "{}|{}".format(_METADATA["platform"], driver.platform)
|
||||
|
||||
env = _metadata_env()
|
||||
if env:
|
||||
self.__metadata["env"] = env
|
||||
|
||||
_truncate_metadata(self.__metadata)
|
||||
|
||||
@property
|
||||
def _credentials(self) -> Optional[MongoCredential]:
|
||||
"""A :class:`~pymongo.auth.MongoCredentials` instance or None."""
|
||||
return self.__credentials
|
||||
|
||||
@property
|
||||
def non_default_options(self) -> dict[str, Any]:
|
||||
"""The non-default options this pool was created with.
|
||||
|
||||
Added for CMAP's :class:`PoolCreatedEvent`.
|
||||
"""
|
||||
opts = {}
|
||||
if self.__max_pool_size != MAX_POOL_SIZE:
|
||||
opts["maxPoolSize"] = self.__max_pool_size
|
||||
if self.__min_pool_size != MIN_POOL_SIZE:
|
||||
opts["minPoolSize"] = self.__min_pool_size
|
||||
if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC:
|
||||
assert self.__max_idle_time_seconds is not None
|
||||
opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000
|
||||
if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT:
|
||||
assert self.__wait_queue_timeout is not None
|
||||
opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000
|
||||
if self.__max_connecting != MAX_CONNECTING:
|
||||
opts["maxConnecting"] = self.__max_connecting
|
||||
return opts
|
||||
|
||||
@property
|
||||
def max_pool_size(self) -> float:
|
||||
"""The maximum allowable number of concurrent connections to each
|
||||
connected server. Requests to a server will block if there are
|
||||
`maxPoolSize` outstanding connections to the requested server.
|
||||
Defaults to 100. Cannot be 0.
|
||||
|
||||
When a server's pool has reached `max_pool_size`, operations for that
|
||||
server block waiting for a socket to be returned to the pool. If
|
||||
``waitQueueTimeoutMS`` is set, a blocked operation will raise
|
||||
:exc:`~pymongo.errors.ConnectionFailure` after a timeout.
|
||||
By default ``waitQueueTimeoutMS`` is not set.
|
||||
"""
|
||||
return self.__max_pool_size
|
||||
|
||||
@property
|
||||
def min_pool_size(self) -> int:
|
||||
"""The minimum required number of concurrent connections that the pool
|
||||
will maintain to each connected server. Default is 0.
|
||||
"""
|
||||
return self.__min_pool_size
|
||||
|
||||
@property
|
||||
def max_connecting(self) -> int:
|
||||
"""The maximum number of concurrent connection creation attempts per
|
||||
pool. Defaults to 2.
|
||||
"""
|
||||
return self.__max_connecting
|
||||
|
||||
@property
|
||||
def pause_enabled(self) -> bool:
|
||||
return self.__pause_enabled
|
||||
|
||||
@property
|
||||
def max_idle_time_seconds(self) -> Optional[int]:
|
||||
"""The maximum number of seconds that a connection can remain
|
||||
idle in the pool before being removed and replaced. Defaults to
|
||||
`None` (no limit).
|
||||
"""
|
||||
return self.__max_idle_time_seconds
|
||||
|
||||
@property
|
||||
def connect_timeout(self) -> Optional[float]:
|
||||
"""How long a connection can take to be opened before timing out."""
|
||||
return self.__connect_timeout
|
||||
|
||||
@property
|
||||
def socket_timeout(self) -> Optional[float]:
|
||||
"""How long a send or receive on a socket can take before timing out."""
|
||||
return self.__socket_timeout
|
||||
|
||||
@property
|
||||
def wait_queue_timeout(self) -> Optional[int]:
|
||||
"""How long a thread will wait for a socket from the pool if the pool
|
||||
has no free sockets.
|
||||
"""
|
||||
return self.__wait_queue_timeout
|
||||
|
||||
@property
|
||||
def _ssl_context(self) -> Optional[SSLContext]:
|
||||
"""An SSLContext instance or None."""
|
||||
return self.__ssl_context
|
||||
|
||||
@property
|
||||
def tls_allow_invalid_hostnames(self) -> bool:
|
||||
"""If True skip ssl.match_hostname."""
|
||||
return self.__tls_allow_invalid_hostnames
|
||||
|
||||
@property
|
||||
def _event_listeners(self) -> Optional[_EventListeners]:
|
||||
"""An instance of pymongo.monitoring._EventListeners."""
|
||||
return self.__event_listeners
|
||||
|
||||
@property
|
||||
def appname(self) -> Optional[str]:
|
||||
"""The application name, for sending with hello in server handshake."""
|
||||
return self.__appname
|
||||
|
||||
@property
|
||||
def driver(self) -> Optional[DriverInfo]:
|
||||
"""Driver name and version, for sending with hello in handshake."""
|
||||
return self.__driver
|
||||
|
||||
@property
|
||||
def _compression_settings(self) -> Optional[CompressionSettings]:
|
||||
return self.__compression_settings
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
"""A dict of metadata about the application, driver, os, and platform."""
|
||||
return self.__metadata.copy()
|
||||
|
||||
@property
|
||||
def server_api(self) -> Optional[ServerApi]:
|
||||
"""A pymongo.server_api.ServerApi or None."""
|
||||
return self.__server_api
|
||||
|
||||
@property
|
||||
def load_balanced(self) -> Optional[bool]:
|
||||
"""True if this Pool is configured in load balanced mode."""
|
||||
return self.__load_balanced
|
||||
@ -1,6 +1,6 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
# Copyright 2012-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# Licensed under the Apache License, Version 2.0 (the "License",
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
@ -12,10 +12,612 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Re-import of synchronous ReadPreferences API for compatibility."""
|
||||
"""Utilities for choosing which member of a replica set to read from."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pymongo.synchronous.read_preferences import * # noqa: F403
|
||||
from pymongo.synchronous.read_preferences import __doc__ as original_doc
|
||||
from collections import abc
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
|
||||
|
||||
__doc__ = original_doc
|
||||
from pymongo import max_staleness_selectors
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.server_selectors import (
|
||||
member_with_tags_server_selector,
|
||||
secondary_with_tags_server_selector,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.server_selectors import Selection
|
||||
from pymongo.topology_description import TopologyDescription
|
||||
|
||||
|
||||
_PRIMARY = 0
|
||||
_PRIMARY_PREFERRED = 1
|
||||
_SECONDARY = 2
|
||||
_SECONDARY_PREFERRED = 3
|
||||
_NEAREST = 4
|
||||
|
||||
|
||||
_MONGOS_MODES = (
|
||||
"primary",
|
||||
"primaryPreferred",
|
||||
"secondary",
|
||||
"secondaryPreferred",
|
||||
"nearest",
|
||||
)
|
||||
|
||||
_Hedge = Mapping[str, Any]
|
||||
_TagSets = Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]:
|
||||
"""Validate tag sets for a MongoClient."""
|
||||
if tag_sets is None:
|
||||
return tag_sets
|
||||
|
||||
if not isinstance(tag_sets, (list, tuple)):
|
||||
raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence")
|
||||
if len(tag_sets) == 0:
|
||||
raise ValueError(
|
||||
f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags"
|
||||
)
|
||||
|
||||
for tags in tag_sets:
|
||||
if not isinstance(tags, abc.Mapping):
|
||||
raise TypeError(
|
||||
f"Tag set {tags!r} invalid, must be an instance of dict, "
|
||||
"bson.son.SON or other type that inherits from "
|
||||
"collection.Mapping"
|
||||
)
|
||||
|
||||
return list(tag_sets)
|
||||
|
||||
|
||||
def _invalid_max_staleness_msg(max_staleness: Any) -> str:
|
||||
return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness
|
||||
|
||||
|
||||
# Some duplication with common.py to avoid import cycle.
|
||||
def _validate_max_staleness(max_staleness: Any) -> int:
|
||||
"""Validate max_staleness."""
|
||||
if max_staleness == -1:
|
||||
return -1
|
||||
|
||||
if not isinstance(max_staleness, int):
|
||||
raise TypeError(_invalid_max_staleness_msg(max_staleness))
|
||||
|
||||
if max_staleness <= 0:
|
||||
raise ValueError(_invalid_max_staleness_msg(max_staleness))
|
||||
|
||||
return max_staleness
|
||||
|
||||
|
||||
def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]:
|
||||
"""Validate hedge."""
|
||||
if hedge is None:
|
||||
return None
|
||||
|
||||
if not isinstance(hedge, dict):
|
||||
raise TypeError(f"hedge must be a dictionary, not {hedge!r}")
|
||||
|
||||
return hedge
|
||||
|
||||
|
||||
class _ServerMode:
|
||||
"""Base class for all read preferences."""
|
||||
|
||||
__slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: int,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
self.__mongos_mode = _MONGOS_MODES[mode]
|
||||
self.__mode = mode
|
||||
self.__tag_sets = _validate_tag_sets(tag_sets)
|
||||
self.__max_staleness = _validate_max_staleness(max_staleness)
|
||||
self.__hedge = _validate_hedge(hedge)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of this read preference."""
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def mongos_mode(self) -> str:
|
||||
"""The mongos mode of this read preference."""
|
||||
return self.__mongos_mode
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""Read preference as a document."""
|
||||
doc: dict[str, Any] = {"mode": self.__mongos_mode}
|
||||
if self.__tag_sets not in (None, [{}]):
|
||||
doc["tags"] = self.__tag_sets
|
||||
if self.__max_staleness != -1:
|
||||
doc["maxStalenessSeconds"] = self.__max_staleness
|
||||
if self.__hedge not in (None, {}):
|
||||
doc["hedge"] = self.__hedge
|
||||
return doc
|
||||
|
||||
@property
|
||||
def mode(self) -> int:
|
||||
"""The mode of this read preference instance."""
|
||||
return self.__mode
|
||||
|
||||
@property
|
||||
def tag_sets(self) -> _TagSets:
|
||||
"""Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to
|
||||
read only from members whose ``dc`` tag has the value ``"ny"``.
|
||||
To specify a priority-order for tag sets, provide a list of
|
||||
tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag
|
||||
set, ``{}``, means "read from any member that matches the mode,
|
||||
ignoring tags." MongoClient tries each set of tags in turn
|
||||
until it finds a set of tags with at least one matching member.
|
||||
For example, to only send a query to an analytic node::
|
||||
|
||||
Nearest(tag_sets=[{"node":"analytics"}])
|
||||
|
||||
Or using :class:`SecondaryPreferred`::
|
||||
|
||||
SecondaryPreferred(tag_sets=[{"node":"analytics"}])
|
||||
|
||||
.. seealso:: `Data-Center Awareness
|
||||
<https://www.mongodb.com/docs/manual/data-center-awareness/>`_
|
||||
"""
|
||||
return list(self.__tag_sets) if self.__tag_sets else [{}]
|
||||
|
||||
@property
|
||||
def max_staleness(self) -> int:
|
||||
"""The maximum estimated length of time (in seconds) a replica set
|
||||
secondary can fall behind the primary in replication before it will
|
||||
no longer be selected for operations, or -1 for no maximum.
|
||||
"""
|
||||
return self.__max_staleness
|
||||
|
||||
@property
|
||||
def hedge(self) -> Optional[_Hedge]:
|
||||
"""The read preference ``hedge`` parameter.
|
||||
|
||||
A dictionary that configures how the server will perform hedged reads.
|
||||
It consists of the following keys:
|
||||
|
||||
- ``enabled``: Enables or disables hedged reads in sharded clusters.
|
||||
|
||||
Hedged reads are automatically enabled in MongoDB 4.4+ when using a
|
||||
``nearest`` read preference. To explicitly enable hedged reads, set
|
||||
the ``enabled`` key to ``true``::
|
||||
|
||||
>>> Nearest(hedge={'enabled': True})
|
||||
|
||||
To explicitly disable hedged reads, set the ``enabled`` key to
|
||||
``False``::
|
||||
|
||||
>>> Nearest(hedge={'enabled': False})
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
return self.__hedge
|
||||
|
||||
@property
|
||||
def min_wire_version(self) -> int:
|
||||
"""The wire protocol version the server must support.
|
||||
|
||||
Some read preferences impose version requirements on all servers (e.g.
|
||||
maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5).
|
||||
|
||||
All servers' maxWireVersion must be at least this read preference's
|
||||
`min_wire_version`, or the driver raises
|
||||
:exc:`~pymongo.errors.ConfigurationError`.
|
||||
"""
|
||||
return 0 if self.__max_staleness == -1 else 5
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format(
|
||||
self.name,
|
||||
self.__tag_sets,
|
||||
self.__max_staleness,
|
||||
self.__hedge,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, _ServerMode):
|
||||
return (
|
||||
self.mode == other.mode
|
||||
and self.tag_sets == other.tag_sets
|
||||
and self.max_staleness == other.max_staleness
|
||||
and self.hedge == other.hedge
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
"""Return value of object for pickling.
|
||||
|
||||
Needed explicitly because __slots__() defined.
|
||||
"""
|
||||
return {
|
||||
"mode": self.__mode,
|
||||
"tag_sets": self.__tag_sets,
|
||||
"max_staleness": self.__max_staleness,
|
||||
"hedge": self.__hedge,
|
||||
}
|
||||
|
||||
def __setstate__(self, value: Mapping[str, Any]) -> None:
|
||||
"""Restore from pickling."""
|
||||
self.__mode = value["mode"]
|
||||
self.__mongos_mode = _MONGOS_MODES[self.__mode]
|
||||
self.__tag_sets = _validate_tag_sets(value["tag_sets"])
|
||||
self.__max_staleness = _validate_max_staleness(value["max_staleness"])
|
||||
self.__hedge = _validate_hedge(value["hedge"])
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
return selection
|
||||
|
||||
|
||||
class Primary(_ServerMode):
|
||||
"""Primary read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed if the server
|
||||
is standalone or a replica set primary.
|
||||
* When connected to a mongos queries are sent to the primary of a shard.
|
||||
* When connected to a replica set queries are sent to the primary of
|
||||
the replica set.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(_PRIMARY)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to a Selection."""
|
||||
return selection.primary_selection
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Primary()"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, _ServerMode):
|
||||
return other.mode == _PRIMARY
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class PrimaryPreferred(_ServerMode):
|
||||
"""PrimaryPreferred read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are sent to the primary of a shard if
|
||||
available, otherwise a shard secondary.
|
||||
* When connected to a replica set queries are sent to the primary if
|
||||
available, otherwise a secondary.
|
||||
|
||||
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
|
||||
created reads will be routed to an available secondary until the
|
||||
primary of the replica set is discovered.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` to use if the primary is not
|
||||
available.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` to use if the primary is not available.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
if selection.primary:
|
||||
return selection.primary_selection
|
||||
else:
|
||||
return secondary_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
|
||||
class Secondary(_ServerMode):
|
||||
"""Secondary read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are distributed among shard
|
||||
secondaries. An error is raised if no secondaries are available.
|
||||
* When connected to a replica set queries are distributed among
|
||||
secondaries. An error is raised if no secondaries are available.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` for this read preference.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` for this read preference.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_SECONDARY, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
return secondary_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
|
||||
class SecondaryPreferred(_ServerMode):
|
||||
"""SecondaryPreferred read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are distributed among shard
|
||||
secondaries, or the shard primary if no secondary is available.
|
||||
* When connected to a replica set queries are distributed among
|
||||
secondaries, or the primary if no secondary is available.
|
||||
|
||||
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
|
||||
created reads will be routed to the primary of the replica set until
|
||||
an available secondary is discovered.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` for this read preference.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` for this read preference.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
secondaries = secondary_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
if secondaries:
|
||||
return secondaries
|
||||
else:
|
||||
return selection.primary_selection
|
||||
|
||||
|
||||
class Nearest(_ServerMode):
|
||||
"""Nearest read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are distributed among all members of
|
||||
a shard.
|
||||
* When connected to a replica set queries are distributed among all
|
||||
members.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` for this read preference.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` for this read preference.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_NEAREST, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
return member_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
|
||||
class _AggWritePref:
|
||||
"""Agg $out/$merge write preference.
|
||||
|
||||
* If there are readable servers and there is any pre-5.0 server, use
|
||||
primary read preference.
|
||||
* Otherwise use `pref` read preference.
|
||||
|
||||
:param pref: The read preference to use on MongoDB 5.0+.
|
||||
"""
|
||||
|
||||
__slots__ = ("pref", "effective_pref")
|
||||
|
||||
def __init__(self, pref: _ServerMode):
|
||||
self.pref = pref
|
||||
self.effective_pref: _ServerMode = ReadPreference.PRIMARY
|
||||
|
||||
def selection_hook(self, topology_description: TopologyDescription) -> None:
|
||||
common_wv = topology_description.common_wire_version
|
||||
if (
|
||||
topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED)
|
||||
and common_wv
|
||||
and common_wv < 13
|
||||
):
|
||||
self.effective_pref = ReadPreference.PRIMARY
|
||||
else:
|
||||
self.effective_pref = self.pref
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to a Selection."""
|
||||
return self.effective_pref(selection)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"_AggWritePref(pref={self.pref!r})"
|
||||
|
||||
# Proxy other calls to the effective_pref so that _AggWritePref can be
|
||||
# used in place of an actual read preference.
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self.effective_pref, name)
|
||||
|
||||
|
||||
_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest)
|
||||
|
||||
|
||||
def make_read_preference(
|
||||
mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1
|
||||
) -> _ServerMode:
|
||||
if mode == _PRIMARY:
|
||||
if tag_sets not in (None, [{}]):
|
||||
raise ConfigurationError("Read preference primary cannot be combined with tags")
|
||||
if max_staleness != -1:
|
||||
raise ConfigurationError(
|
||||
"Read preference primary cannot be combined with maxStalenessSeconds"
|
||||
)
|
||||
return Primary()
|
||||
return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore
|
||||
|
||||
|
||||
_MODES = (
|
||||
"PRIMARY",
|
||||
"PRIMARY_PREFERRED",
|
||||
"SECONDARY",
|
||||
"SECONDARY_PREFERRED",
|
||||
"NEAREST",
|
||||
)
|
||||
|
||||
|
||||
class ReadPreference:
|
||||
"""An enum that defines some commonly used read preference modes.
|
||||
|
||||
Apps can also create a custom read preference, for example::
|
||||
|
||||
Nearest(tag_sets=[{"node":"analytics"}])
|
||||
|
||||
See :doc:`/examples/high_availability` for code examples.
|
||||
|
||||
A read preference is used in three cases:
|
||||
|
||||
:class:`~pymongo.mongo_client.MongoClient` connected to a single mongod:
|
||||
|
||||
- ``PRIMARY``: Queries are allowed if the server is standalone or a replica
|
||||
set primary.
|
||||
- All other modes allow queries to standalone servers, to a replica set
|
||||
primary, or to replica set secondaries.
|
||||
|
||||
:class:`~pymongo.mongo_client.MongoClient` initialized with the
|
||||
``replicaSet`` option:
|
||||
|
||||
- ``PRIMARY``: Read from the primary. This is the default, and provides the
|
||||
strongest consistency. If no primary is available, raise
|
||||
:class:`~pymongo.errors.AutoReconnect`.
|
||||
|
||||
- ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is
|
||||
none, read from a secondary.
|
||||
|
||||
- ``SECONDARY``: Read from a secondary. If no secondary is available,
|
||||
raise :class:`~pymongo.errors.AutoReconnect`.
|
||||
|
||||
- ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise
|
||||
from the primary.
|
||||
|
||||
- ``NEAREST``: Read from any member.
|
||||
|
||||
:class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a
|
||||
sharded cluster of replica sets:
|
||||
|
||||
- ``PRIMARY``: Read from the primary of the shard, or raise
|
||||
:class:`~pymongo.errors.OperationFailure` if there is none.
|
||||
This is the default.
|
||||
|
||||
- ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is
|
||||
none, read from a secondary of the shard.
|
||||
|
||||
- ``SECONDARY``: Read from a secondary of the shard, or raise
|
||||
:class:`~pymongo.errors.OperationFailure` if there is none.
|
||||
|
||||
- ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available,
|
||||
otherwise from the shard primary.
|
||||
|
||||
- ``NEAREST``: Read from any shard member.
|
||||
"""
|
||||
|
||||
PRIMARY = Primary()
|
||||
PRIMARY_PREFERRED = PrimaryPreferred()
|
||||
SECONDARY = Secondary()
|
||||
SECONDARY_PREFERRED = SecondaryPreferred()
|
||||
NEAREST = Nearest()
|
||||
|
||||
|
||||
def read_pref_mode_from_name(name: str) -> int:
|
||||
"""Get the read preference mode from mongos/uri name."""
|
||||
return _MONGOS_MODES.index(name)
|
||||
|
||||
|
||||
class MovingAverage:
|
||||
"""Tracks an exponentially-weighted moving average."""
|
||||
|
||||
average: Optional[float]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.average = None
|
||||
|
||||
def add_sample(self, sample: float) -> None:
|
||||
if sample < 0:
|
||||
# Likely system time change while waiting for hello response
|
||||
# and not using time.monotonic. Ignore it, the next one will
|
||||
# probably be valid.
|
||||
return
|
||||
if self.average is None:
|
||||
self.average = sample
|
||||
else:
|
||||
# The Server Selection Spec requires an exponentially weighted
|
||||
# average with alpha = 0.2.
|
||||
self.average = 0.8 * self.average + 0.2 * sample
|
||||
|
||||
def get(self) -> Optional[float]:
|
||||
"""Get the calculated average, or None if no samples yet."""
|
||||
return self.average
|
||||
|
||||
def reset(self) -> None:
|
||||
self.average = None
|
||||
|
||||
@ -20,11 +20,8 @@ from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from pymongo.synchronous.message import _OpMsg, _OpReply
|
||||
from pymongo.synchronous.pool import Connection
|
||||
from pymongo.synchronous.typings import _Address, _DocumentOut
|
||||
|
||||
_IS_SYNC = True
|
||||
from pymongo.message import _OpMsg, _OpReply
|
||||
from pymongo.typings import _Address, _AgnosticConnection, _DocumentOut
|
||||
|
||||
|
||||
class Response:
|
||||
@ -92,7 +89,7 @@ class PinnedResponse(Response):
|
||||
self,
|
||||
data: Union[_OpMsg, _OpReply],
|
||||
address: _Address,
|
||||
conn: Connection,
|
||||
conn: _AgnosticConnection,
|
||||
request_id: int,
|
||||
duration: Optional[timedelta],
|
||||
from_command: bool,
|
||||
@ -103,7 +100,7 @@ class PinnedResponse(Response):
|
||||
|
||||
:param data: A network response message.
|
||||
:param address: (host, port) of the source server.
|
||||
:param conn: The Connection used for the initial query.
|
||||
:param conn: The AsyncConnection/Connection used for the initial query.
|
||||
:param request_id: The request id of this operation.
|
||||
:param duration: The duration of the operation.
|
||||
:param from_command: If the response is the result of a db command.
|
||||
@ -116,8 +113,8 @@ class PinnedResponse(Response):
|
||||
self._more_to_come = more_to_come
|
||||
|
||||
@property
|
||||
def conn(self) -> Connection:
|
||||
"""The Connection used for the initial query.
|
||||
def conn(self) -> _AgnosticConnection:
|
||||
"""The AsyncConnection/Connection used for the initial query.
|
||||
|
||||
The server will send batches on this socket, without waiting for
|
||||
getMores from the client, until the result set is exhausted or there
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,10 +12,288 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Re-import of synchronous ServerDescription API for compatibility."""
|
||||
"""Represent one server the driver is connected to."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pymongo.synchronous.server_description import * # noqa: F403
|
||||
from pymongo.synchronous.server_description import __doc__ as original_doc
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
__doc__ = original_doc
|
||||
from bson import EPOCH_NAIVE
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.typings import ClusterTime, _Address
|
||||
|
||||
|
||||
class ServerDescription:
|
||||
"""Immutable representation of one server.
|
||||
|
||||
:param address: A (host, port) pair
|
||||
:param hello: Optional Hello instance
|
||||
:param round_trip_time: Optional float
|
||||
:param error: Optional, the last error attempting to connect to the server
|
||||
:param round_trip_time: Optional float, the min latency from the most recent samples
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"_address",
|
||||
"_server_type",
|
||||
"_all_hosts",
|
||||
"_tags",
|
||||
"_replica_set_name",
|
||||
"_primary",
|
||||
"_max_bson_size",
|
||||
"_max_message_size",
|
||||
"_max_write_batch_size",
|
||||
"_min_wire_version",
|
||||
"_max_wire_version",
|
||||
"_round_trip_time",
|
||||
"_min_round_trip_time",
|
||||
"_me",
|
||||
"_is_writable",
|
||||
"_is_readable",
|
||||
"_ls_timeout_minutes",
|
||||
"_error",
|
||||
"_set_version",
|
||||
"_election_id",
|
||||
"_cluster_time",
|
||||
"_last_write_date",
|
||||
"_last_update_time",
|
||||
"_topology_version",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: _Address,
|
||||
hello: Optional[Hello] = None,
|
||||
round_trip_time: Optional[float] = None,
|
||||
error: Optional[Exception] = None,
|
||||
min_round_trip_time: float = 0.0,
|
||||
) -> None:
|
||||
self._address = address
|
||||
if not hello:
|
||||
hello = Hello({})
|
||||
|
||||
self._server_type = hello.server_type
|
||||
self._all_hosts = hello.all_hosts
|
||||
self._tags = hello.tags
|
||||
self._replica_set_name = hello.replica_set_name
|
||||
self._primary = hello.primary
|
||||
self._max_bson_size = hello.max_bson_size
|
||||
self._max_message_size = hello.max_message_size
|
||||
self._max_write_batch_size = hello.max_write_batch_size
|
||||
self._min_wire_version = hello.min_wire_version
|
||||
self._max_wire_version = hello.max_wire_version
|
||||
self._set_version = hello.set_version
|
||||
self._election_id = hello.election_id
|
||||
self._cluster_time = hello.cluster_time
|
||||
self._is_writable = hello.is_writable
|
||||
self._is_readable = hello.is_readable
|
||||
self._ls_timeout_minutes = hello.logical_session_timeout_minutes
|
||||
self._round_trip_time = round_trip_time
|
||||
self._min_round_trip_time = min_round_trip_time
|
||||
self._me = hello.me
|
||||
self._last_update_time = time.monotonic()
|
||||
self._error = error
|
||||
self._topology_version = hello.topology_version
|
||||
if error:
|
||||
details = getattr(error, "details", None)
|
||||
if isinstance(details, dict):
|
||||
self._topology_version = details.get("topologyVersion")
|
||||
|
||||
self._last_write_date: Optional[float]
|
||||
if hello.last_write_date:
|
||||
# Convert from datetime to seconds.
|
||||
delta = hello.last_write_date - EPOCH_NAIVE
|
||||
self._last_write_date = delta.total_seconds()
|
||||
else:
|
||||
self._last_write_date = None
|
||||
|
||||
@property
|
||||
def address(self) -> _Address:
|
||||
"""The address (host, port) of this server."""
|
||||
return self._address
|
||||
|
||||
@property
|
||||
def server_type(self) -> int:
|
||||
"""The type of this server."""
|
||||
return self._server_type
|
||||
|
||||
@property
|
||||
def server_type_name(self) -> str:
|
||||
"""The server type as a human readable string.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
return SERVER_TYPE._fields[self._server_type]
|
||||
|
||||
@property
|
||||
def all_hosts(self) -> set[tuple[str, int]]:
|
||||
"""List of hosts, passives, and arbiters known to this server."""
|
||||
return self._all_hosts
|
||||
|
||||
@property
|
||||
def tags(self) -> Mapping[str, Any]:
|
||||
return self._tags
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""Replica set name or None."""
|
||||
return self._replica_set_name
|
||||
|
||||
@property
|
||||
def primary(self) -> Optional[tuple[str, int]]:
|
||||
"""This server's opinion about who the primary is, or None."""
|
||||
return self._primary
|
||||
|
||||
@property
|
||||
def max_bson_size(self) -> int:
|
||||
return self._max_bson_size
|
||||
|
||||
@property
|
||||
def max_message_size(self) -> int:
|
||||
return self._max_message_size
|
||||
|
||||
@property
|
||||
def max_write_batch_size(self) -> int:
|
||||
return self._max_write_batch_size
|
||||
|
||||
@property
|
||||
def min_wire_version(self) -> int:
|
||||
return self._min_wire_version
|
||||
|
||||
@property
|
||||
def max_wire_version(self) -> int:
|
||||
return self._max_wire_version
|
||||
|
||||
@property
|
||||
def set_version(self) -> Optional[int]:
|
||||
return self._set_version
|
||||
|
||||
@property
|
||||
def election_id(self) -> Optional[ObjectId]:
|
||||
return self._election_id
|
||||
|
||||
@property
|
||||
def cluster_time(self) -> Optional[ClusterTime]:
|
||||
return self._cluster_time
|
||||
|
||||
@property
|
||||
def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]:
|
||||
warnings.warn(
|
||||
"'election_tuple' is deprecated, use 'set_version' and 'election_id' instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._set_version, self._election_id
|
||||
|
||||
@property
|
||||
def me(self) -> Optional[tuple[str, int]]:
|
||||
return self._me
|
||||
|
||||
@property
|
||||
def logical_session_timeout_minutes(self) -> Optional[int]:
|
||||
return self._ls_timeout_minutes
|
||||
|
||||
@property
|
||||
def last_write_date(self) -> Optional[float]:
|
||||
return self._last_write_date
|
||||
|
||||
@property
|
||||
def last_update_time(self) -> float:
|
||||
return self._last_update_time
|
||||
|
||||
@property
|
||||
def round_trip_time(self) -> Optional[float]:
|
||||
"""The current average latency or None."""
|
||||
# This override is for unittesting only!
|
||||
if self._address in self._host_to_round_trip_time:
|
||||
return self._host_to_round_trip_time[self._address]
|
||||
|
||||
return self._round_trip_time
|
||||
|
||||
@property
|
||||
def min_round_trip_time(self) -> float:
|
||||
"""The min latency from the most recent samples."""
|
||||
return self._min_round_trip_time
|
||||
|
||||
@property
|
||||
def error(self) -> Optional[Exception]:
|
||||
"""The last error attempting to connect to the server, or None."""
|
||||
return self._error
|
||||
|
||||
@property
|
||||
def is_writable(self) -> bool:
|
||||
return self._is_writable
|
||||
|
||||
@property
|
||||
def is_readable(self) -> bool:
|
||||
return self._is_readable
|
||||
|
||||
@property
|
||||
def mongos(self) -> bool:
|
||||
return self._server_type == SERVER_TYPE.Mongos
|
||||
|
||||
@property
|
||||
def is_server_type_known(self) -> bool:
|
||||
return self.server_type != SERVER_TYPE.Unknown
|
||||
|
||||
@property
|
||||
def retryable_writes_supported(self) -> bool:
|
||||
"""Checks if this server supports retryable writes."""
|
||||
return (
|
||||
self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)
|
||||
) or self._server_type == SERVER_TYPE.LoadBalancer
|
||||
|
||||
@property
|
||||
def retryable_reads_supported(self) -> bool:
|
||||
"""Checks if this server supports retryable writes."""
|
||||
return self._max_wire_version >= 6
|
||||
|
||||
@property
|
||||
def topology_version(self) -> Optional[Mapping[str, Any]]:
|
||||
return self._topology_version
|
||||
|
||||
def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription:
|
||||
unknown = ServerDescription(self.address, error=error)
|
||||
unknown._topology_version = self.topology_version
|
||||
return unknown
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, ServerDescription):
|
||||
return (
|
||||
(self._address == other.address)
|
||||
and (self._server_type == other.server_type)
|
||||
and (self._min_wire_version == other.min_wire_version)
|
||||
and (self._max_wire_version == other.max_wire_version)
|
||||
and (self._me == other.me)
|
||||
and (self._all_hosts == other.all_hosts)
|
||||
and (self._tags == other.tags)
|
||||
and (self._replica_set_name == other.replica_set_name)
|
||||
and (self._set_version == other.set_version)
|
||||
and (self._election_id == other.election_id)
|
||||
and (self._primary == other.primary)
|
||||
and (self._ls_timeout_minutes == other.logical_session_timeout_minutes)
|
||||
and (self._error == other.error)
|
||||
)
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __repr__(self) -> str:
|
||||
errmsg = ""
|
||||
if self.error:
|
||||
errmsg = f", error={self.error!r}"
|
||||
return "<{} {} server_type: {}, rtt: {}{}>".format(
|
||||
self.__class__.__name__,
|
||||
self.address,
|
||||
self.server_type_name,
|
||||
self.round_trip_time,
|
||||
errmsg,
|
||||
)
|
||||
|
||||
# For unittesting only. Use under no circumstances!
|
||||
_host_to_round_trip_time: dict = {}
|
||||
|
||||
@ -20,10 +20,9 @@ from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, TypeVar, cas
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.server_description import ServerDescription
|
||||
from pymongo.synchronous.topology_description import TopologyDescription
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.topology_description import TopologyDescription
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
T = TypeVar("T")
|
||||
TagSet = Mapping[str, Any]
|
||||
@ -19,14 +19,12 @@ import ipaddress
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pymongo.common import CONNECT_TIMEOUT
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.synchronous.common import CONNECT_TIMEOUT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dns import resolver
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def _have_dnspython() -> bool:
|
||||
try:
|
||||
@ -18,20 +18,20 @@ from __future__ import annotations
|
||||
from collections.abc import Callable, Mapping, MutableMapping
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pymongo import common
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.synchronous import common
|
||||
from pymongo.synchronous.collation import validate_collation_or_none
|
||||
from pymongo.synchronous.read_preferences import ReadPreference, _AggWritePref
|
||||
from pymongo.read_preferences import ReadPreference, _AggWritePref
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.pool import Connection
|
||||
from pymongo.synchronous.read_preferences import _ServerMode
|
||||
from pymongo.synchronous.server import Server
|
||||
from pymongo.synchronous.typings import _DocumentType, _Pipeline
|
||||
from pymongo.typings import _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
@ -18,16 +18,12 @@ from __future__ import annotations
|
||||
import functools
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import socket
|
||||
import typing
|
||||
from base64 import standard_b64decode, standard_b64encode
|
||||
from collections import namedtuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
@ -36,20 +32,22 @@ from typing import (
|
||||
from urllib.parse import quote
|
||||
|
||||
from bson.binary import Binary
|
||||
from pymongo.auth_shared import (
|
||||
MongoCredential,
|
||||
_authenticate_scram_start,
|
||||
_parse_scram_response,
|
||||
_xor,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.saslprep import saslprep
|
||||
from pymongo.synchronous.auth_aws import _authenticate_aws
|
||||
from pymongo.synchronous.auth_oidc import (
|
||||
_authenticate_oidc,
|
||||
_get_authenticator,
|
||||
_OIDCAzureCallback,
|
||||
_OIDCGCPCallback,
|
||||
_OIDCProperties,
|
||||
_OIDCTestCallback,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.hello import Hello
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
HAVE_KERBEROS = True
|
||||
@ -68,210 +66,6 @@ except ImportError:
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
MECHANISMS = frozenset(
|
||||
[
|
||||
"GSSAPI",
|
||||
"MONGODB-CR",
|
||||
"MONGODB-OIDC",
|
||||
"MONGODB-X509",
|
||||
"MONGODB-AWS",
|
||||
"PLAIN",
|
||||
"SCRAM-SHA-1",
|
||||
"SCRAM-SHA-256",
|
||||
"DEFAULT",
|
||||
]
|
||||
)
|
||||
"""The authentication mechanisms supported by PyMongo."""
|
||||
|
||||
|
||||
class _Cache:
|
||||
__slots__ = ("data",)
|
||||
|
||||
_hash_val = hash("_Cache")
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = None
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
# Two instances must always compare equal.
|
||||
if isinstance(other, _Cache):
|
||||
return True
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
if isinstance(other, _Cache):
|
||||
return False
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self._hash_val
|
||||
|
||||
|
||||
MongoCredential = namedtuple(
|
||||
"MongoCredential",
|
||||
["mechanism", "source", "username", "password", "mechanism_properties", "cache"],
|
||||
)
|
||||
"""A hashable namedtuple of values used for authentication."""
|
||||
|
||||
|
||||
GSSAPIProperties = namedtuple(
|
||||
"GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"]
|
||||
)
|
||||
"""Mechanism properties for GSSAPI authentication."""
|
||||
|
||||
|
||||
_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
|
||||
"""Mechanism properties for MONGODB-AWS authentication."""
|
||||
|
||||
|
||||
def _build_credentials_tuple(
|
||||
mech: str,
|
||||
source: Optional[str],
|
||||
user: str,
|
||||
passwd: str,
|
||||
extra: Mapping[str, Any],
|
||||
database: Optional[str],
|
||||
) -> MongoCredential:
|
||||
"""Build and return a mechanism specific credentials tuple."""
|
||||
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
|
||||
raise ConfigurationError(f"{mech} requires a username.")
|
||||
if mech == "GSSAPI":
|
||||
if source is not None and source != "$external":
|
||||
raise ValueError("authentication source must be $external or None for GSSAPI")
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
service_name = properties.get("SERVICE_NAME", "mongodb")
|
||||
canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False))
|
||||
service_realm = properties.get("SERVICE_REALM")
|
||||
props = GSSAPIProperties(
|
||||
service_name=service_name,
|
||||
canonicalize_host_name=canonicalize,
|
||||
service_realm=service_realm,
|
||||
)
|
||||
# Source is always $external.
|
||||
return MongoCredential(mech, "$external", user, passwd, props, None)
|
||||
elif mech == "MONGODB-X509":
|
||||
if passwd is not None:
|
||||
raise ConfigurationError("Passwords are not supported by MONGODB-X509")
|
||||
if source is not None and source != "$external":
|
||||
raise ValueError("authentication source must be $external or None for MONGODB-X509")
|
||||
# Source is always $external, user can be None.
|
||||
return MongoCredential(mech, "$external", user, None, None, None)
|
||||
elif mech == "MONGODB-AWS":
|
||||
if user is not None and passwd is None:
|
||||
raise ConfigurationError("username without a password is not supported by MONGODB-AWS")
|
||||
if source is not None and source != "$external":
|
||||
raise ConfigurationError(
|
||||
"authentication source must be $external or None for MONGODB-AWS"
|
||||
)
|
||||
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
aws_session_token = properties.get("AWS_SESSION_TOKEN")
|
||||
aws_props = _AWSProperties(aws_session_token=aws_session_token)
|
||||
# user can be None for temporary link-local EC2 credentials.
|
||||
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
|
||||
elif mech == "MONGODB-OIDC":
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
callback = properties.get("OIDC_CALLBACK")
|
||||
human_callback = properties.get("OIDC_HUMAN_CALLBACK")
|
||||
environ = properties.get("ENVIRONMENT")
|
||||
token_resource = properties.get("TOKEN_RESOURCE", "")
|
||||
default_allowed = [
|
||||
"*.mongodb.net",
|
||||
"*.mongodb-dev.net",
|
||||
"*.mongodb-qa.net",
|
||||
"*.mongodbgov.net",
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
]
|
||||
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
|
||||
msg = (
|
||||
"authentication with MONGODB-OIDC requires providing either a callback or a environment"
|
||||
)
|
||||
if passwd is not None:
|
||||
msg = "password is not supported by MONGODB-OIDC"
|
||||
raise ConfigurationError(msg)
|
||||
if callback or human_callback:
|
||||
if environ is not None:
|
||||
raise ConfigurationError(msg)
|
||||
if callback and human_callback:
|
||||
msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK"
|
||||
raise ConfigurationError(msg)
|
||||
elif environ is not None:
|
||||
if environ == "test":
|
||||
if user is not None:
|
||||
msg = "test environment for MONGODB-OIDC does not support username"
|
||||
raise ConfigurationError(msg)
|
||||
callback = _OIDCTestCallback()
|
||||
elif environ == "azure":
|
||||
passwd = None
|
||||
if not token_resource:
|
||||
raise ConfigurationError(
|
||||
"Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
|
||||
)
|
||||
callback = _OIDCAzureCallback(token_resource)
|
||||
elif environ == "gcp":
|
||||
passwd = None
|
||||
if not token_resource:
|
||||
raise ConfigurationError(
|
||||
"GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
|
||||
)
|
||||
callback = _OIDCGCPCallback(token_resource)
|
||||
else:
|
||||
raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}")
|
||||
else:
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
oidc_props = _OIDCProperties(
|
||||
callback=callback,
|
||||
human_callback=human_callback,
|
||||
environment=environ,
|
||||
allowed_hosts=allowed_hosts,
|
||||
token_resource=token_resource,
|
||||
username=user,
|
||||
)
|
||||
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
|
||||
|
||||
elif mech == "PLAIN":
|
||||
source_database = source or database or "$external"
|
||||
return MongoCredential(mech, source_database, user, passwd, None, None)
|
||||
else:
|
||||
source_database = source or database or "admin"
|
||||
if passwd is None:
|
||||
raise ConfigurationError("A password is required.")
|
||||
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
|
||||
|
||||
|
||||
def _xor(fir: bytes, sec: bytes) -> bytes:
|
||||
"""XOR two byte strings together."""
|
||||
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
|
||||
|
||||
|
||||
def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]:
|
||||
"""Split a scram response into key, value pairs."""
|
||||
return dict(
|
||||
typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1))
|
||||
for item in response.split(b",")
|
||||
)
|
||||
|
||||
|
||||
def _authenticate_scram_start(
|
||||
credentials: MongoCredential, mechanism: str
|
||||
) -> tuple[bytes, bytes, MutableMapping[str, Any]]:
|
||||
username = credentials.username
|
||||
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
|
||||
nonce = standard_b64encode(os.urandom(32))
|
||||
first_bare = b"n=" + user + b",r=" + nonce
|
||||
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": mechanism,
|
||||
"payload": Binary(b"n,," + first_bare),
|
||||
"autoAuthorize": 1,
|
||||
"options": {"skipEmptyExchange": True},
|
||||
}
|
||||
return nonce, first_bare, cmd
|
||||
|
||||
|
||||
def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanism: str) -> None:
|
||||
"""Authenticate using SCRAM."""
|
||||
|
||||
@ -23,7 +23,7 @@ from pymongo.errors import ConfigurationError, OperationFailure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.typings import _ReadableBuffer
|
||||
from pymongo.synchronous.auth import MongoCredential
|
||||
from pymongo.auth_shared import MongoCredential
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
@ -15,79 +15,35 @@
|
||||
"""MONGODB-OIDC Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
from pymongo._azure_helpers import _get_azure_response
|
||||
from pymongo._csot import remaining
|
||||
from pymongo._gcp_helpers import _get_gcp_response
|
||||
from pymongo.auth_oidc_shared import (
|
||||
CALLBACK_VERSION,
|
||||
HUMAN_CALLBACK_TIMEOUT_SECONDS,
|
||||
MACHINE_CALLBACK_TIMEOUT_SECONDS,
|
||||
TIME_BETWEEN_CALLS_SECONDS,
|
||||
OIDCCallback,
|
||||
OIDCCallbackContext,
|
||||
OIDCCallbackResult,
|
||||
OIDCIdPInfo,
|
||||
_OIDCProperties,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.helpers_constants import _AUTHENTICATION_FAILURE_CODE
|
||||
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.auth import MongoCredential
|
||||
from pymongo.auth_shared import MongoCredential
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCIdPInfo:
|
||||
issuer: str
|
||||
clientId: Optional[str] = field(default=None)
|
||||
requestScopes: Optional[list[str]] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCCallbackContext:
|
||||
timeout_seconds: float
|
||||
username: str
|
||||
version: int
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[OIDCIdPInfo] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCCallbackResult:
|
||||
access_token: str
|
||||
expires_in_seconds: Optional[float] = field(default=None)
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
class OIDCCallback(abc.ABC):
|
||||
"""A base class for defining OIDC callbacks."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
"""Convert the given BSON value into our own type."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCProperties:
|
||||
callback: Optional[OIDCCallback] = field(default=None)
|
||||
human_callback: Optional[OIDCCallback] = field(default=None)
|
||||
environment: Optional[str] = field(default=None)
|
||||
allowed_hosts: list[str] = field(default_factory=list)
|
||||
token_resource: Optional[str] = field(default=None)
|
||||
username: str = ""
|
||||
|
||||
|
||||
"""Mechanism properties for MONGODB-OIDC authentication."""
|
||||
|
||||
TOKEN_BUFFER_MINUTES = 5
|
||||
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
|
||||
CALLBACK_VERSION = 1
|
||||
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
|
||||
TIME_BETWEEN_CALLS_SECONDS = 0.1
|
||||
|
||||
|
||||
def _get_authenticator(
|
||||
credentials: MongoCredential, address: tuple[str, int]
|
||||
) -> _OIDCAuthenticator:
|
||||
@ -117,48 +73,6 @@ def _get_authenticator(
|
||||
return credentials.cache.data
|
||||
|
||||
|
||||
class _OIDCTestCallback(OIDCCallback):
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
token_file = os.environ.get("OIDC_TOKEN_FILE")
|
||||
if not token_file:
|
||||
raise RuntimeError(
|
||||
'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set'
|
||||
)
|
||||
with open(token_file) as fid:
|
||||
return OIDCCallbackResult(access_token=fid.read().strip())
|
||||
|
||||
|
||||
class _OIDCAWSCallback(OIDCCallback):
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
|
||||
if not token_file:
|
||||
raise RuntimeError(
|
||||
'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set'
|
||||
)
|
||||
with open(token_file) as fid:
|
||||
return OIDCCallbackResult(access_token=fid.read().strip())
|
||||
|
||||
|
||||
class _OIDCAzureCallback(OIDCCallback):
|
||||
def __init__(self, token_resource: str) -> None:
|
||||
self.token_resource = quote(token_resource)
|
||||
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds)
|
||||
return OIDCCallbackResult(
|
||||
access_token=resp["access_token"], expires_in_seconds=resp["expires_in"]
|
||||
)
|
||||
|
||||
|
||||
class _OIDCGCPCallback(OIDCCallback):
|
||||
def __init__(self, token_resource: str) -> None:
|
||||
self.token_resource = quote(token_resource)
|
||||
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
|
||||
return OIDCCallbackResult(access_token=resp["access_token"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCAuthenticator:
|
||||
username: str
|
||||
|
||||
@ -19,6 +19,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import logging
|
||||
from collections.abc import MutableMapping
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
@ -26,7 +28,6 @@ from typing import (
|
||||
Any,
|
||||
Iterator,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
@ -34,140 +35,50 @@ from typing import (
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo import _csot
|
||||
from pymongo.errors import (
|
||||
BulkWriteError,
|
||||
ConfigurationError,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
from pymongo import _csot, common
|
||||
from pymongo.bulk_shared import (
|
||||
_COMMANDS,
|
||||
_DELETE_ALL,
|
||||
_merge_command,
|
||||
_raise_bulk_write_error,
|
||||
_Run,
|
||||
)
|
||||
from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.synchronous import common
|
||||
from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern
|
||||
from pymongo.synchronous.common import (
|
||||
from pymongo.common import (
|
||||
validate_is_document_type,
|
||||
validate_ok_for_replace,
|
||||
validate_ok_for_update,
|
||||
)
|
||||
from pymongo.synchronous.helpers import _get_wce_doc
|
||||
from pymongo.synchronous.message import (
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
InvalidOperation,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import (
|
||||
_DELETE,
|
||||
_INSERT,
|
||||
_UPDATE,
|
||||
_BulkWriteContext,
|
||||
_convert_exception,
|
||||
_convert_write_result,
|
||||
_EncryptedBulkWriteContext,
|
||||
_randint,
|
||||
)
|
||||
from pymongo.synchronous.read_preferences import ReadPreference
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern
|
||||
from pymongo.synchronous.helpers import _handle_reauth
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.pool import Connection
|
||||
from pymongo.synchronous.typings import _DocumentOut, _DocumentType, _Pipeline
|
||||
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
_DELETE_ALL: int = 0
|
||||
_DELETE_ONE: int = 1
|
||||
|
||||
# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err
|
||||
_BAD_VALUE: int = 2
|
||||
_UNKNOWN_ERROR: int = 8
|
||||
_WRITE_CONCERN_ERROR: int = 64
|
||||
|
||||
_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete")
|
||||
|
||||
|
||||
class _Run:
|
||||
"""Represents a batch of write operations."""
|
||||
|
||||
def __init__(self, op_type: int) -> None:
|
||||
"""Initialize a new Run object."""
|
||||
self.op_type: int = op_type
|
||||
self.index_map: list[int] = []
|
||||
self.ops: list[Any] = []
|
||||
self.idx_offset: int = 0
|
||||
|
||||
def index(self, idx: int) -> int:
|
||||
"""Get the original index of an operation in this run.
|
||||
|
||||
:param idx: The Run index that maps to the original index.
|
||||
"""
|
||||
return self.index_map[idx]
|
||||
|
||||
def add(self, original_index: int, operation: Any) -> None:
|
||||
"""Add an operation to this Run instance.
|
||||
|
||||
:param original_index: The original index of this operation
|
||||
within a larger bulk operation.
|
||||
:param operation: The operation document.
|
||||
"""
|
||||
self.index_map.append(original_index)
|
||||
self.ops.append(operation)
|
||||
|
||||
|
||||
def _merge_command(
|
||||
run: _Run,
|
||||
full_result: MutableMapping[str, Any],
|
||||
offset: int,
|
||||
result: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Merge a write command result into the full bulk result."""
|
||||
affected = result.get("n", 0)
|
||||
|
||||
if run.op_type == _INSERT:
|
||||
full_result["nInserted"] += affected
|
||||
|
||||
elif run.op_type == _DELETE:
|
||||
full_result["nRemoved"] += affected
|
||||
|
||||
elif run.op_type == _UPDATE:
|
||||
upserted = result.get("upserted")
|
||||
if upserted:
|
||||
n_upserted = len(upserted)
|
||||
for doc in upserted:
|
||||
doc["index"] = run.index(doc["index"] + offset)
|
||||
full_result["upserted"].extend(upserted)
|
||||
full_result["nUpserted"] += n_upserted
|
||||
full_result["nMatched"] += affected - n_upserted
|
||||
else:
|
||||
full_result["nMatched"] += affected
|
||||
full_result["nModified"] += result["nModified"]
|
||||
|
||||
write_errors = result.get("writeErrors")
|
||||
if write_errors:
|
||||
for doc in write_errors:
|
||||
# Leave the server response intact for APM.
|
||||
replacement = doc.copy()
|
||||
idx = doc["index"] + offset
|
||||
replacement["index"] = run.index(idx)
|
||||
# Add the failed operation to the error document.
|
||||
replacement["op"] = run.ops[idx]
|
||||
full_result["writeErrors"].append(replacement)
|
||||
|
||||
wce = _get_wce_doc(result)
|
||||
if wce:
|
||||
full_result["writeConcernErrors"].append(wce)
|
||||
|
||||
|
||||
def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn:
|
||||
"""Raise a BulkWriteError from the full bulk api result."""
|
||||
# retryWrites on MMAPv1 should raise an actionable error.
|
||||
if full_result["writeErrors"]:
|
||||
full_result["writeErrors"].sort(key=lambda error: error["index"])
|
||||
err = full_result["writeErrors"][0]
|
||||
code = err["code"]
|
||||
msg = err["errmsg"]
|
||||
if code == 20 and msg.startswith("Transaction numbers"):
|
||||
errmsg = (
|
||||
"This MongoDB deployment does not support "
|
||||
"retryable writes. Please add retryWrites=false "
|
||||
"to your connection string."
|
||||
)
|
||||
raise OperationFailure(errmsg, code, full_result)
|
||||
raise BulkWriteError(full_result)
|
||||
|
||||
|
||||
class _Bulk:
|
||||
"""The private guts of the bulk write API."""
|
||||
@ -204,13 +115,16 @@ class _Bulk:
|
||||
# Extra state so that we know where to pick up on a retry attempt.
|
||||
self.current_run = None
|
||||
self.next_run = None
|
||||
self.is_encrypted = False
|
||||
|
||||
@property
|
||||
def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
|
||||
encrypter = self.collection.database.client._encrypter
|
||||
if encrypter and not encrypter._bypass_auto_encryption:
|
||||
self.is_encrypted = True
|
||||
return _EncryptedBulkWriteContext
|
||||
else:
|
||||
self.is_encrypted = False
|
||||
return _BulkWriteContext
|
||||
|
||||
def add_insert(self, document: _DocumentOut) -> None:
|
||||
@ -315,6 +229,230 @@ class _Bulk:
|
||||
if run.ops:
|
||||
yield run
|
||||
|
||||
@_handle_reauth
|
||||
def write_command(
|
||||
self,
|
||||
bwc: _BulkWriteContext,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
docs: list[Mapping[str, Any]],
|
||||
client: MongoClient,
|
||||
) -> dict[str, Any]:
|
||||
"""A proxy for SocketInfo.write_command that handles event publishing."""
|
||||
cmd[bwc.field] = docs
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._start(cmd, request_id, docs)
|
||||
try:
|
||||
reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc]
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
durationMS=duration,
|
||||
reply=reply,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
|
||||
if bwc.publish:
|
||||
bwc._fail(request_id, failure, duration)
|
||||
raise
|
||||
finally:
|
||||
bwc.start_time = datetime.datetime.now()
|
||||
return reply # type: ignore[return-value]
|
||||
|
||||
def unack_write(
|
||||
self,
|
||||
bwc: _BulkWriteContext,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
max_doc_size: int,
|
||||
docs: list[Mapping[str, Any]],
|
||||
client: MongoClient,
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""A proxy for Connection.unack_write that handles event publishing."""
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
cmd = bwc._start(cmd, request_id, docs)
|
||||
try:
|
||||
result = bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override]
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if result is not None:
|
||||
reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type]
|
||||
else:
|
||||
# Comply with APM spec.
|
||||
reply = {"ok": 1}
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
durationMS=duration,
|
||||
reply=reply,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._succeed(request_id, reply, duration)
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if isinstance(exc, OperationFailure):
|
||||
failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type]
|
||||
elif isinstance(exc, NotPrimaryError):
|
||||
failure = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
if bwc.publish:
|
||||
assert bwc.start_time is not None
|
||||
bwc._fail(request_id, failure, duration)
|
||||
raise
|
||||
finally:
|
||||
bwc.start_time = datetime.datetime.now()
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
def _execute_batch_unack(
|
||||
self,
|
||||
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
|
||||
cmd: dict[str, Any],
|
||||
ops: list[Mapping[str, Any]],
|
||||
client: MongoClient,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
if self.is_encrypted:
|
||||
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
|
||||
bwc.conn.command( # type: ignore[misc]
|
||||
bwc.db_name,
|
||||
batched_cmd, # type: ignore[arg-type]
|
||||
write_concern=WriteConcern(w=0),
|
||||
session=bwc.session, # type: ignore[arg-type]
|
||||
client=client, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
request_id, msg, to_send = bwc.batch_command(cmd, ops)
|
||||
# Though this isn't strictly a "legacy" write, the helper
|
||||
# handles publishing commands and sending our message
|
||||
# without receiving a result. Send 0 for max_doc_size
|
||||
# to disable size checking. Size checking is handled while
|
||||
# the documents are encoded to BSON.
|
||||
self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) # type: ignore[arg-type]
|
||||
|
||||
return to_send
|
||||
|
||||
def _execute_batch(
|
||||
self,
|
||||
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
|
||||
cmd: dict[str, Any],
|
||||
ops: list[Mapping[str, Any]],
|
||||
client: MongoClient,
|
||||
) -> tuple[dict[str, Any], list[Mapping[str, Any]]]:
|
||||
if self.is_encrypted:
|
||||
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
|
||||
result = bwc.conn.command( # type: ignore[misc]
|
||||
bwc.db_name,
|
||||
batched_cmd, # type: ignore[arg-type]
|
||||
codec_options=bwc.codec,
|
||||
session=bwc.session, # type: ignore[arg-type]
|
||||
client=client, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
request_id, msg, to_send = bwc.batch_command(cmd, ops)
|
||||
result = self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type]
|
||||
client._process_response(result, bwc.session) # type: ignore[arg-type]
|
||||
|
||||
return result, to_send # type: ignore[return-value]
|
||||
|
||||
def _execute_command(
|
||||
self,
|
||||
generator: Iterator[Any],
|
||||
@ -387,7 +525,7 @@ class _Bulk:
|
||||
|
||||
# Run as many ops as possible in one command.
|
||||
if write_concern.acknowledged:
|
||||
result, to_send = bwc.execute(cmd, ops, client)
|
||||
result, to_send = self._execute_batch(bwc, cmd, ops, client)
|
||||
|
||||
# Retryable writeConcernErrors halt the execution of this run.
|
||||
wce = result.get("writeConcernError", {})
|
||||
@ -407,7 +545,7 @@ class _Bulk:
|
||||
if self.ordered and "writeErrors" in result:
|
||||
break
|
||||
else:
|
||||
to_send = bwc.execute_unack(cmd, ops, client)
|
||||
to_send = self._execute_batch_unack(bwc, cmd, ops, client)
|
||||
|
||||
run.idx_offset += len(to_send)
|
||||
|
||||
@ -458,7 +596,7 @@ class _Bulk:
|
||||
retryable_bulk,
|
||||
session,
|
||||
operation,
|
||||
bulk=self,
|
||||
bulk=self, # type: ignore[arg-type]
|
||||
operation_id=op_id,
|
||||
)
|
||||
|
||||
@ -499,7 +637,7 @@ class _Bulk:
|
||||
conn.add_server_api(cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
# Run as many ops as possible.
|
||||
to_send = bwc.execute_unack(cmd, ops, client)
|
||||
to_send = self._execute_batch_unack(bwc, cmd, ops, client)
|
||||
run.idx_offset += len(to_send)
|
||||
self.current_run = run = next(generator, None)
|
||||
|
||||
|
||||
@ -21,7 +21,8 @@ from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
|
||||
from bson import CodecOptions, _bson_to_dict
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot
|
||||
from pymongo import _csot, common
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.errors import (
|
||||
ConnectionFailure,
|
||||
CursorNotFound,
|
||||
@ -29,16 +30,14 @@ from pymongo.errors import (
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
)
|
||||
from pymongo.synchronous import common
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.synchronous.aggregation import (
|
||||
_AggregationCommand,
|
||||
_CollectionAggregationCommand,
|
||||
_DatabaseAggregationCommand,
|
||||
)
|
||||
from pymongo.synchronous.collation import validate_collation_or_none
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.synchronous.operations import _Op
|
||||
from pymongo.synchronous.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
@ -1,334 +0,0 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Tools to parse mongo client options."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast
|
||||
|
||||
from bson.codec_options import _parse_codec_options
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.ssl_support import get_ssl_context
|
||||
from pymongo.synchronous import common
|
||||
from pymongo.synchronous.compression_support import CompressionSettings
|
||||
from pymongo.synchronous.monitoring import _EventListener, _EventListeners
|
||||
from pymongo.synchronous.pool import PoolOptions
|
||||
from pymongo.synchronous.read_preferences import (
|
||||
_ServerMode,
|
||||
make_read_preference,
|
||||
read_pref_mode_from_name,
|
||||
)
|
||||
from pymongo.synchronous.server_selectors import any_server_selector
|
||||
from pymongo.write_concern import WriteConcern, validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.codec_options import CodecOptions
|
||||
from pymongo.pyopenssl_context import SSLContext
|
||||
from pymongo.synchronous.auth import MongoCredential
|
||||
from pymongo.synchronous.encryption_options import AutoEncryptionOpts
|
||||
from pymongo.synchronous.topology_description import _ServerSelector
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def _parse_credentials(
|
||||
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
) -> Optional[MongoCredential]:
|
||||
"""Parse authentication credentials."""
|
||||
mechanism = options.get("authmechanism", "DEFAULT" if username else None)
|
||||
source = options.get("authsource")
|
||||
if username or mechanism:
|
||||
from pymongo.synchronous.auth import _build_credentials_tuple
|
||||
|
||||
return _build_credentials_tuple(mechanism, source, username, password, options, database)
|
||||
return None
|
||||
|
||||
|
||||
def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode:
|
||||
"""Parse read preference options."""
|
||||
if "read_preference" in options:
|
||||
return options["read_preference"]
|
||||
|
||||
name = options.get("readpreference", "primary")
|
||||
mode = read_pref_mode_from_name(name)
|
||||
tags = options.get("readpreferencetags")
|
||||
max_staleness = options.get("maxstalenessseconds", -1)
|
||||
return make_read_preference(mode, tags, max_staleness)
|
||||
|
||||
|
||||
def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern:
|
||||
"""Parse write concern options."""
|
||||
concern = options.get("w")
|
||||
wtimeout = options.get("wtimeoutms")
|
||||
j = options.get("journal")
|
||||
fsync = options.get("fsync")
|
||||
return WriteConcern(concern, wtimeout, j, fsync)
|
||||
|
||||
|
||||
def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
|
||||
"""Parse read concern options."""
|
||||
concern = options.get("readconcernlevel")
|
||||
return ReadConcern(concern)
|
||||
|
||||
|
||||
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
|
||||
"""Parse ssl options."""
|
||||
use_tls = options.get("tls")
|
||||
if use_tls is not None:
|
||||
validate_boolean("tls", use_tls)
|
||||
|
||||
certfile = options.get("tlscertificatekeyfile")
|
||||
passphrase = options.get("tlscertificatekeyfilepassword")
|
||||
ca_certs = options.get("tlscafile")
|
||||
crlfile = options.get("tlscrlfile")
|
||||
allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False)
|
||||
allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False)
|
||||
disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False)
|
||||
|
||||
enabled_tls_opts = []
|
||||
for opt in (
|
||||
"tlscertificatekeyfile",
|
||||
"tlscertificatekeyfilepassword",
|
||||
"tlscafile",
|
||||
"tlscrlfile",
|
||||
):
|
||||
# Any non-null value of these options implies tls=True.
|
||||
if opt in options and options[opt]:
|
||||
enabled_tls_opts.append(opt)
|
||||
for opt in (
|
||||
"tlsallowinvalidcertificates",
|
||||
"tlsallowinvalidhostnames",
|
||||
"tlsdisableocspendpointcheck",
|
||||
):
|
||||
# A value of False for these options implies tls=True.
|
||||
if opt in options and not options[opt]:
|
||||
enabled_tls_opts.append(opt)
|
||||
|
||||
if enabled_tls_opts:
|
||||
if use_tls is None:
|
||||
# Implicitly enable TLS when one of the tls* options is set.
|
||||
use_tls = True
|
||||
elif not use_tls:
|
||||
# Error since tls is explicitly disabled but a tls option is set.
|
||||
raise ConfigurationError(
|
||||
"TLS has not been enabled but the "
|
||||
"following tls parameters have been set: "
|
||||
"%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts)
|
||||
)
|
||||
|
||||
if use_tls:
|
||||
ctx = get_ssl_context(
|
||||
certfile,
|
||||
passphrase,
|
||||
ca_certs,
|
||||
crlfile,
|
||||
allow_invalid_certificates,
|
||||
allow_invalid_hostnames,
|
||||
disable_ocsp_endpoint_check,
|
||||
)
|
||||
return ctx, allow_invalid_hostnames
|
||||
return None, allow_invalid_hostnames
|
||||
|
||||
|
||||
def _parse_pool_options(
|
||||
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
) -> PoolOptions:
|
||||
"""Parse connection pool options."""
|
||||
credentials = _parse_credentials(username, password, database, options)
|
||||
max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE)
|
||||
min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE)
|
||||
max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC)
|
||||
if max_pool_size is not None and min_pool_size > max_pool_size:
|
||||
raise ValueError("minPoolSize must be smaller or equal to maxPoolSize")
|
||||
connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT)
|
||||
socket_timeout = options.get("sockettimeoutms")
|
||||
wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT)
|
||||
event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners"))
|
||||
appname = options.get("appname")
|
||||
driver = options.get("driver")
|
||||
server_api = options.get("server_api")
|
||||
compression_settings = CompressionSettings(
|
||||
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
|
||||
)
|
||||
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
|
||||
load_balanced = options.get("loadbalanced")
|
||||
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
|
||||
return PoolOptions(
|
||||
max_pool_size,
|
||||
min_pool_size,
|
||||
max_idle_time_seconds,
|
||||
connect_timeout,
|
||||
socket_timeout,
|
||||
wait_queue_timeout,
|
||||
ssl_context,
|
||||
tls_allow_invalid_hostnames,
|
||||
_EventListeners(event_listeners),
|
||||
appname,
|
||||
driver,
|
||||
compression_settings,
|
||||
max_connecting=max_connecting,
|
||||
server_api=server_api,
|
||||
load_balanced=load_balanced,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
|
||||
class ClientOptions:
|
||||
"""Read only configuration options for a MongoClient.
|
||||
|
||||
Should not be instantiated directly by application developers. Access
|
||||
a client's options via :attr:`pymongo.mongo_client.MongoClient.options`
|
||||
instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
):
|
||||
self.__options = options
|
||||
self.__codec_options = _parse_codec_options(options)
|
||||
self.__direct_connection = options.get("directconnection")
|
||||
self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS)
|
||||
# self.__server_selection_timeout is in seconds. Must use full name for
|
||||
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
|
||||
self.__server_selection_timeout = options.get(
|
||||
"serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT
|
||||
)
|
||||
self.__pool_options = _parse_pool_options(username, password, database, options)
|
||||
self.__read_preference = _parse_read_preference(options)
|
||||
self.__replica_set_name = options.get("replicaset")
|
||||
self.__write_concern = _parse_write_concern(options)
|
||||
self.__read_concern = _parse_read_concern(options)
|
||||
self.__connect = options.get("connect")
|
||||
self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY)
|
||||
self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES)
|
||||
self.__retry_reads = options.get("retryreads", common.RETRY_READS)
|
||||
self.__server_selector = options.get("server_selector", any_server_selector)
|
||||
self.__auto_encryption_opts = options.get("auto_encryption_opts")
|
||||
self.__load_balanced = options.get("loadbalanced")
|
||||
self.__timeout = options.get("timeoutms")
|
||||
self.__server_monitoring_mode = options.get(
|
||||
"servermonitoringmode", common.SERVER_MONITORING_MODE
|
||||
)
|
||||
|
||||
@property
|
||||
def _options(self) -> Mapping[str, Any]:
|
||||
"""The original options used to create this ClientOptions."""
|
||||
return self.__options
|
||||
|
||||
@property
|
||||
def connect(self) -> Optional[bool]:
|
||||
"""Whether to begin discovering a MongoDB topology automatically."""
|
||||
return self.__connect
|
||||
|
||||
@property
|
||||
def codec_options(self) -> CodecOptions:
|
||||
"""A :class:`~bson.codec_options.CodecOptions` instance."""
|
||||
return self.__codec_options
|
||||
|
||||
@property
|
||||
def direct_connection(self) -> Optional[bool]:
|
||||
"""Whether to connect to the deployment in 'Single' topology."""
|
||||
return self.__direct_connection
|
||||
|
||||
@property
|
||||
def local_threshold_ms(self) -> int:
|
||||
"""The local threshold for this instance."""
|
||||
return self.__local_threshold_ms
|
||||
|
||||
@property
|
||||
def server_selection_timeout(self) -> int:
|
||||
"""The server selection timeout for this instance in seconds."""
|
||||
return self.__server_selection_timeout
|
||||
|
||||
@property
|
||||
def server_selector(self) -> _ServerSelector:
|
||||
return self.__server_selector
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self) -> int:
|
||||
"""The monitoring frequency in seconds."""
|
||||
return self.__heartbeat_frequency
|
||||
|
||||
@property
|
||||
def pool_options(self) -> PoolOptions:
|
||||
"""A :class:`~pymongo.pool.PoolOptions` instance."""
|
||||
return self.__pool_options
|
||||
|
||||
@property
|
||||
def read_preference(self) -> _ServerMode:
|
||||
"""A read preference instance."""
|
||||
return self.__read_preference
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""Replica set name or None."""
|
||||
return self.__replica_set_name
|
||||
|
||||
@property
|
||||
def write_concern(self) -> WriteConcern:
|
||||
"""A :class:`~pymongo.write_concern.WriteConcern` instance."""
|
||||
return self.__write_concern
|
||||
|
||||
@property
|
||||
def read_concern(self) -> ReadConcern:
|
||||
"""A :class:`~pymongo.read_concern.ReadConcern` instance."""
|
||||
return self.__read_concern
|
||||
|
||||
@property
|
||||
def timeout(self) -> Optional[float]:
|
||||
"""The configured timeoutMS converted to seconds, or None.
|
||||
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
return self.__timeout
|
||||
|
||||
@property
|
||||
def retry_writes(self) -> bool:
|
||||
"""If this instance should retry supported write operations."""
|
||||
return self.__retry_writes
|
||||
|
||||
@property
|
||||
def retry_reads(self) -> bool:
|
||||
"""If this instance should retry supported read operations."""
|
||||
return self.__retry_reads
|
||||
|
||||
@property
|
||||
def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]:
|
||||
"""A :class:`~pymongo.encryption.AutoEncryptionOpts` or None."""
|
||||
return self.__auto_encryption_opts
|
||||
|
||||
@property
|
||||
def load_balanced(self) -> Optional[bool]:
|
||||
"""True if the client was configured to connect to a load balancer."""
|
||||
return self.__load_balanced
|
||||
|
||||
@property
|
||||
def event_listeners(self) -> list[_EventListeners]:
|
||||
"""The event listeners registered for this client.
|
||||
|
||||
See :mod:`~pymongo.monitoring` for details.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
assert self.__pool_options._event_listeners is not None
|
||||
return self.__pool_options._event_listeners.event_listeners()
|
||||
|
||||
@property
|
||||
def server_monitoring_mode(self) -> str:
|
||||
"""The configured serverMonitoringMode option.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
"""
|
||||
return self.__server_monitoring_mode
|
||||
@ -164,12 +164,12 @@ from pymongo.errors import (
|
||||
PyMongoError,
|
||||
WTimeoutError,
|
||||
)
|
||||
from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.synchronous.cursor import _ConnectionManager
|
||||
from pymongo.synchronous.operations import _Op
|
||||
from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -177,7 +177,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from pymongo.synchronous.pool import Connection
|
||||
from pymongo.synchronous.server import Server
|
||||
from pymongo.synchronous.typings import ClusterTime, _Address
|
||||
from pymongo.typings import ClusterTime, _Address
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
@ -1,226 +0,0 @@
|
||||
# Copyright 2016 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tools for working with `collations`_.
|
||||
|
||||
.. _collations: https://www.mongodb.com/docs/manual/reference/collation/
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping, Optional, Union
|
||||
|
||||
from pymongo.synchronous import common
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class CollationStrength:
|
||||
"""
|
||||
An enum that defines values for `strength` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
PRIMARY = 1
|
||||
"""Differentiate base (unadorned) characters."""
|
||||
|
||||
SECONDARY = 2
|
||||
"""Differentiate character accents."""
|
||||
|
||||
TERTIARY = 3
|
||||
"""Differentiate character case."""
|
||||
|
||||
QUATERNARY = 4
|
||||
"""Differentiate words with and without punctuation."""
|
||||
|
||||
IDENTICAL = 5
|
||||
"""Differentiate unicode code point (characters are exactly identical)."""
|
||||
|
||||
|
||||
class CollationAlternate:
|
||||
"""
|
||||
An enum that defines values for `alternate` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
NON_IGNORABLE = "non-ignorable"
|
||||
"""Spaces and punctuation are treated as base characters."""
|
||||
|
||||
SHIFTED = "shifted"
|
||||
"""Spaces and punctuation are *not* considered base characters.
|
||||
|
||||
Spaces and punctuation are distinguished regardless when the
|
||||
:class:`~pymongo.collation.Collation` strength is at least
|
||||
:data:`~pymongo.collation.CollationStrength.QUATERNARY`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class CollationMaxVariable:
|
||||
"""
|
||||
An enum that defines values for `max_variable` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
PUNCT = "punct"
|
||||
"""Both punctuation and spaces are ignored."""
|
||||
|
||||
SPACE = "space"
|
||||
"""Spaces alone are ignored."""
|
||||
|
||||
|
||||
class CollationCaseFirst:
|
||||
"""
|
||||
An enum that defines values for `case_first` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
UPPER = "upper"
|
||||
"""Sort uppercase characters first."""
|
||||
|
||||
LOWER = "lower"
|
||||
"""Sort lowercase characters first."""
|
||||
|
||||
OFF = "off"
|
||||
"""Default for locale or collation strength."""
|
||||
|
||||
|
||||
class Collation:
|
||||
"""Collation
|
||||
|
||||
:param locale: (string) The locale of the collation. This should be a string
|
||||
that identifies an `ICU locale ID` exactly. For example, ``en_US`` is
|
||||
valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB
|
||||
documentation for a list of supported locales.
|
||||
:param caseLevel: (optional) If ``True``, turn on case sensitivity if
|
||||
`strength` is 1 or 2 (case sensitivity is implied if `strength` is
|
||||
greater than 2). Defaults to ``False``.
|
||||
:param caseFirst: (optional) Specify that either uppercase or lowercase
|
||||
characters take precedence. Must be one of the following values:
|
||||
|
||||
* :data:`~CollationCaseFirst.UPPER`
|
||||
* :data:`~CollationCaseFirst.LOWER`
|
||||
* :data:`~CollationCaseFirst.OFF` (the default)
|
||||
|
||||
:param strength: Specify the comparison strength. This is also
|
||||
known as the ICU comparison level. This must be one of the following
|
||||
values:
|
||||
|
||||
* :data:`~CollationStrength.PRIMARY`
|
||||
* :data:`~CollationStrength.SECONDARY`
|
||||
* :data:`~CollationStrength.TERTIARY` (the default)
|
||||
* :data:`~CollationStrength.QUATERNARY`
|
||||
* :data:`~CollationStrength.IDENTICAL`
|
||||
|
||||
Each successive level builds upon the previous. For example, a
|
||||
`strength` of :data:`~CollationStrength.SECONDARY` differentiates
|
||||
characters based both on the unadorned base character and its accents.
|
||||
|
||||
:param numericOrdering: If ``True``, order numbers numerically
|
||||
instead of in collation order (defaults to ``False``).
|
||||
:param alternate: Specify whether spaces and punctuation are
|
||||
considered base characters. This must be one of the following values:
|
||||
|
||||
* :data:`~CollationAlternate.NON_IGNORABLE` (the default)
|
||||
* :data:`~CollationAlternate.SHIFTED`
|
||||
|
||||
:param maxVariable: When `alternate` is
|
||||
:data:`~CollationAlternate.SHIFTED`, this option specifies what
|
||||
characters may be ignored. This must be one of the following values:
|
||||
|
||||
* :data:`~CollationMaxVariable.PUNCT` (the default)
|
||||
* :data:`~CollationMaxVariable.SPACE`
|
||||
|
||||
:param normalization: If ``True``, normalizes text into Unicode
|
||||
NFD. Defaults to ``False``.
|
||||
:param backwards: If ``True``, accents on characters are
|
||||
considered from the back of the word to the front, as it is done in some
|
||||
French dictionary ordering traditions. Defaults to ``False``.
|
||||
:param kwargs: Keyword arguments supplying any additional options
|
||||
to be sent with this Collation object.
|
||||
|
||||
.. versionadded: 3.4
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
locale: str,
|
||||
caseLevel: Optional[bool] = None,
|
||||
caseFirst: Optional[str] = None,
|
||||
strength: Optional[int] = None,
|
||||
numericOrdering: Optional[bool] = None,
|
||||
alternate: Optional[str] = None,
|
||||
maxVariable: Optional[str] = None,
|
||||
normalization: Optional[bool] = None,
|
||||
backwards: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
locale = common.validate_string("locale", locale)
|
||||
self.__document: dict[str, Any] = {"locale": locale}
|
||||
if caseLevel is not None:
|
||||
self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel)
|
||||
if caseFirst is not None:
|
||||
self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst)
|
||||
if strength is not None:
|
||||
self.__document["strength"] = common.validate_integer("strength", strength)
|
||||
if numericOrdering is not None:
|
||||
self.__document["numericOrdering"] = validate_boolean(
|
||||
"numericOrdering", numericOrdering
|
||||
)
|
||||
if alternate is not None:
|
||||
self.__document["alternate"] = common.validate_string("alternate", alternate)
|
||||
if maxVariable is not None:
|
||||
self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable)
|
||||
if normalization is not None:
|
||||
self.__document["normalization"] = validate_boolean("normalization", normalization)
|
||||
if backwards is not None:
|
||||
self.__document["backwards"] = validate_boolean("backwards", backwards)
|
||||
self.__document.update(kwargs)
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""The document representation of this collation.
|
||||
|
||||
.. note::
|
||||
:class:`Collation` is immutable. Mutating the value of
|
||||
:attr:`document` does not mutate this :class:`Collation`.
|
||||
"""
|
||||
return self.__document.copy()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
document = self.document
|
||||
return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Collation):
|
||||
return self.document == other.document
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
def validate_collation_or_none(
|
||||
value: Optional[Union[Mapping[str, Any], Collation]]
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, Collation):
|
||||
return value.document
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")
|
||||
@ -41,41 +41,18 @@ from bson.objectid import ObjectId
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.son import SON
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import ASCENDING, _csot
|
||||
from pymongo import ASCENDING, _csot, common, helpers_shared, message
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.common import _ecoc_coll_name, _esc_coll_name
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
InvalidName,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.read_concern import DEFAULT_READ_CONCERN
|
||||
from pymongo.results import (
|
||||
BulkWriteResult,
|
||||
DeleteResult,
|
||||
InsertManyResult,
|
||||
InsertOneResult,
|
||||
UpdateResult,
|
||||
)
|
||||
from pymongo.synchronous import common, helpers, message
|
||||
from pymongo.synchronous.aggregation import (
|
||||
_CollectionAggregationCommand,
|
||||
_CollectionRawAggregationCommand,
|
||||
)
|
||||
from pymongo.synchronous.bulk import _Bulk
|
||||
from pymongo.synchronous.change_stream import CollectionChangeStream
|
||||
from pymongo.synchronous.collation import validate_collation_or_none
|
||||
from pymongo.synchronous.command_cursor import (
|
||||
CommandCursor,
|
||||
RawBatchCommandCursor,
|
||||
)
|
||||
from pymongo.synchronous.common import _ecoc_coll_name, _esc_coll_name
|
||||
from pymongo.synchronous.cursor import (
|
||||
Cursor,
|
||||
RawBatchCursor,
|
||||
)
|
||||
from pymongo.synchronous.helpers import _check_write_command_response
|
||||
from pymongo.synchronous.message import _UNICODE_REPLACE_CODEC_OPTIONS
|
||||
from pymongo.synchronous.operations import (
|
||||
from pymongo.helpers_shared import _check_write_command_response
|
||||
from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS
|
||||
from pymongo.operations import (
|
||||
DeleteMany,
|
||||
DeleteOne,
|
||||
IndexModel,
|
||||
@ -88,8 +65,30 @@ from pymongo.synchronous.operations import (
|
||||
_IndexList,
|
||||
_Op,
|
||||
)
|
||||
from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.synchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
|
||||
from pymongo.read_concern import DEFAULT_READ_CONCERN
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.results import (
|
||||
BulkWriteResult,
|
||||
DeleteResult,
|
||||
InsertManyResult,
|
||||
InsertOneResult,
|
||||
UpdateResult,
|
||||
)
|
||||
from pymongo.synchronous.aggregation import (
|
||||
_CollectionAggregationCommand,
|
||||
_CollectionRawAggregationCommand,
|
||||
)
|
||||
from pymongo.synchronous.bulk import _Bulk
|
||||
from pymongo.synchronous.change_stream import CollectionChangeStream
|
||||
from pymongo.synchronous.command_cursor import (
|
||||
CommandCursor,
|
||||
RawBatchCommandCursor,
|
||||
)
|
||||
from pymongo.synchronous.cursor import (
|
||||
Cursor,
|
||||
RawBatchCursor,
|
||||
)
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
|
||||
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean
|
||||
|
||||
_IS_SYNC = True
|
||||
@ -125,10 +124,10 @@ class ReturnDocument:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import bson
|
||||
from pymongo.collation import Collation
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.synchronous.aggregation import _AggregationCommand
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
from pymongo.synchronous.collation import Collation
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.pool import Connection
|
||||
from pymongo.synchronous.server import Server
|
||||
@ -995,7 +994,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
|
||||
)
|
||||
if not isinstance(hint, str):
|
||||
hint = helpers._index_document(hint)
|
||||
hint = helpers_shared._index_document(hint)
|
||||
update_doc["hint"] = hint
|
||||
command = {"update": self.name, "ordered": ordered, "updates": [update_doc]}
|
||||
if let is not None:
|
||||
@ -1476,7 +1475,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
|
||||
)
|
||||
if not isinstance(hint, str):
|
||||
hint = helpers._index_document(hint)
|
||||
hint = helpers_shared._index_document(hint)
|
||||
delete_doc["hint"] = hint
|
||||
command = {"delete": self.name, "ordered": ordered, "deletes": [delete_doc]}
|
||||
|
||||
@ -2092,7 +2091,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
pipeline.append({"$group": {"_id": 1, "n": {"$sum": 1}}})
|
||||
cmd = {"aggregate": self._name, "pipeline": pipeline, "cursor": {}}
|
||||
if "hint" in kwargs and not isinstance(kwargs["hint"], str):
|
||||
kwargs["hint"] = helpers._index_document(kwargs["hint"])
|
||||
kwargs["hint"] = helpers_shared._index_document(kwargs["hint"])
|
||||
collation = validate_collation_or_none(kwargs.pop("collation", None))
|
||||
cmd.update(kwargs)
|
||||
|
||||
@ -2425,7 +2424,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
) -> None:
|
||||
name = index_or_name
|
||||
if isinstance(index_or_name, list):
|
||||
name = helpers._gen_index_name(index_or_name)
|
||||
name = helpers_shared._gen_index_name(index_or_name)
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("index_or_name must be an instance of str or list")
|
||||
@ -3154,15 +3153,15 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd["let"] = let
|
||||
cmd.update(kwargs)
|
||||
if projection is not None:
|
||||
cmd["fields"] = helpers._fields_list_to_dict(projection, "projection")
|
||||
cmd["fields"] = helpers_shared._fields_list_to_dict(projection, "projection")
|
||||
if sort is not None:
|
||||
cmd["sort"] = helpers._index_document(sort)
|
||||
cmd["sort"] = helpers_shared._index_document(sort)
|
||||
if upsert is not None:
|
||||
validate_boolean("upsert", upsert)
|
||||
cmd["upsert"] = upsert
|
||||
if hint is not None:
|
||||
if not isinstance(hint, str):
|
||||
hint = helpers._index_document(hint)
|
||||
hint = helpers_shared._index_document(hint)
|
||||
|
||||
write_concern = self._write_concern_for_cmd(cmd, session)
|
||||
|
||||
|
||||
@ -31,16 +31,16 @@ from typing import (
|
||||
from bson import CodecOptions, _convert_raw_document_lists_to_streams
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.synchronous.cursor import _ConnectionManager
|
||||
from pymongo.synchronous.message import (
|
||||
from pymongo.message import (
|
||||
_CursorAddress,
|
||||
_GetMore,
|
||||
_OpMsg,
|
||||
_OpReply,
|
||||
_RawBatchGetMore,
|
||||
)
|
||||
from pymongo.synchronous.response import PinnedResponse
|
||||
from pymongo.synchronous.typings import _Address, _DocumentOut, _DocumentType
|
||||
from pymongo.response import PinnedResponse
|
||||
from pymongo.synchronous.cursor import _ConnectionManager
|
||||
from pymongo.typings import _Address, _DocumentOut, _DocumentType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
@ -272,7 +272,7 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self._sock_mgr:
|
||||
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
|
||||
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type]
|
||||
if response.from_command:
|
||||
cursor = response.docs[0]["cursor"]
|
||||
documents = cursor["nextBatch"]
|
||||
|
||||
@ -36,17 +36,16 @@ from typing import (
|
||||
from bson import RE_TYPE, _convert_raw_document_lists_to_streams
|
||||
from bson.code import Code
|
||||
from bson.son import SON
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.synchronous import helpers
|
||||
from pymongo.synchronous.collation import validate_collation_or_none
|
||||
from pymongo.synchronous.common import (
|
||||
from pymongo import helpers_shared
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.common import (
|
||||
validate_is_document_type,
|
||||
validate_is_mapping,
|
||||
)
|
||||
from pymongo.synchronous.helpers import next
|
||||
from pymongo.synchronous.message import (
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.message import (
|
||||
_CursorAddress,
|
||||
_GetMore,
|
||||
_OpMsg,
|
||||
@ -55,18 +54,19 @@ from pymongo.synchronous.message import (
|
||||
_RawBatchGetMore,
|
||||
_RawBatchQuery,
|
||||
)
|
||||
from pymongo.synchronous.response import PinnedResponse
|
||||
from pymongo.synchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
|
||||
from pymongo.response import PinnedResponse
|
||||
from pymongo.synchronous.helpers import next
|
||||
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SupportsItems
|
||||
|
||||
from bson.codec_options import CodecOptions
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.pool import Connection
|
||||
from pymongo.synchronous.read_preferences import _ServerMode
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
@ -179,7 +179,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
allow_disk_use = validate_boolean("allow_disk_use", allow_disk_use)
|
||||
|
||||
if projection is not None:
|
||||
projection = helpers._fields_list_to_dict(projection, "projection")
|
||||
projection = helpers_shared._fields_list_to_dict(projection, "projection")
|
||||
|
||||
if let is not None:
|
||||
validate_is_document_type("let", let)
|
||||
@ -191,7 +191,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
self._skip = skip
|
||||
self._limit = limit
|
||||
self._batch_size = batch_size
|
||||
self._ordering = sort and helpers._index_document(sort) or None
|
||||
self._ordering = sort and helpers_shared._index_document(sort) or None
|
||||
self._max_scan = max_scan
|
||||
self._explain = False
|
||||
self._comment = comment
|
||||
@ -740,8 +740,8 @@ class Cursor(Generic[_DocumentType]):
|
||||
key, if not given :data:`~pymongo.ASCENDING` is assumed
|
||||
"""
|
||||
self._check_okay_to_chain()
|
||||
keys = helpers._index_list(key_or_list, direction)
|
||||
self._ordering = helpers._index_document(keys)
|
||||
keys = helpers_shared._index_list(key_or_list, direction)
|
||||
self._ordering = helpers_shared._index_document(keys)
|
||||
return self
|
||||
|
||||
def explain(self) -> _DocumentType:
|
||||
@ -772,7 +772,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
if isinstance(index, str):
|
||||
self._hint = index
|
||||
else:
|
||||
self._hint = helpers._index_document(index)
|
||||
self._hint = helpers_shared._index_document(index)
|
||||
|
||||
def hint(self, index: Optional[_Hint]) -> Cursor[_DocumentType]:
|
||||
"""Adds a 'hint', telling Mongo the proper index to use for the query.
|
||||
@ -1120,7 +1120,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
self._address = response.address
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self._sock_mgr:
|
||||
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
|
||||
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type]
|
||||
|
||||
cmd_name = operation.name
|
||||
docs = response.docs
|
||||
|
||||
@ -33,18 +33,17 @@ from typing import (
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
|
||||
from bson.dbref import DBRef
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot
|
||||
from pymongo import _csot, common
|
||||
from pymongo.common import _ecoc_coll_name, _esc_coll_name
|
||||
from pymongo.database_shared import _check_name, _CodecDocumentType
|
||||
from pymongo.errors import CollectionInvalid, InvalidOperation
|
||||
from pymongo.synchronous import common
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.synchronous.aggregation import _DatabaseAggregationCommand
|
||||
from pymongo.synchronous.change_stream import DatabaseChangeStream
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.synchronous.common import _ecoc_coll_name, _esc_coll_name
|
||||
from pymongo.synchronous.operations import _Op
|
||||
from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.synchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import bson
|
||||
@ -151,7 +150,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
>>> db1.read_preference
|
||||
Primary()
|
||||
>>> from pymongo.synchronous.read_preferences import Secondary
|
||||
>>> from pymongo.read_preferences import Secondary
|
||||
>>> db2 = db1.with_options(read_preference=Secondary([{'node': 'analytics'}]))
|
||||
>>> db1.read_preference
|
||||
Primary()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user