PYTHON-4264 Async PyMongo Beta (#1629)

This commit is contained in:
Noah Stapp 2024-06-06 09:01:24 -07:00 committed by GitHub
parent e9c86f4c00
commit d6bf0e1e78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
211 changed files with 62315 additions and 23123 deletions

View File

@ -17,6 +17,17 @@ repos:
exclude: .patch
exclude_types: [json]
- repo: local
hooks:
- id: synchro
name: synchro
entry: bash ./tools/synchro.sh
language: python
require_serial: true
additional_dependencies:
- ruff==0.1.3
- unasync
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.3
@ -74,7 +85,7 @@ repos:
stages: [manual]
- repo: https://github.com/ariebovenberg/slotscheck
rev: v0.17.0
rev: v0.19.0
hooks:
- id: slotscheck
files: \.py$

View File

@ -22,6 +22,7 @@ include doc/make.bat
include doc/static/periodic-executor-refs.dot
recursive-include requirements *.txt
recursive-include tools *.py
recursive-include tools *.sh
include tools/README.rst
include green_framework_test.py
recursive-include test *.pem

View File

@ -3,13 +3,13 @@ Changelog
Changes in Version 4.8.0
-------------------------
The handshake metadata for "os.name" on Windows has been simplified to "Windows" to improve import time.
The repr of ``bson.binary.Binary`` is now redacted when the subtype is SENSITIVE_SUBTYPE(8).
.. warning:: PyMongo 4.8 drops support for Python 3.7 and PyPy 3.8: Python 3.8+ or PyPy 3.9+ is now required.
PyMongo 4.8 brings a number of improvements including:
- The handshake metadata for "os.name" on Windows has been simplified to "Windows" to improve import time.
- The repr of ``bson.binary.Binary`` is now redacted when the subtype is SENSITIVE_SUBTYPE(8).
- A new asynchronous API with full asyncio support.
Changes in Version 4.7.3
-------------------------

View File

@ -21,980 +21,34 @@ The :mod:`gridfs` package is an implementation of GridFS on top of
"""
from __future__ import annotations
from collections import abc
from typing import Any, Mapping, Optional, cast
from bson.objectid import ObjectId
from gridfs.asynchronous.grid_file import (
AsyncGridFS,
AsyncGridFSBucket,
AsyncGridIn,
AsyncGridOut,
AsyncGridOutCursor,
)
from gridfs.errors import NoFile
from gridfs.grid_file import (
DEFAULT_CHUNK_SIZE,
from gridfs.grid_file_shared import DEFAULT_CHUNK_SIZE
from gridfs.synchronous.grid_file import (
GridFS,
GridFSBucket,
GridIn,
GridOut,
GridOutCursor,
_clear_entity_type_registry,
_disallow_transactions,
)
from pymongo import ASCENDING, DESCENDING, _csot
from pymongo.client_session import ClientSession
from pymongo.collection import Collection
from pymongo.common import validate_string
from pymongo.database import Database
from pymongo.errors import ConfigurationError
from pymongo.read_preferences import _ServerMode
from pymongo.write_concern import WriteConcern
__all__ = [
"AsyncGridFS",
"GridFS",
"AsyncGridFSBucket",
"GridFSBucket",
"NoFile",
"DEFAULT_CHUNK_SIZE",
"AsyncGridIn",
"GridIn",
"AsyncGridOut",
"GridOut",
"AsyncGridOutCursor",
"GridOutCursor",
]
class GridFS:
"""An instance of GridFS on top of a single Database."""
def __init__(self, database: Database, collection: str = "fs"):
"""Create a new instance of :class:`GridFS`.
Raises :class:`TypeError` if `database` is not an instance of
:class:`~pymongo.database.Database`.
:param database: database to use
:param collection: root collection to use
.. versionchanged:: 4.0
Removed the `disable_md5` parameter. See
:ref:`removed-gridfs-checksum` for details.
.. versionchanged:: 3.11
Running a GridFS operation in a transaction now always raises an
error. GridFS does not support multi-document transactions.
.. versionchanged:: 3.7
Added the `disable_md5` parameter.
.. versionchanged:: 3.1
Indexes are only ensured on the first write to the DB.
.. versionchanged:: 3.0
`database` must use an acknowledged
:attr:`~pymongo.database.Database.write_concern`
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
"""
if not isinstance(database, Database):
raise TypeError("database must be an instance of Database")
database = _clear_entity_type_registry(database)
if not database.write_concern.acknowledged:
raise ConfigurationError("database must use acknowledged write_concern")
self.__collection = database[collection]
self.__files = self.__collection.files
self.__chunks = self.__collection.chunks
def new_file(self, **kwargs: Any) -> GridIn:
"""Create a new file in GridFS.
Returns a new :class:`~gridfs.grid_file.GridIn` instance to
which data can be written. Any keyword arguments will be
passed through to :meth:`~gridfs.grid_file.GridIn`.
If the ``"_id"`` of the file is manually specified, it must
not already exist in GridFS. Otherwise
:class:`~gridfs.errors.FileExists` is raised.
:param kwargs: keyword arguments for file creation
"""
return GridIn(self.__collection, **kwargs)
def put(self, data: Any, **kwargs: Any) -> Any:
"""Put data in GridFS as a new file.
Equivalent to doing::
with fs.new_file(**kwargs) as f:
f.write(data)
`data` can be either an instance of :class:`bytes` or a file-like
object providing a :meth:`read` method. If an `encoding` keyword
argument is passed, `data` can also be a :class:`str` instance, which
will be encoded as `encoding` before being written. Any keyword
arguments will be passed through to the created file - see
:meth:`~gridfs.grid_file.GridIn` for possible arguments. Returns the
``"_id"`` of the created file.
If the ``"_id"`` of the file is manually specified, it must
not already exist in GridFS. Otherwise
:class:`~gridfs.errors.FileExists` is raised.
:param data: data to be written as a file.
:param kwargs: keyword arguments for file creation
.. versionchanged:: 3.0
w=0 writes to GridFS are now prohibited.
"""
with GridIn(self.__collection, **kwargs) as grid_file:
grid_file.write(data)
return grid_file._id
def get(self, file_id: Any, session: Optional[ClientSession] = None) -> GridOut:
"""Get a file from GridFS by ``"_id"``.
Returns an instance of :class:`~gridfs.grid_file.GridOut`,
which provides a file-like interface for reading.
:param file_id: ``"_id"`` of the file to get
:param session: a
:class:`~pymongo.client_session.ClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
gout = GridOut(self.__collection, file_id, session=session)
# Raise NoFile now, instead of on first attribute access.
gout._ensure_file()
return gout
def get_version(
self,
filename: Optional[str] = None,
version: Optional[int] = -1,
session: Optional[ClientSession] = None,
**kwargs: Any,
) -> GridOut:
"""Get a file from GridFS by ``"filename"`` or metadata fields.
Returns a version of the file in GridFS whose filename matches
`filename` and whose metadata fields match the supplied keyword
arguments, as an instance of :class:`~gridfs.grid_file.GridOut`.
Version numbering is a convenience atop the GridFS API provided
by MongoDB. If more than one file matches the query (either by
`filename` alone, by metadata fields, or by a combination of
both), then version ``-1`` will be the most recently uploaded
matching file, ``-2`` the second most recently
uploaded, etc. Version ``0`` will be the first version
uploaded, ``1`` the second version, etc. So if three versions
have been uploaded, then version ``0`` is the same as version
``-3``, version ``1`` is the same as version ``-2``, and
version ``2`` is the same as version ``-1``.
Raises :class:`~gridfs.errors.NoFile` if no such version of
that file exists.
:param filename: ``"filename"`` of the file to get, or `None`
:param version: version of the file to get (defaults
to -1, the most recent version uploaded)
:param session: a
:class:`~pymongo.client_session.ClientSession`
:param kwargs: find files by custom metadata.
.. versionchanged:: 3.6
Added ``session`` parameter.
.. versionchanged:: 3.1
``get_version`` no longer ensures indexes.
"""
query = kwargs
if filename is not None:
query["filename"] = filename
_disallow_transactions(session)
cursor = self.__files.find(query, session=session)
if version is None:
version = -1
if version < 0:
skip = abs(version) - 1
cursor.limit(-1).skip(skip).sort("uploadDate", DESCENDING)
else:
cursor.limit(-1).skip(version).sort("uploadDate", ASCENDING)
try:
doc = next(cursor)
return GridOut(self.__collection, file_document=doc, session=session)
except StopIteration:
raise NoFile("no version %d for filename %r" % (version, filename)) from None
def get_last_version(
self, filename: Optional[str] = None, session: Optional[ClientSession] = None, **kwargs: Any
) -> GridOut:
"""Get the most recent version of a file in GridFS by ``"filename"``
or metadata fields.
Equivalent to calling :meth:`get_version` with the default
`version` (``-1``).
:param filename: ``"filename"`` of the file to get, or `None`
:param session: a
:class:`~pymongo.client_session.ClientSession`
:param kwargs: find files by custom metadata.
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
return self.get_version(filename=filename, session=session, **kwargs)
# TODO add optional safe mode for chunk removal?
def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None:
"""Delete a file from GridFS by ``"_id"``.
Deletes all data belonging to the file with ``"_id"``:
`file_id`.
.. warning:: Any processes/threads reading from the file while
this method is executing will likely see an invalid/corrupt
file. Care should be taken to avoid concurrent reads to a file
while it is being deleted.
.. note:: Deletes of non-existent files are considered successful
since the end result is the same: no file with that _id remains.
:param file_id: ``"_id"`` of the file to delete
:param session: a
:class:`~pymongo.client_session.ClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
.. versionchanged:: 3.1
``delete`` no longer ensures indexes.
"""
_disallow_transactions(session)
self.__files.delete_one({"_id": file_id}, session=session)
self.__chunks.delete_many({"files_id": file_id}, session=session)
def list(self, session: Optional[ClientSession] = None) -> list[str]:
"""List the names of all files stored in this instance of
:class:`GridFS`.
:param session: a
:class:`~pymongo.client_session.ClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
.. versionchanged:: 3.1
``list`` no longer ensures indexes.
"""
_disallow_transactions(session)
# With an index, distinct includes documents with no filename
# as None.
return [
name for name in self.__files.distinct("filename", session=session) if name is not None
]
def find_one(
self,
filter: Optional[Any] = None,
session: Optional[ClientSession] = None,
*args: Any,
**kwargs: Any,
) -> Optional[GridOut]:
"""Get a single file from gridfs.
All arguments to :meth:`find` are also valid arguments for
:meth:`find_one`, although any `limit` argument will be
ignored. Returns a single :class:`~gridfs.grid_file.GridOut`,
or ``None`` if no matching file is found. For example:
.. code-block: python
file = fs.find_one({"filename": "lisa.txt"})
:param filter: a dictionary specifying
the query to be performing OR any other type to be used as
the value for a query for ``"_id"`` in the file collection.
:param args: any additional positional arguments are
the same as the arguments to :meth:`find`.
:param session: a
:class:`~pymongo.client_session.ClientSession`
:param kwargs: any additional keyword arguments
are the same as the arguments to :meth:`find`.
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
if filter is not None and not isinstance(filter, abc.Mapping):
filter = {"_id": filter}
_disallow_transactions(session)
for f in self.find(filter, *args, session=session, **kwargs):
return f
return None
def find(self, *args: Any, **kwargs: Any) -> GridOutCursor:
"""Query GridFS for files.
Returns a cursor that iterates across files matching
arbitrary queries on the files collection. Can be combined
with other modifiers for additional control. For example::
for grid_out in fs.find({"filename": "lisa.txt"},
no_cursor_timeout=True):
data = grid_out.read()
would iterate through all versions of "lisa.txt" stored in GridFS.
Note that setting no_cursor_timeout to True may be important to
prevent the cursor from timing out during long multi-file processing
work.
As another example, the call::
most_recent_three = fs.find().sort("uploadDate", -1).limit(3)
would return a cursor to the three most recently uploaded files
in GridFS.
Follows a similar interface to
:meth:`~pymongo.collection.Collection.find`
in :class:`~pymongo.collection.Collection`.
If a :class:`~pymongo.client_session.ClientSession` is passed to
:meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances
are associated with that session.
:param filter: A query document that selects which files
to include in the result set. Can be an empty document to include
all files.
:param skip: the number of files to omit (from
the start of the result set) when returning the results
:param limit: the maximum number of results to
return
:param no_cursor_timeout: if False (the default), any
returned cursor is closed by the server after 10 minutes of
inactivity. If set to True, the returned cursor will never
time out on the server. Care should be taken to ensure that
cursors with no_cursor_timeout turned on are properly closed.
:param sort: a list of (key, direction) pairs
specifying the sort order for this query. See
:meth:`~pymongo.cursor.Cursor.sort` for details.
Raises :class:`TypeError` if any of the arguments are of
improper type. Returns an instance of
:class:`~gridfs.grid_file.GridOutCursor`
corresponding to this query.
.. versionchanged:: 3.0
Removed the read_preference, tag_sets, and
secondary_acceptable_latency_ms options.
.. versionadded:: 2.7
.. seealso:: The MongoDB documentation on `find <https://dochub.mongodb.org/core/find>`_.
"""
return GridOutCursor(self.__collection, *args, **kwargs)
def exists(
self,
document_or_id: Optional[Any] = None,
session: Optional[ClientSession] = None,
**kwargs: Any,
) -> bool:
"""Check if a file exists in this instance of :class:`GridFS`.
The file to check for can be specified by the value of its
``_id`` key, or by passing in a query document. A query
document can be passed in as dictionary, or by using keyword
arguments. Thus, the following three calls are equivalent:
>>> fs.exists(file_id)
>>> fs.exists({"_id": file_id})
>>> fs.exists(_id=file_id)
As are the following two calls:
>>> fs.exists({"filename": "mike.txt"})
>>> fs.exists(filename="mike.txt")
And the following two:
>>> fs.exists({"foo": {"$gt": 12}})
>>> fs.exists(foo={"$gt": 12})
Returns ``True`` if a matching file exists, ``False``
otherwise. Calls to :meth:`exists` will not automatically
create appropriate indexes; application developers should be
sure to create indexes if needed and as appropriate.
:param document_or_id: query document, or _id of the
document to check for
:param session: a
:class:`~pymongo.client_session.ClientSession`
:param kwargs: keyword arguments are used as a
query document, if they're present.
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
_disallow_transactions(session)
if kwargs:
f = self.__files.find_one(kwargs, ["_id"], session=session)
else:
f = self.__files.find_one(document_or_id, ["_id"], session=session)
return f is not None
class GridFSBucket:
"""An instance of GridFS on top of a single Database."""
def __init__(
self,
db: Database,
bucket_name: str = "fs",
chunk_size_bytes: int = DEFAULT_CHUNK_SIZE,
write_concern: Optional[WriteConcern] = None,
read_preference: Optional[_ServerMode] = None,
) -> None:
"""Create a new instance of :class:`GridFSBucket`.
Raises :exc:`TypeError` if `database` is not an instance of
:class:`~pymongo.database.Database`.
Raises :exc:`~pymongo.errors.ConfigurationError` if `write_concern`
is not acknowledged.
:param database: database to use.
:param bucket_name: The name of the bucket. Defaults to 'fs'.
:param chunk_size_bytes: The chunk size in bytes. Defaults
to 255KB.
:param write_concern: The
:class:`~pymongo.write_concern.WriteConcern` to use. If ``None``
(the default) db.write_concern is used.
:param read_preference: The read preference to use. If
``None`` (the default) db.read_preference is used.
.. versionchanged:: 4.0
Removed the `disable_md5` parameter. See
:ref:`removed-gridfs-checksum` for details.
.. versionchanged:: 3.11
Running a GridFSBucket operation in a transaction now always raises
an error. GridFSBucket does not support multi-document transactions.
.. versionchanged:: 3.7
Added the `disable_md5` parameter.
.. versionadded:: 3.1
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
"""
if not isinstance(db, Database):
raise TypeError("database must be an instance of Database")
db = _clear_entity_type_registry(db)
wtc = write_concern if write_concern is not None else db.write_concern
if not wtc.acknowledged:
raise ConfigurationError("write concern must be acknowledged")
self._bucket_name = bucket_name
self._collection = db[bucket_name]
self._chunks: Collection = self._collection.chunks.with_options(
write_concern=write_concern, read_preference=read_preference
)
self._files: Collection = self._collection.files.with_options(
write_concern=write_concern, read_preference=read_preference
)
self._chunk_size_bytes = chunk_size_bytes
self._timeout = db.client.options.timeout
def open_upload_stream(
self,
filename: str,
chunk_size_bytes: Optional[int] = None,
metadata: Optional[Mapping[str, Any]] = None,
session: Optional[ClientSession] = None,
) -> GridIn:
"""Opens a Stream that the application can write the contents of the
file to.
The user must specify the filename, and can choose to add any
additional information in the metadata field of the file document or
modify the chunk size.
For example::
my_db = MongoClient().test
fs = GridFSBucket(my_db)
with fs.open_upload_stream(
"test_file", chunk_size_bytes=4,
metadata={"contentType": "text/plain"}) as grid_in:
grid_in.write("data I want to store!")
# uploaded on close
Returns an instance of :class:`~gridfs.grid_file.GridIn`.
Raises :exc:`~gridfs.errors.NoFile` if no such version of
that file exists.
Raises :exc:`~ValueError` if `filename` is not a string.
:param filename: The name of the file to upload.
:param chunk_size_bytes` (options): The number of bytes per chunk of this
file. Defaults to the chunk_size_bytes in :class:`GridFSBucket`.
:param metadata: User data for the 'metadata' field of the
files collection document. If not provided the metadata field will
be omitted from the files collection document.
:param session: a
:class:`~pymongo.client_session.ClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
validate_string("filename", filename)
opts = {
"filename": filename,
"chunk_size": (
chunk_size_bytes if chunk_size_bytes is not None else self._chunk_size_bytes
),
}
if metadata is not None:
opts["metadata"] = metadata
return GridIn(self._collection, session=session, **opts)
def open_upload_stream_with_id(
self,
file_id: Any,
filename: str,
chunk_size_bytes: Optional[int] = None,
metadata: Optional[Mapping[str, Any]] = None,
session: Optional[ClientSession] = None,
) -> GridIn:
"""Opens a Stream that the application can write the contents of the
file to.
The user must specify the file id and filename, and can choose to add
any additional information in the metadata field of the file document
or modify the chunk size.
For example::
my_db = MongoClient().test
fs = GridFSBucket(my_db)
with fs.open_upload_stream_with_id(
ObjectId(),
"test_file",
chunk_size_bytes=4,
metadata={"contentType": "text/plain"}) as grid_in:
grid_in.write("data I want to store!")
# uploaded on close
Returns an instance of :class:`~gridfs.grid_file.GridIn`.
Raises :exc:`~gridfs.errors.NoFile` if no such version of
that file exists.
Raises :exc:`~ValueError` if `filename` is not a string.
:param file_id: The id to use for this file. The id must not have
already been used for another file.
:param filename: The name of the file to upload.
:param chunk_size_bytes` (options): The number of bytes per chunk of this
file. Defaults to the chunk_size_bytes in :class:`GridFSBucket`.
:param metadata: User data for the 'metadata' field of the
files collection document. If not provided the metadata field will
be omitted from the files collection document.
:param session: a
:class:`~pymongo.client_session.ClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
validate_string("filename", filename)
opts = {
"_id": file_id,
"filename": filename,
"chunk_size": (
chunk_size_bytes if chunk_size_bytes is not None else self._chunk_size_bytes
),
}
if metadata is not None:
opts["metadata"] = metadata
return GridIn(self._collection, session=session, **opts)
@_csot.apply
def upload_from_stream(
self,
filename: str,
source: Any,
chunk_size_bytes: Optional[int] = None,
metadata: Optional[Mapping[str, Any]] = None,
session: Optional[ClientSession] = None,
) -> ObjectId:
"""Uploads a user file to a GridFS bucket.
Reads the contents of the user file from `source` and uploads
it to the file `filename`. Source can be a string or file-like object.
For example::
my_db = MongoClient().test
fs = GridFSBucket(my_db)
file_id = fs.upload_from_stream(
"test_file",
"data I want to store!",
chunk_size_bytes=4,
metadata={"contentType": "text/plain"})
Returns the _id of the uploaded file.
Raises :exc:`~gridfs.errors.NoFile` if no such version of
that file exists.
Raises :exc:`~ValueError` if `filename` is not a string.
:param filename: The name of the file to upload.
:param source: The source stream of the content to be uploaded. Must be
a file-like object that implements :meth:`read` or a string.
:param chunk_size_bytes` (options): The number of bytes per chunk of this
file. Defaults to the chunk_size_bytes of :class:`GridFSBucket`.
:param metadata: User data for the 'metadata' field of the
files collection document. If not provided the metadata field will
be omitted from the files collection document.
:param session: a
:class:`~pymongo.client_session.ClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
with self.open_upload_stream(filename, chunk_size_bytes, metadata, session=session) as gin:
gin.write(source)
return cast(ObjectId, gin._id)
@_csot.apply
def upload_from_stream_with_id(
self,
file_id: Any,
filename: str,
source: Any,
chunk_size_bytes: Optional[int] = None,
metadata: Optional[Mapping[str, Any]] = None,
session: Optional[ClientSession] = None,
) -> None:
"""Uploads a user file to a GridFS bucket with a custom file id.
Reads the contents of the user file from `source` and uploads
it to the file `filename`. Source can be a string or file-like object.
For example::
my_db = MongoClient().test
fs = GridFSBucket(my_db)
file_id = fs.upload_from_stream(
ObjectId(),
"test_file",
"data I want to store!",
chunk_size_bytes=4,
metadata={"contentType": "text/plain"})
Raises :exc:`~gridfs.errors.NoFile` if no such version of
that file exists.
Raises :exc:`~ValueError` if `filename` is not a string.
:param file_id: The id to use for this file. The id must not have
already been used for another file.
:param filename: The name of the file to upload.
:param source: The source stream of the content to be uploaded. Must be
a file-like object that implements :meth:`read` or a string.
:param chunk_size_bytes` (options): The number of bytes per chunk of this
file. Defaults to the chunk_size_bytes of :class:`GridFSBucket`.
:param metadata: User data for the 'metadata' field of the
files collection document. If not provided the metadata field will
be omitted from the files collection document.
:param session: a
:class:`~pymongo.client_session.ClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
with self.open_upload_stream_with_id(
file_id, filename, chunk_size_bytes, metadata, session=session
) as gin:
gin.write(source)
def open_download_stream(
self, file_id: Any, session: Optional[ClientSession] = None
) -> GridOut:
"""Opens a Stream from which the application can read the contents of
the stored file specified by file_id.
For example::
my_db = MongoClient().test
fs = GridFSBucket(my_db)
# get _id of file to read.
file_id = fs.upload_from_stream("test_file", "data I want to store!")
grid_out = fs.open_download_stream(file_id)
contents = grid_out.read()
Returns an instance of :class:`~gridfs.grid_file.GridOut`.
Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists.
:param file_id: The _id of the file to be downloaded.
:param session: a
:class:`~pymongo.client_session.ClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
gout = GridOut(self._collection, file_id, session=session)
# Raise NoFile now, instead of on first attribute access.
gout._ensure_file()
return gout
@_csot.apply
def download_to_stream(
self, file_id: Any, destination: Any, session: Optional[ClientSession] = None
) -> None:
"""Downloads the contents of the stored file specified by file_id and
writes the contents to `destination`.
For example::
my_db = MongoClient().test
fs = GridFSBucket(my_db)
# Get _id of file to read
file_id = fs.upload_from_stream("test_file", "data I want to store!")
# Get file to write to
file = open('myfile','wb+')
fs.download_to_stream(file_id, file)
file.seek(0)
contents = file.read()
Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists.
:param file_id: The _id of the file to be downloaded.
:param destination: a file-like object implementing :meth:`write`.
:param session: a
:class:`~pymongo.client_session.ClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
with self.open_download_stream(file_id, session=session) as gout:
while True:
chunk = gout.readchunk()
if not len(chunk):
break
destination.write(chunk)
@_csot.apply
def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None:
"""Given an file_id, delete this stored file's files collection document
and associated chunks from a GridFS bucket.
For example::
my_db = MongoClient().test
fs = GridFSBucket(my_db)
# Get _id of file to delete
file_id = fs.upload_from_stream("test_file", "data I want to store!")
fs.delete(file_id)
Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists.
:param file_id: The _id of the file to be deleted.
:param session: a
:class:`~pymongo.client_session.ClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
_disallow_transactions(session)
res = self._files.delete_one({"_id": file_id}, session=session)
self._chunks.delete_many({"files_id": file_id}, session=session)
if not res.deleted_count:
raise NoFile("no file could be deleted because none matched %s" % file_id)
def find(self, *args: Any, **kwargs: Any) -> GridOutCursor:
"""Find and return the files collection documents that match ``filter``
Returns a cursor that iterates across files matching
arbitrary queries on the files collection. Can be combined
with other modifiers for additional control.
For example::
for grid_data in fs.find({"filename": "lisa.txt"},
no_cursor_timeout=True):
data = grid_data.read()
would iterate through all versions of "lisa.txt" stored in GridFS.
Note that setting no_cursor_timeout to True may be important to
prevent the cursor from timing out during long multi-file processing
work.
As another example, the call::
most_recent_three = fs.find().sort("uploadDate", -1).limit(3)
would return a cursor to the three most recently uploaded files
in GridFS.
Follows a similar interface to
:meth:`~pymongo.collection.Collection.find`
in :class:`~pymongo.collection.Collection`.
If a :class:`~pymongo.client_session.ClientSession` is passed to
:meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances
are associated with that session.
:param filter: Search query.
:param batch_size: The number of documents to return per
batch.
:param limit: The maximum number of documents to return.
:param no_cursor_timeout: The server normally times out idle
cursors after an inactivity period (10 minutes) to prevent excess
memory use. Set this option to True prevent that.
:param skip: The number of documents to skip before
returning.
:param sort: The order by which to sort results. Defaults to
None.
"""
return GridOutCursor(self._collection, *args, **kwargs)
def open_download_stream_by_name(
self, filename: str, revision: int = -1, session: Optional[ClientSession] = None
) -> GridOut:
"""Opens a Stream from which the application can read the contents of
`filename` and optional `revision`.
For example::
my_db = MongoClient().test
fs = GridFSBucket(my_db)
grid_out = fs.open_download_stream_by_name("test_file")
contents = grid_out.read()
Returns an instance of :class:`~gridfs.grid_file.GridOut`.
Raises :exc:`~gridfs.errors.NoFile` if no such version of
that file exists.
Raises :exc:`~ValueError` filename is not a string.
:param filename: The name of the file to read from.
:param revision: Which revision (documents with the same
filename and different uploadDate) of the file to retrieve.
Defaults to -1 (the most recent revision).
:param session: a
:class:`~pymongo.client_session.ClientSession`
:Note: Revision numbers are defined as follows:
- 0 = the original stored file
- 1 = the first revision
- 2 = the second revision
- etc...
- -2 = the second most recent revision
- -1 = the most recent revision
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
validate_string("filename", filename)
query = {"filename": filename}
_disallow_transactions(session)
cursor = self._files.find(query, session=session)
if revision < 0:
skip = abs(revision) - 1
cursor.limit(-1).skip(skip).sort("uploadDate", DESCENDING)
else:
cursor.limit(-1).skip(revision).sort("uploadDate", ASCENDING)
try:
grid_file = next(cursor)
return GridOut(self._collection, file_document=grid_file, session=session)
except StopIteration:
raise NoFile("no version %d for filename %r" % (revision, filename)) from None
@_csot.apply
def download_to_stream_by_name(
self,
filename: str,
destination: Any,
revision: int = -1,
session: Optional[ClientSession] = None,
) -> None:
"""Write the contents of `filename` (with optional `revision`) to
`destination`.
For example::
my_db = MongoClient().test
fs = GridFSBucket(my_db)
# Get file to write to
file = open('myfile','wb')
fs.download_to_stream_by_name("test_file", file)
Raises :exc:`~gridfs.errors.NoFile` if no such version of
that file exists.
Raises :exc:`~ValueError` if `filename` is not a string.
:param filename: The name of the file to read from.
:param destination: A file-like object that implements :meth:`write`.
:param revision: Which revision (documents with the same
filename and different uploadDate) of the file to retrieve.
Defaults to -1 (the most recent revision).
:param session: a
:class:`~pymongo.client_session.ClientSession`
:Note: Revision numbers are defined as follows:
- 0 = the original stored file
- 1 = the first revision
- 2 = the second revision
- etc...
- -2 = the second most recent revision
- -1 = the most recent revision
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
with self.open_download_stream_by_name(filename, revision, session=session) as gout:
while True:
chunk = gout.readchunk()
if not len(chunk):
break
destination.write(chunk)
def rename(
self, file_id: Any, new_filename: str, session: Optional[ClientSession] = None
) -> None:
"""Renames the stored file with the specified file_id.
For example::
my_db = MongoClient().test
fs = GridFSBucket(my_db)
# Get _id of file to rename
file_id = fs.upload_from_stream("test_file", "data I want to store!")
fs.rename(file_id, "new_test_name")
Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists.
:param file_id: The _id of the file to be renamed.
:param new_filename: The new name of the file.
:param session: a
:class:`~pymongo.client_session.ClientSession`
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
_disallow_transactions(session)
result = self._files.update_one(
{"_id": file_id}, {"$set": {"filename": new_filename}}, session=session
)
if not result.matched_count:
raise NoFile(
"no files could be renamed %r because none "
"matched file_id %i" % (new_filename, file_id)
)

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
# Copyright 2009-present MongoDB, Inc.
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,953 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for representing files stored in GridFS."""
"""Re-import of synchronous gridfs API for compatibility."""
from __future__ import annotations
import datetime
import io
import math
import os
import warnings
from typing import Any, Iterable, Mapping, NoReturn, Optional
from bson.int64 import Int64
from bson.objectid import ObjectId
from gridfs.errors import CorruptGridFile, FileExists, NoFile
from pymongo import ASCENDING
from pymongo.client_session import ClientSession
from pymongo.collection import Collection
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.cursor import Cursor
from pymongo.errors import (
BulkWriteError,
ConfigurationError,
CursorNotFound,
DuplicateKeyError,
InvalidOperation,
OperationFailure,
)
from pymongo.helpers import _check_write_command_response
from pymongo.read_preferences import ReadPreference
_SEEK_SET = os.SEEK_SET
_SEEK_CUR = os.SEEK_CUR
_SEEK_END = os.SEEK_END
EMPTY = b""
NEWLN = b"\n"
"""Default chunk size, in bytes."""
# Slightly under a power of 2, to work well with server's record allocations.
DEFAULT_CHUNK_SIZE = 255 * 1024
# The number of chunked bytes to buffer before calling insert_many.
_UPLOAD_BUFFER_SIZE = MAX_MESSAGE_SIZE
# The number of chunk documents to buffer before calling insert_many.
_UPLOAD_BUFFER_CHUNKS = 100000
# Rough BSON overhead of a chunk document not including the chunk data itself.
# Essentially len(encode({"_id": ObjectId(), "files_id": ObjectId(), "n": 1, "data": ""}))
_CHUNK_OVERHEAD = 60
_C_INDEX: dict[str, Any] = {"files_id": ASCENDING, "n": ASCENDING}
_F_INDEX: dict[str, Any] = {"filename": ASCENDING, "uploadDate": ASCENDING}
def _grid_in_property(
field_name: str,
docstring: str,
read_only: Optional[bool] = False,
closed_only: Optional[bool] = False,
) -> Any:
"""Create a GridIn property."""
warn_str = ""
if docstring.startswith("DEPRECATED,"):
warn_str = (
f"GridIn property '{field_name}' is deprecated and will be removed in PyMongo 5.0"
)
def getter(self: Any) -> Any:
if warn_str:
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
if closed_only and not self._closed:
raise AttributeError("can only get %r on a closed file" % field_name)
# Protect against PHP-237
if field_name == "length":
return self._file.get(field_name, 0)
return self._file.get(field_name, None)
def setter(self: Any, value: Any) -> Any:
if warn_str:
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {field_name: value}})
self._file[field_name] = value
if read_only:
docstring += "\n\nThis attribute is read-only."
elif closed_only:
docstring = "{}\n\n{}".format(
docstring,
"This attribute is read-only and "
"can only be read after :meth:`close` "
"has been called.",
)
if not read_only and not closed_only:
return property(getter, setter, doc=docstring)
return property(getter, doc=docstring)
def _grid_out_property(field_name: str, docstring: str) -> Any:
"""Create a GridOut property."""
warn_str = ""
if docstring.startswith("DEPRECATED,"):
warn_str = (
f"GridOut property '{field_name}' is deprecated and will be removed in PyMongo 5.0"
)
def getter(self: Any) -> Any:
if warn_str:
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
self._ensure_file()
# Protect against PHP-237
if field_name == "length":
return self._file.get(field_name, 0)
return self._file.get(field_name, None)
docstring += "\n\nThis attribute is read-only."
return property(getter, doc=docstring)
def _clear_entity_type_registry(entity: Any, **kwargs: Any) -> Any:
"""Clear the given database/collection object's type registry."""
codecopts = entity.codec_options.with_options(type_registry=None)
return entity.with_options(codec_options=codecopts, **kwargs)
def _disallow_transactions(session: Optional[ClientSession]) -> None:
if session and session.in_transaction:
raise InvalidOperation("GridFS does not support multi-document transactions")
class GridIn:
"""Class to write data to GridFS."""
def __init__(
self, root_collection: Collection, session: Optional[ClientSession] = None, **kwargs: Any
) -> None:
"""Write a file to GridFS
Application developers should generally not need to
instantiate this class directly - instead see the methods
provided by :class:`~gridfs.GridFS`.
Raises :class:`TypeError` if `root_collection` is not an
instance of :class:`~pymongo.collection.Collection`.
Any of the file level options specified in the `GridFS Spec
<http://dochub.mongodb.org/core/gridfsspec>`_ may be passed as
keyword arguments. Any additional keyword arguments will be
set as additional fields on the file document. Valid keyword
arguments include:
- ``"_id"``: unique ID for this file (default:
:class:`~bson.objectid.ObjectId`) - this ``"_id"`` must
not have already been used for another file
- ``"filename"``: human name for the file
- ``"contentType"`` or ``"content_type"``: valid mime-type
for the file
- ``"chunkSize"`` or ``"chunk_size"``: size of each of the
chunks, in bytes (default: 255 kb)
- ``"encoding"``: encoding used for this file. Any :class:`str`
that is written to the file will be converted to :class:`bytes`.
:param root_collection: root collection to write to
:param session: a
:class:`~pymongo.client_session.ClientSession` to use for all
commands
:param kwargs: Any: file level options (see above)
.. versionchanged:: 4.0
Removed the `disable_md5` parameter. See
:ref:`removed-gridfs-checksum` for details.
.. versionchanged:: 3.7
Added the `disable_md5` parameter.
.. versionchanged:: 3.6
Added ``session`` parameter.
.. versionchanged:: 3.0
`root_collection` must use an acknowledged
:attr:`~pymongo.collection.Collection.write_concern`
"""
if not isinstance(root_collection, Collection):
raise TypeError("root_collection must be an instance of Collection")
if not root_collection.write_concern.acknowledged:
raise ConfigurationError("root_collection must use acknowledged write_concern")
_disallow_transactions(session)
# Handle alternative naming
if "content_type" in kwargs:
kwargs["contentType"] = kwargs.pop("content_type")
if "chunk_size" in kwargs:
kwargs["chunkSize"] = kwargs.pop("chunk_size")
coll = _clear_entity_type_registry(root_collection, read_preference=ReadPreference.PRIMARY)
# Defaults
kwargs["_id"] = kwargs.get("_id", ObjectId())
kwargs["chunkSize"] = kwargs.get("chunkSize", DEFAULT_CHUNK_SIZE)
object.__setattr__(self, "_session", session)
object.__setattr__(self, "_coll", coll)
object.__setattr__(self, "_chunks", coll.chunks)
object.__setattr__(self, "_file", kwargs)
object.__setattr__(self, "_buffer", io.BytesIO())
object.__setattr__(self, "_position", 0)
object.__setattr__(self, "_chunk_number", 0)
object.__setattr__(self, "_closed", False)
object.__setattr__(self, "_ensured_index", False)
object.__setattr__(self, "_buffered_docs", [])
object.__setattr__(self, "_buffered_docs_size", 0)
def __create_index(self, collection: Collection, index_key: Any, unique: bool) -> None:
doc = collection.find_one(projection={"_id": 1}, session=self._session)
if doc is None:
try:
index_keys = [
index_spec["key"]
for index_spec in collection.list_indexes(session=self._session)
]
except OperationFailure:
index_keys = []
if index_key not in index_keys:
collection.create_index(index_key.items(), unique=unique, session=self._session)
def __ensure_indexes(self) -> None:
if not object.__getattribute__(self, "_ensured_index"):
_disallow_transactions(self._session)
self.__create_index(self._coll.files, _F_INDEX, False)
self.__create_index(self._coll.chunks, _C_INDEX, True)
object.__setattr__(self, "_ensured_index", True)
def abort(self) -> None:
"""Remove all chunks/files that may have been uploaded and close."""
self._coll.chunks.delete_many({"files_id": self._file["_id"]}, session=self._session)
self._coll.files.delete_one({"_id": self._file["_id"]}, session=self._session)
object.__setattr__(self, "_closed", True)
@property
def closed(self) -> bool:
"""Is this file closed?"""
return self._closed
_id: Any = _grid_in_property("_id", "The ``'_id'`` value for this file.", read_only=True)
filename: Optional[str] = _grid_in_property("filename", "Name of this file.")
name: Optional[str] = _grid_in_property("filename", "Alias for `filename`.")
content_type: Optional[str] = _grid_in_property(
"contentType", "DEPRECATED, will be removed in PyMongo 5.0. Mime-type for this file."
)
length: int = _grid_in_property("length", "Length (in bytes) of this file.", closed_only=True)
chunk_size: int = _grid_in_property("chunkSize", "Chunk size for this file.", read_only=True)
upload_date: datetime.datetime = _grid_in_property(
"uploadDate", "Date that this file was uploaded.", closed_only=True
)
md5: Optional[str] = _grid_in_property(
"md5",
"DEPRECATED, will be removed in PyMongo 5.0. MD5 of the contents of this file if an md5 sum was created.",
closed_only=True,
)
_buffer: io.BytesIO
_closed: bool
_buffered_docs: list[dict[str, Any]]
_buffered_docs_size: int
def __getattr__(self, name: str) -> Any:
if name in self._file:
return self._file[name]
raise AttributeError("GridIn object has no attribute '%s'" % name)
def __setattr__(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
object.__setattr__(self, name, value)
else:
# All other attributes are part of the document in db.fs.files.
# Store them to be sent to server on close() or if closed, send
# them now.
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
def __flush_data(self, data: Any, force: bool = False) -> None:
"""Flush `data` to a chunk."""
self.__ensure_indexes()
assert len(data) <= self.chunk_size
if data:
self._buffered_docs.append(
{"files_id": self._file["_id"], "n": self._chunk_number, "data": data}
)
self._buffered_docs_size += len(data) + _CHUNK_OVERHEAD
if not self._buffered_docs:
return
# Limit to 100,000 chunks or 32MB (+1 chunk) of data.
if (
force
or self._buffered_docs_size >= _UPLOAD_BUFFER_SIZE
or len(self._buffered_docs) >= _UPLOAD_BUFFER_CHUNKS
):
try:
self._chunks.insert_many(self._buffered_docs, session=self._session)
except BulkWriteError as exc:
# For backwards compatibility, raise an insert_one style exception.
write_errors = exc.details["writeErrors"]
for err in write_errors:
if err.get("code") in (11000, 11001, 12582): # Duplicate key errors
self._raise_file_exists(self._file["_id"])
result = {"writeErrors": write_errors}
wces = exc.details["writeConcernErrors"]
if wces:
result["writeConcernError"] = wces[-1]
_check_write_command_response(result)
raise
self._buffered_docs = []
self._buffered_docs_size = 0
self._chunk_number += 1
self._position += len(data)
def __flush_buffer(self, force: bool = False) -> None:
"""Flush the buffer contents out to a chunk."""
self.__flush_data(self._buffer.getvalue(), force=force)
self._buffer.close()
self._buffer = io.BytesIO()
def __flush(self) -> Any:
"""Flush the file to the database."""
try:
self.__flush_buffer(force=True)
# The GridFS spec says length SHOULD be an Int64.
self._file["length"] = Int64(self._position)
self._file["uploadDate"] = datetime.datetime.now(tz=datetime.timezone.utc)
return self._coll.files.insert_one(self._file, session=self._session)
except DuplicateKeyError:
self._raise_file_exists(self._id)
def _raise_file_exists(self, file_id: Any) -> NoReturn:
"""Raise a FileExists exception for the given file_id."""
raise FileExists("file with _id %r already exists" % file_id)
def close(self) -> None:
"""Flush the file and close it.
A closed file cannot be written any more. Calling
:meth:`close` more than once is allowed.
"""
if not self._closed:
self.__flush()
object.__setattr__(self, "_closed", True)
def read(self, size: int = -1) -> NoReturn:
raise io.UnsupportedOperation("read")
def readable(self) -> bool:
return False
def seekable(self) -> bool:
return False
def write(self, data: Any) -> None:
"""Write data to the file. There is no return value.
`data` can be either a string of bytes or a file-like object
(implementing :meth:`read`). If the file has an
:attr:`encoding` attribute, `data` can also be a
:class:`str` instance, which will be encoded as
:attr:`encoding` before being written.
Due to buffering, the data may not actually be written to the
database until the :meth:`close` method is called. Raises
:class:`ValueError` if this file is already closed. Raises
:class:`TypeError` if `data` is not an instance of
:class:`bytes`, a file-like object, or an instance of :class:`str`.
Unicode data is only allowed if the file has an :attr:`encoding`
attribute.
:param data: string of bytes or file-like object to be written
to the file
"""
if self._closed:
raise ValueError("cannot write to a closed file")
try:
# file-like
read = data.read
except AttributeError:
# string
if not isinstance(data, (str, bytes)):
raise TypeError("can only write strings or file-like objects") from None
if isinstance(data, str):
try:
data = data.encode(self.encoding)
except AttributeError:
raise TypeError(
"must specify an encoding for file in order to write str"
) from None
read = io.BytesIO(data).read
if self._buffer.tell() > 0:
# Make sure to flush only when _buffer is complete
space = self.chunk_size - self._buffer.tell()
if space:
try:
to_write = read(space)
except BaseException:
self.abort()
raise
self._buffer.write(to_write)
if len(to_write) < space:
return # EOF or incomplete
self.__flush_buffer()
to_write = read(self.chunk_size)
while to_write and len(to_write) == self.chunk_size:
self.__flush_data(to_write)
to_write = read(self.chunk_size)
self._buffer.write(to_write)
def writelines(self, sequence: Iterable[Any]) -> None:
"""Write a sequence of strings to the file.
Does not add separators.
"""
for line in sequence:
self.write(line)
def writeable(self) -> bool:
return True
def __enter__(self) -> GridIn:
"""Support for the context manager protocol."""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any:
"""Support for the context manager protocol.
Close the file if no exceptions occur and allow exceptions to propagate.
"""
if exc_type is None:
# No exceptions happened.
self.close()
else:
# Something happened, at minimum mark as closed.
object.__setattr__(self, "_closed", True)
# propagate exceptions
return False
class GridOut(io.IOBase):
"""Class to read data out of GridFS."""
def __init__(
self,
root_collection: Collection,
file_id: Optional[int] = None,
file_document: Optional[Any] = None,
session: Optional[ClientSession] = None,
) -> None:
"""Read a file from GridFS
Application developers should generally not need to
instantiate this class directly - instead see the methods
provided by :class:`~gridfs.GridFS`.
Either `file_id` or `file_document` must be specified,
`file_document` will be given priority if present. Raises
:class:`TypeError` if `root_collection` is not an instance of
:class:`~pymongo.collection.Collection`.
:param root_collection: root collection to read from
:param file_id: value of ``"_id"`` for the file to read
:param file_document: file document from
`root_collection.files`
:param session: a
:class:`~pymongo.client_session.ClientSession` to use for all
commands
.. versionchanged:: 3.8
For better performance and to better follow the GridFS spec,
:class:`GridOut` now uses a single cursor to read all the chunks in
the file.
.. versionchanged:: 3.6
Added ``session`` parameter.
.. versionchanged:: 3.0
Creating a GridOut does not immediately retrieve the file metadata
from the server. Metadata is fetched when first needed.
"""
if not isinstance(root_collection, Collection):
raise TypeError("root_collection must be an instance of Collection")
_disallow_transactions(session)
root_collection = _clear_entity_type_registry(root_collection)
super().__init__()
self.__chunks = root_collection.chunks
self.__files = root_collection.files
self.__file_id = file_id
self.__buffer = EMPTY
# Start position within the current buffered chunk.
self.__buffer_pos = 0
self.__chunk_iter = None
# Position within the total file.
self.__position = 0
self._file = file_document
self._session = session
_id: Any = _grid_out_property("_id", "The ``'_id'`` value for this file.")
filename: str = _grid_out_property("filename", "Name of this file.")
name: str = _grid_out_property("filename", "Alias for `filename`.")
content_type: Optional[str] = _grid_out_property(
"contentType", "DEPRECATED, will be removed in PyMongo 5.0. Mime-type for this file."
)
length: int = _grid_out_property("length", "Length (in bytes) of this file.")
chunk_size: int = _grid_out_property("chunkSize", "Chunk size for this file.")
upload_date: datetime.datetime = _grid_out_property(
"uploadDate", "Date that this file was first uploaded."
)
aliases: Optional[list[str]] = _grid_out_property(
"aliases", "DEPRECATED, will be removed in PyMongo 5.0. List of aliases for this file."
)
metadata: Optional[Mapping[str, Any]] = _grid_out_property(
"metadata", "Metadata attached to this file."
)
md5: Optional[str] = _grid_out_property(
"md5",
"DEPRECATED, will be removed in PyMongo 5.0. MD5 of the contents of this file if an md5 sum was created.",
)
_file: Any
__chunk_iter: Any
def _ensure_file(self) -> None:
if not self._file:
_disallow_transactions(self._session)
self._file = self.__files.find_one({"_id": self.__file_id}, session=self._session)
if not self._file:
raise NoFile(
f"no file in gridfs collection {self.__files!r} with _id {self.__file_id!r}"
)
def __getattr__(self, name: str) -> Any:
self._ensure_file()
if name in self._file:
return self._file[name]
raise AttributeError("GridOut object has no attribute '%s'" % name)
def readable(self) -> bool:
return True
def readchunk(self) -> bytes:
"""Reads a chunk at a time. If the current position is within a
chunk the remainder of the chunk is returned.
"""
received = len(self.__buffer) - self.__buffer_pos
chunk_data = EMPTY
chunk_size = int(self.chunk_size)
if received > 0:
chunk_data = self.__buffer[self.__buffer_pos :]
elif self.__position < int(self.length):
chunk_number = int((received + self.__position) / chunk_size)
if self.__chunk_iter is None:
self.__chunk_iter = _GridOutChunkIterator(
self, self.__chunks, self._session, chunk_number
)
chunk = self.__chunk_iter.next()
chunk_data = chunk["data"][self.__position % chunk_size :]
if not chunk_data:
raise CorruptGridFile("truncated chunk")
self.__position += len(chunk_data)
self.__buffer = EMPTY
self.__buffer_pos = 0
return chunk_data
def _read_size_or_line(self, size: int = -1, line: bool = False) -> bytes:
"""Internal read() and readline() helper."""
self._ensure_file()
remainder = int(self.length) - self.__position
if size < 0 or size > remainder:
size = remainder
if size == 0:
return EMPTY
received = 0
data = []
while received < size:
needed = size - received
if self.__buffer:
# Optimization: Read the buffer with zero byte copies.
buf = self.__buffer
chunk_start = self.__buffer_pos
chunk_data = memoryview(buf)[self.__buffer_pos :]
self.__buffer = EMPTY
self.__buffer_pos = 0
self.__position += len(chunk_data)
else:
buf = self.readchunk()
chunk_start = 0
chunk_data = memoryview(buf)
if line:
pos = buf.find(NEWLN, chunk_start, chunk_start + needed) - chunk_start
if pos >= 0:
# Decrease size to exit the loop.
size = received + pos + 1
needed = pos + 1
if len(chunk_data) > needed:
data.append(chunk_data[:needed])
# Optimization: Save the buffer with zero byte copies.
self.__buffer = buf
self.__buffer_pos = chunk_start + needed
self.__position -= len(self.__buffer) - self.__buffer_pos
else:
data.append(chunk_data)
received += len(chunk_data)
# Detect extra chunks after reading the entire file.
if size == remainder and self.__chunk_iter:
try:
self.__chunk_iter.next()
except StopIteration:
pass
return b"".join(data)
def read(self, size: int = -1) -> bytes:
"""Read at most `size` bytes from the file (less if there
isn't enough data).
The bytes are returned as an instance of :class:`bytes`
If `size` is negative or omitted all data is read.
:param size: the number of bytes to read
.. versionchanged:: 3.8
This method now only checks for extra chunks after reading the
entire file. Previously, this method would check for extra chunks
on every call.
"""
return self._read_size_or_line(size=size)
def readline(self, size: int = -1) -> bytes: # type: ignore[override]
"""Read one line or up to `size` bytes from the file.
:param size: the maximum number of bytes to read
"""
return self._read_size_or_line(size=size, line=True)
def tell(self) -> int:
"""Return the current position of this file."""
return self.__position
def seek(self, pos: int, whence: int = _SEEK_SET) -> int:
"""Set the current position of this file.
:param pos: the position (or offset if using relative
positioning) to seek to
:param whence: where to seek
from. :attr:`os.SEEK_SET` (``0``) for absolute file
positioning, :attr:`os.SEEK_CUR` (``1``) to seek relative
to the current position, :attr:`os.SEEK_END` (``2``) to
seek relative to the file's end.
.. versionchanged:: 4.1
The method now returns the new position in the file, to
conform to the behavior of :meth:`io.IOBase.seek`.
"""
if whence == _SEEK_SET:
new_pos = pos
elif whence == _SEEK_CUR:
new_pos = self.__position + pos
elif whence == _SEEK_END:
new_pos = int(self.length) + pos
else:
raise OSError(22, "Invalid value for `whence`")
if new_pos < 0:
raise OSError(22, "Invalid value for `pos` - must be positive")
# Optimization, continue using the same buffer and chunk iterator.
if new_pos == self.__position:
return new_pos
self.__position = new_pos
self.__buffer = EMPTY
self.__buffer_pos = 0
if self.__chunk_iter:
self.__chunk_iter.close()
self.__chunk_iter = None
return new_pos
def seekable(self) -> bool:
return True
def __iter__(self) -> GridOut:
"""Return an iterator over all of this file's data.
The iterator will return lines (delimited by ``b'\\n'``) of
:class:`bytes`. This can be useful when serving files
using a webserver that handles such an iterator efficiently.
.. versionchanged:: 3.8
The iterator now raises :class:`CorruptGridFile` when encountering
any truncated, missing, or extra chunk in a file. The previous
behavior was to only raise :class:`CorruptGridFile` on a missing
chunk.
.. versionchanged:: 4.0
The iterator now iterates over *lines* in the file, instead
of chunks, to conform to the base class :py:class:`io.IOBase`.
Use :meth:`GridOut.readchunk` to read chunk by chunk instead
of line by line.
"""
return self
def close(self) -> None:
"""Make GridOut more generically file-like."""
if self.__chunk_iter:
self.__chunk_iter.close()
self.__chunk_iter = None
super().close()
def write(self, value: Any) -> NoReturn:
raise io.UnsupportedOperation("write")
def writelines(self, lines: Any) -> NoReturn:
raise io.UnsupportedOperation("writelines")
def writable(self) -> bool:
return False
def __enter__(self) -> GridOut:
"""Makes it possible to use :class:`GridOut` files
with the context manager protocol.
"""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any:
"""Makes it possible to use :class:`GridOut` files
with the context manager protocol.
"""
self.close()
return False
def fileno(self) -> NoReturn:
raise io.UnsupportedOperation("fileno")
def flush(self) -> None:
# GridOut is read-only, so flush does nothing.
pass
def isatty(self) -> bool:
return False
def truncate(self, size: Optional[int] = None) -> NoReturn:
# See https://docs.python.org/3/library/io.html#io.IOBase.writable
# for why truncate has to raise.
raise io.UnsupportedOperation("truncate")
# Override IOBase.__del__ otherwise it will lead to __getattr__ on
# __IOBase_closed which calls _ensure_file and potentially performs I/O.
# We cannot do I/O in __del__ since it can lead to a deadlock.
def __del__(self) -> None:
pass
class _GridOutChunkIterator:
"""Iterates over a file's chunks using a single cursor.
Raises CorruptGridFile when encountering any truncated, missing, or extra
chunk in a file.
"""
def __init__(
self,
grid_out: GridOut,
chunks: Collection,
session: Optional[ClientSession],
next_chunk: Any,
) -> None:
self._id = grid_out._id
self._chunk_size = int(grid_out.chunk_size)
self._length = int(grid_out.length)
self._chunks = chunks
self._session = session
self._next_chunk = next_chunk
self._num_chunks = math.ceil(float(self._length) / self._chunk_size)
self._cursor = None
_cursor: Optional[Cursor]
def expected_chunk_length(self, chunk_n: int) -> int:
if chunk_n < self._num_chunks - 1:
return self._chunk_size
return self._length - (self._chunk_size * (self._num_chunks - 1))
def __iter__(self) -> _GridOutChunkIterator:
return self
def _create_cursor(self) -> None:
filter = {"files_id": self._id}
if self._next_chunk > 0:
filter["n"] = {"$gte": self._next_chunk}
_disallow_transactions(self._session)
self._cursor = self._chunks.find(filter, sort=[("n", 1)], session=self._session)
def _next_with_retry(self) -> Mapping[str, Any]:
"""Return the next chunk and retry once on CursorNotFound.
We retry on CursorNotFound to maintain backwards compatibility in
cases where two calls to read occur more than 10 minutes apart (the
server's default cursor timeout).
"""
if self._cursor is None:
self._create_cursor()
assert self._cursor is not None
try:
return self._cursor.next()
except CursorNotFound:
self._cursor.close()
self._create_cursor()
return self._cursor.next()
def next(self) -> Mapping[str, Any]:
try:
chunk = self._next_with_retry()
except StopIteration:
if self._next_chunk >= self._num_chunks:
raise
raise CorruptGridFile("no chunk #%d" % self._next_chunk) from None
if chunk["n"] != self._next_chunk:
self.close()
raise CorruptGridFile(
"Missing chunk: expected chunk #%d but found "
"chunk with n=%d" % (self._next_chunk, chunk["n"])
)
if chunk["n"] >= self._num_chunks:
# According to spec, ignore extra chunks if they are empty.
if len(chunk["data"]):
self.close()
raise CorruptGridFile(
"Extra chunk found: expected %d chunks but found "
"chunk with n=%d" % (self._num_chunks, chunk["n"])
)
expected_length = self.expected_chunk_length(chunk["n"])
if len(chunk["data"]) != expected_length:
self.close()
raise CorruptGridFile(
"truncated chunk #%d: expected chunk length to be %d but "
"found chunk with length %d" % (chunk["n"], expected_length, len(chunk["data"]))
)
self._next_chunk += 1
return chunk
__next__ = next
def close(self) -> None:
if self._cursor:
self._cursor.close()
self._cursor = None
class GridOutIterator:
def __init__(self, grid_out: GridOut, chunks: Collection, session: ClientSession):
self.__chunk_iter = _GridOutChunkIterator(grid_out, chunks, session, 0)
def __iter__(self) -> GridOutIterator:
return self
def next(self) -> bytes:
chunk = self.__chunk_iter.next()
return bytes(chunk["data"])
__next__ = next
class GridOutCursor(Cursor):
"""A cursor / iterator for returning GridOut objects as the result
of an arbitrary query against the GridFS files collection.
"""
def __init__(
self,
collection: Collection,
filter: Optional[Mapping[str, Any]] = None,
skip: int = 0,
limit: int = 0,
no_cursor_timeout: bool = False,
sort: Optional[Any] = None,
batch_size: int = 0,
session: Optional[ClientSession] = None,
) -> None:
"""Create a new cursor, similar to the normal
:class:`~pymongo.cursor.Cursor`.
Should not be called directly by application developers - see
the :class:`~gridfs.GridFS` method :meth:`~gridfs.GridFS.find` instead.
.. versionadded 2.7
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
"""
_disallow_transactions(session)
collection = _clear_entity_type_registry(collection)
# Hold on to the base "fs" collection to create GridOut objects later.
self.__root_collection = collection
super().__init__(
collection.files,
filter,
skip=skip,
limit=limit,
no_cursor_timeout=no_cursor_timeout,
sort=sort,
batch_size=batch_size,
session=session,
)
def next(self) -> GridOut:
"""Get next GridOut object from cursor."""
_disallow_transactions(self.session)
next_file = super().next()
return GridOut(self.__root_collection, file_document=next_file, session=self.session)
__next__ = next
def add_option(self, *args: Any, **kwargs: Any) -> NoReturn:
raise NotImplementedError("Method does not exist for GridOutCursor")
def remove_option(self, *args: Any, **kwargs: Any) -> NoReturn:
raise NotImplementedError("Method does not exist for GridOutCursor")
def _clone_base(self, session: Optional[ClientSession]) -> GridOutCursor:
"""Creates an empty GridOutCursor for information to be copied into."""
return GridOutCursor(self.__root_collection, session=session)
from gridfs.synchronous.grid_file import * # noqa: F403

149
gridfs/grid_file_shared.py Normal file
View File

@ -0,0 +1,149 @@
from __future__ import annotations
import os
import warnings
from typing import Any, Optional
from pymongo import ASCENDING
from pymongo.asynchronous.common import MAX_MESSAGE_SIZE
from pymongo.errors import InvalidOperation
_SEEK_SET = os.SEEK_SET
_SEEK_CUR = os.SEEK_CUR
_SEEK_END = os.SEEK_END
EMPTY = b""
NEWLN = b"\n"
"""Default chunk size, in bytes."""
# Slightly under a power of 2, to work well with server's record allocations.
DEFAULT_CHUNK_SIZE = 255 * 1024
# The number of chunked bytes to buffer before calling insert_many.
_UPLOAD_BUFFER_SIZE = MAX_MESSAGE_SIZE
# The number of chunk documents to buffer before calling insert_many.
_UPLOAD_BUFFER_CHUNKS = 100000
# Rough BSON overhead of a chunk document not including the chunk data itself.
# Essentially len(encode({"_id": ObjectId(), "files_id": ObjectId(), "n": 1, "data": ""}))
_CHUNK_OVERHEAD = 60
_C_INDEX: dict[str, Any] = {"files_id": ASCENDING, "n": ASCENDING}
_F_INDEX: dict[str, Any] = {"filename": ASCENDING, "uploadDate": ASCENDING}
def _a_grid_in_property(
field_name: str,
docstring: str,
read_only: Optional[bool] = False,
closed_only: Optional[bool] = False,
) -> Any:
"""Create a GridIn property."""
def getter(self: Any) -> Any:
if closed_only and not self._closed:
raise AttributeError("can only get %r on a closed file" % field_name)
# Protect against PHP-237
if field_name == "length":
return self._file.get(field_name, 0)
return self._file.get(field_name, None)
if read_only:
docstring += "\n\nThis attribute is read-only."
elif closed_only:
docstring = "{}\n\n{}".format(
docstring,
"This attribute is read-only and "
"can only be read after :meth:`close` "
"has been called.",
)
return property(getter, doc=docstring)
def _a_grid_out_property(field_name: str, docstring: str) -> Any:
"""Create a GridOut property."""
def a_getter(self: Any) -> Any:
if not self._file:
raise InvalidOperation(
"You must call GridOut.open() before accessing " "the %s property" % field_name
)
# Protect against PHP-237
if field_name == "length":
return self._file.get(field_name, 0)
return self._file.get(field_name, None)
docstring += "\n\nThis attribute is read-only."
return property(a_getter, doc=docstring)
def _grid_in_property(
field_name: str,
docstring: str,
read_only: Optional[bool] = False,
closed_only: Optional[bool] = False,
) -> Any:
"""Create a GridIn property."""
warn_str = ""
if docstring.startswith("DEPRECATED,"):
warn_str = (
f"GridIn property '{field_name}' is deprecated and will be removed in PyMongo 5.0"
)
def getter(self: Any) -> Any:
if warn_str:
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
if closed_only and not self._closed:
raise AttributeError("can only get %r on a closed file" % field_name)
# Protect against PHP-237
if field_name == "length":
return self._file.get(field_name, 0)
return self._file.get(field_name, None)
def setter(self: Any, value: Any) -> Any:
if warn_str:
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {field_name: value}})
self._file[field_name] = value
if read_only:
docstring += "\n\nThis attribute is read-only."
elif closed_only:
docstring = "{}\n\n{}".format(
docstring,
"This attribute is read-only and "
"can only be read after :meth:`close` "
"has been called.",
)
if not read_only and not closed_only:
return property(getter, setter, doc=docstring)
return property(getter, doc=docstring)
def _grid_out_property(field_name: str, docstring: str) -> Any:
"""Create a GridOut property."""
warn_str = ""
if docstring.startswith("DEPRECATED,"):
warn_str = (
f"GridOut property '{field_name}' is deprecated and will be removed in PyMongo 5.0"
)
def getter(self: Any) -> Any:
if warn_str:
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
self.open()
# Protect against PHP-237
if field_name == "length":
return self._file.get(field_name, 0)
return self._file.get(field_name, None)
docstring += "\n\nThis attribute is read-only."
return property(getter, doc=docstring)
def _clear_entity_type_registry(entity: Any, **kwargs: Any) -> Any:
"""Clear the given database/collection object's type registry."""
codecopts = entity.codec_options.with_options(type_registry=None)
return entity.with_options(codec_options=codecopts, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -6,3 +6,10 @@ exclude = (?x)(
^test/mypy_fails/*.*$
| ^test/conftest.py$
)
[mypy-pymongo.synchronous.*,gridfs.synchronous.*,test.synchronous.*]
warn_unused_ignores = false
disable_error_code = unused-coroutine
[mypy-pymongo.asynchronous.*,test.asynchronous.*]
warn_unused_ignores = false

View File

@ -33,6 +33,7 @@ __all__ = [
"MIN_SUPPORTED_WIRE_VERSION",
"CursorType",
"MongoClient",
"AsyncMongoClient",
"DeleteMany",
"DeleteOne",
"IndexModel",
@ -87,11 +88,12 @@ TEXT = "text"
from pymongo import _csot
from pymongo._version import __version__, get_version_string, version_tuple
from pymongo.collection import ReturnDocument
from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.cursor import CursorType
from pymongo.mongo_client import MongoClient
from pymongo.operations import (
from pymongo.synchronous.collection import ReturnDocument
from pymongo.synchronous.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.operations import (
DeleteMany,
DeleteOne,
IndexModel,
@ -100,7 +102,7 @@ from pymongo.operations import (
UpdateMany,
UpdateOne,
)
from pymongo.read_preferences import ReadPreference
from pymongo.synchronous.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
version = __version__

View File

@ -17,6 +17,7 @@
from __future__ import annotations
import functools
import inspect
import time
from collections import deque
from contextlib import AbstractContextManager
@ -96,16 +97,27 @@ F = TypeVar("F", bound=Callable[..., Any])
def apply(func: F) -> F:
"""Apply the client's timeoutMS to this operation."""
"""Apply the client's timeoutMS to this operation. Can wrap both asynchronous and synchronous methods"""
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
if get_timeout() is None:
timeout = self._timeout
if timeout is not None:
with _TimeoutContext(timeout):
return func(self, *args, **kwargs)
return func(self, *args, **kwargs)
@functools.wraps(func)
async def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
if get_timeout() is None:
timeout = self._timeout
if timeout is not None:
with _TimeoutContext(timeout):
return await func(self, *args, **kwargs)
return await func(self, *args, **kwargs)
else:
@functools.wraps(func)
def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
if get_timeout() is None:
timeout = self._timeout
if timeout is not None:
with _TimeoutContext(timeout):
return func(self, *args, **kwargs)
return func(self, *args, **kwargs)
return cast(F, csot_wrapper)

View File

View File

@ -0,0 +1,257 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Perform aggregation operations on a collection or database."""
from __future__ import annotations
from collections.abc import Callable, Mapping, MutableMapping
from typing import TYPE_CHECKING, Any, Optional, Union
from pymongo.asynchronous import common
from pymongo.asynchronous.collation import validate_collation_or_none
from pymongo.asynchronous.read_preferences import ReadPreference, _AggWritePref
from pymongo.errors import ConfigurationError
if TYPE_CHECKING:
from pymongo.asynchronous.client_session import ClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.read_preferences import _ServerMode
from pymongo.asynchronous.server import Server
from pymongo.asynchronous.typings import _DocumentType, _Pipeline
_IS_SYNC = False
class _AggregationCommand:
"""The internal abstract base class for aggregation cursors.
Should not be called directly by application developers. Use
:meth:`pymongo.collection.AsyncCollection.aggregate`, or
:meth:`pymongo.database.AsyncDatabase.aggregate` instead.
"""
def __init__(
self,
target: Union[AsyncDatabase, AsyncCollection],
cursor_class: type[AsyncCommandCursor],
pipeline: _Pipeline,
options: MutableMapping[str, Any],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
user_fields: Optional[MutableMapping[str, Any]] = None,
result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None,
comment: Any = None,
) -> None:
if "explain" in options:
raise ConfigurationError(
"The explain option is not supported. Use AsyncDatabase.command instead."
)
self._target = target
pipeline = common.validate_list("pipeline", pipeline)
self._pipeline = pipeline
self._performs_write = False
if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
self._performs_write = True
common.validate_is_mapping("options", options)
if let is not None:
common.validate_is_mapping("let", let)
options["let"] = let
if comment is not None:
options["comment"] = comment
self._options = options
# This is the batchSize that will be used for setting the initial
# batchSize for the cursor, as well as the subsequent getMores.
self._batch_size = common.validate_non_negative_integer_or_none(
"batchSize", self._options.pop("batchSize", None)
)
# If the cursor option is already specified, avoid overriding it.
self._options.setdefault("cursor", {})
# If the pipeline performs a write, we ignore the initial batchSize
# since the server doesn't return results in this case.
if self._batch_size is not None and not self._performs_write:
self._options["cursor"]["batchSize"] = self._batch_size
self._cursor_class = cursor_class
self._explicit_session = explicit_session
self._user_fields = user_fields
self._result_processor = result_processor
self._collation = validate_collation_or_none(options.pop("collation", None))
self._max_await_time_ms = options.pop("maxAwaitTimeMS", None)
self._write_preference: Optional[_AggWritePref] = None
@property
def _aggregation_target(self) -> Union[str, int]:
"""The argument to pass to the aggregate command."""
raise NotImplementedError
@property
def _cursor_namespace(self) -> str:
"""The namespace in which the aggregate command is run."""
raise NotImplementedError
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> AsyncCollection:
"""The AsyncCollection used for the aggregate command cursor."""
raise NotImplementedError
@property
def _database(self) -> AsyncDatabase:
"""The database against which the aggregation command is run."""
raise NotImplementedError
def get_read_preference(
self, session: Optional[ClientSession]
) -> Union[_AggWritePref, _ServerMode]:
if self._write_preference:
return self._write_preference
pref = self._target._read_preference_for(session)
if self._performs_write and pref != ReadPreference.PRIMARY:
self._write_preference = pref = _AggWritePref(pref) # type: ignore[assignment]
return pref
async def get_cursor(
self,
session: Optional[ClientSession],
server: Server,
conn: Connection,
read_preference: _ServerMode,
) -> AsyncCommandCursor[_DocumentType]:
# Serialize command.
cmd = {"aggregate": self._aggregation_target, "pipeline": self._pipeline}
cmd.update(self._options)
# Apply this target's read concern if:
# readConcern has not been specified as a kwarg and either
# - server version is >= 4.2 or
# - server version is >= 3.2 and pipeline doesn't use $out
if ("readConcern" not in cmd) and (
not self._performs_write or (conn.max_wire_version >= 8)
):
read_concern = self._target.read_concern
else:
read_concern = None
# Apply this target's write concern if:
# writeConcern has not been specified as a kwarg and pipeline doesn't
# perform a write operation
if "writeConcern" not in cmd and self._performs_write:
write_concern = self._target._write_concern_for(session)
else:
write_concern = None
# Run command.
result = await conn.command(
self._database.name,
cmd,
read_preference,
self._target.codec_options,
parse_write_concern_error=True,
read_concern=read_concern,
write_concern=write_concern,
collation=self._collation,
session=session,
client=self._database.client,
user_fields=self._user_fields,
)
if self._result_processor:
self._result_processor(result, conn)
# Extract cursor from result or mock/fake one if necessary.
if "cursor" in result:
cursor = result["cursor"]
else:
# Unacknowledged $out/$merge write. Fake a cursor.
cursor = {
"id": 0,
"firstBatch": result.get("result", []),
"ns": self._cursor_namespace,
}
# Create and return cursor instance.
cmd_cursor = self._cursor_class(
self._cursor_collection(cursor),
cursor,
conn.address,
batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms,
session=session,
explicit_session=self._explicit_session,
comment=self._options.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)
return cmd_cursor
class _CollectionAggregationCommand(_AggregationCommand):
_target: AsyncCollection
@property
def _aggregation_target(self) -> str:
return self._target.name
@property
def _cursor_namespace(self) -> str:
return self._target.full_name
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
"""The AsyncCollection used for the aggregate command cursor."""
return self._target
@property
def _database(self) -> AsyncDatabase:
return self._target.database
class _CollectionRawAggregationCommand(_CollectionAggregationCommand):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# For raw-batches, we set the initial batchSize for the cursor to 0.
if not self._performs_write:
self._options["cursor"]["batchSize"] = 0
class _DatabaseAggregationCommand(_AggregationCommand):
_target: AsyncDatabase
@property
def _aggregation_target(self) -> int:
return 1
@property
def _cursor_namespace(self) -> str:
return f"{self._target.name}.$cmd.aggregate"
@property
def _database(self) -> AsyncDatabase:
return self._target
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
"""The AsyncCollection used for the aggregate command cursor."""
# AsyncCollection level aggregate may not always return the "ns" field
# according to our MockupDB tests. Let's handle that case for db level
# aggregate too by defaulting to the <db>.$cmd.aggregate namespace.
_, collname = cursor.get("ns", self._cursor_namespace).split(".", 1)
return self._database[collname]

View File

@ -0,0 +1,663 @@
# Copyright 2013-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Authentication helpers."""
from __future__ import annotations
import functools
import hashlib
import hmac
import os
import socket
import typing
from base64 import standard_b64decode, standard_b64encode
from collections import namedtuple
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
Mapping,
MutableMapping,
Optional,
cast,
)
from urllib.parse import quote
from bson.binary import Binary
from pymongo.asynchronous.auth_aws import _authenticate_aws
from pymongo.asynchronous.auth_oidc import (
_authenticate_oidc,
_get_authenticator,
_OIDCAzureCallback,
_OIDCGCPCallback,
_OIDCProperties,
_OIDCTestCallback,
)
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.saslprep import saslprep
if TYPE_CHECKING:
from pymongo.asynchronous.hello import Hello
from pymongo.asynchronous.pool import Connection
HAVE_KERBEROS = True
_USE_PRINCIPAL = False
try:
import winkerberos as kerberos # type:ignore[import]
if tuple(map(int, kerberos.__version__.split(".")[:2])) >= (0, 5):
_USE_PRINCIPAL = True
except ImportError:
try:
import kerberos # type:ignore[import]
except ImportError:
HAVE_KERBEROS = False
_IS_SYNC = False
MECHANISMS = frozenset(
[
"GSSAPI",
"MONGODB-CR",
"MONGODB-OIDC",
"MONGODB-X509",
"MONGODB-AWS",
"PLAIN",
"SCRAM-SHA-1",
"SCRAM-SHA-256",
"DEFAULT",
]
)
"""The authentication mechanisms supported by PyMongo."""
class _Cache:
__slots__ = ("data",)
_hash_val = hash("_Cache")
def __init__(self) -> None:
self.data = None
def __eq__(self, other: object) -> bool:
# Two instances must always compare equal.
if isinstance(other, _Cache):
return True
return NotImplemented
def __ne__(self, other: object) -> bool:
if isinstance(other, _Cache):
return False
return NotImplemented
def __hash__(self) -> int:
return self._hash_val
MongoCredential = namedtuple(
"MongoCredential",
["mechanism", "source", "username", "password", "mechanism_properties", "cache"],
)
"""A hashable namedtuple of values used for authentication."""
GSSAPIProperties = namedtuple(
"GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"]
)
"""Mechanism properties for GSSAPI authentication."""
_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
"""Mechanism properties for MONGODB-AWS authentication."""
def _build_credentials_tuple(
mech: str,
source: Optional[str],
user: str,
passwd: str,
extra: Mapping[str, Any],
database: Optional[str],
) -> MongoCredential:
"""Build and return a mechanism specific credentials tuple."""
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
raise ConfigurationError(f"{mech} requires a username.")
if mech == "GSSAPI":
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for GSSAPI")
properties = extra.get("authmechanismproperties", {})
service_name = properties.get("SERVICE_NAME", "mongodb")
canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False))
service_realm = properties.get("SERVICE_REALM")
props = GSSAPIProperties(
service_name=service_name,
canonicalize_host_name=canonicalize,
service_realm=service_realm,
)
# Source is always $external.
return MongoCredential(mech, "$external", user, passwd, props, None)
elif mech == "MONGODB-X509":
if passwd is not None:
raise ConfigurationError("Passwords are not supported by MONGODB-X509")
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for MONGODB-X509")
# Source is always $external, user can be None.
return MongoCredential(mech, "$external", user, None, None, None)
elif mech == "MONGODB-AWS":
if user is not None and passwd is None:
raise ConfigurationError("username without a password is not supported by MONGODB-AWS")
if source is not None and source != "$external":
raise ConfigurationError(
"authentication source must be $external or None for MONGODB-AWS"
)
properties = extra.get("authmechanismproperties", {})
aws_session_token = properties.get("AWS_SESSION_TOKEN")
aws_props = _AWSProperties(aws_session_token=aws_session_token)
# user can be None for temporary link-local EC2 credentials.
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
elif mech == "MONGODB-OIDC":
properties = extra.get("authmechanismproperties", {})
callback = properties.get("OIDC_CALLBACK")
human_callback = properties.get("OIDC_HUMAN_CALLBACK")
environ = properties.get("ENVIRONMENT")
token_resource = properties.get("TOKEN_RESOURCE", "")
default_allowed = [
"*.mongodb.net",
"*.mongodb-dev.net",
"*.mongodb-qa.net",
"*.mongodbgov.net",
"localhost",
"127.0.0.1",
"::1",
]
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
msg = (
"authentication with MONGODB-OIDC requires providing either a callback or a environment"
)
if passwd is not None:
msg = "password is not supported by MONGODB-OIDC"
raise ConfigurationError(msg)
if callback or human_callback:
if environ is not None:
raise ConfigurationError(msg)
if callback and human_callback:
msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK"
raise ConfigurationError(msg)
elif environ is not None:
if environ == "test":
if user is not None:
msg = "test environment for MONGODB-OIDC does not support username"
raise ConfigurationError(msg)
callback = _OIDCTestCallback()
elif environ == "azure":
passwd = None
if not token_resource:
raise ConfigurationError(
"Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCAzureCallback(token_resource)
elif environ == "gcp":
passwd = None
if not token_resource:
raise ConfigurationError(
"GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCGCPCallback(token_resource)
else:
raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}")
else:
raise ConfigurationError(msg)
oidc_props = _OIDCProperties(
callback=callback,
human_callback=human_callback,
environment=environ,
allowed_hosts=allowed_hosts,
token_resource=token_resource,
username=user,
)
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
elif mech == "PLAIN":
source_database = source or database or "$external"
return MongoCredential(mech, source_database, user, passwd, None, None)
else:
source_database = source or database or "admin"
if passwd is None:
raise ConfigurationError("A password is required.")
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
def _xor(fir: bytes, sec: bytes) -> bytes:
"""XOR two byte strings together."""
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]:
"""Split a scram response into key, value pairs."""
return dict(
typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1))
for item in response.split(b",")
)
def _authenticate_scram_start(
credentials: MongoCredential, mechanism: str
) -> tuple[bytes, bytes, MutableMapping[str, Any]]:
username = credentials.username
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
nonce = standard_b64encode(os.urandom(32))
first_bare = b"n=" + user + b",r=" + nonce
cmd = {
"saslStart": 1,
"mechanism": mechanism,
"payload": Binary(b"n,," + first_bare),
"autoAuthorize": 1,
"options": {"skipEmptyExchange": True},
}
return nonce, first_bare, cmd
async def _authenticate_scram(
credentials: MongoCredential, conn: Connection, mechanism: str
) -> None:
"""Authenticate using SCRAM."""
username = credentials.username
if mechanism == "SCRAM-SHA-256":
digest = "sha256"
digestmod = hashlib.sha256
data = saslprep(credentials.password).encode("utf-8")
else:
digest = "sha1"
digestmod = hashlib.sha1
data = _password_digest(username, credentials.password).encode("utf-8")
source = credentials.source
cache = credentials.cache
# Make local
_hmac = hmac.HMAC
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
assert isinstance(ctx, _ScramContext)
assert ctx.scram_data is not None
nonce, first_bare = ctx.scram_data
res = ctx.speculative_authenticate
else:
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
res = await conn.command(source, cmd)
assert res is not None
server_first = res["payload"]
parsed = _parse_scram_response(server_first)
iterations = int(parsed[b"i"])
if iterations < 4096:
raise OperationFailure("Server returned an invalid iteration count.")
salt = parsed[b"s"]
rnonce = parsed[b"r"]
if not rnonce.startswith(nonce):
raise OperationFailure("Server returned an invalid nonce.")
without_proof = b"c=biws,r=" + rnonce
if cache.data:
client_key, server_key, csalt, citerations = cache.data
else:
client_key, server_key, csalt, citerations = None, None, None, None
# Salt and / or iterations could change for a number of different
# reasons. Either changing invalidates the cache.
if not client_key or salt != csalt or iterations != citerations:
salted_pass = hashlib.pbkdf2_hmac(digest, data, standard_b64decode(salt), iterations)
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
cache.data = (client_key, server_key, salt, iterations)
stored_key = digestmod(client_key).digest()
auth_msg = b",".join((first_bare, server_first, without_proof))
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()
client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig))
client_final = b",".join((without_proof, client_proof))
server_sig = standard_b64encode(_hmac(server_key, auth_msg, digestmod).digest())
cmd = {
"saslContinue": 1,
"conversationId": res["conversationId"],
"payload": Binary(client_final),
}
res = await conn.command(source, cmd)
parsed = _parse_scram_response(res["payload"])
if not hmac.compare_digest(parsed[b"v"], server_sig):
raise OperationFailure("Server returned an invalid signature.")
# A third empty challenge may be required if the server does not support
# skipEmptyExchange: SERVER-44857.
if not res["done"]:
cmd = {
"saslContinue": 1,
"conversationId": res["conversationId"],
"payload": Binary(b""),
}
res = await conn.command(source, cmd)
if not res["done"]:
raise OperationFailure("SASL conversation failed to complete.")
def _password_digest(username: str, password: str) -> str:
"""Get a password digest to use for authentication."""
if not isinstance(password, str):
raise TypeError("password must be an instance of str")
if len(password) == 0:
raise ValueError("password can't be empty")
if not isinstance(username, str):
raise TypeError("username must be an instance of str")
md5hash = hashlib.md5() # noqa: S324
data = f"{username}:mongo:{password}"
md5hash.update(data.encode("utf-8"))
return md5hash.hexdigest()
def _auth_key(nonce: str, username: str, password: str) -> str:
"""Get an auth key to use for authentication."""
digest = _password_digest(username, password)
md5hash = hashlib.md5() # noqa: S324
data = f"{nonce}{username}{digest}"
md5hash.update(data.encode("utf-8"))
return md5hash.hexdigest()
def _canonicalize_hostname(hostname: str) -> str:
"""Canonicalize hostname following MIT-krb5 behavior."""
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
)[0]
try:
name = socket.getnameinfo(sockaddr, socket.NI_NAMEREQD)
except socket.gaierror:
return canonname.lower()
return name[0].lower()
async def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using GSSAPI."""
if not HAVE_KERBEROS:
raise ConfigurationError(
'The "kerberos" module must be installed to use GSSAPI authentication.'
)
try:
username = credentials.username
password = credentials.password
props = credentials.mechanism_properties
# Starting here and continuing through the while loop below - establish
# the security context. See RFC 4752, Section 3.1, first paragraph.
host = conn.address[0]
if props.canonicalize_host_name:
host = _canonicalize_hostname(host)
service = props.service_name + "@" + host
if props.service_realm is not None:
service = service + "@" + props.service_realm
if password is not None:
if _USE_PRINCIPAL:
# Note that, though we use unquote_plus for unquoting URI
# options, we use quote here. Microsoft's UrlUnescape (used
# by WinKerberos) doesn't support +.
principal = ":".join((quote(username), quote(password)))
result, ctx = kerberos.authGSSClientInit(
service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG
)
else:
if "@" in username:
user, domain = username.split("@", 1)
else:
user, domain = username, None
result, ctx = kerberos.authGSSClientInit(
service,
gssflags=kerberos.GSS_C_MUTUAL_FLAG,
user=user,
domain=domain,
password=password,
)
else:
result, ctx = kerberos.authGSSClientInit(service, gssflags=kerberos.GSS_C_MUTUAL_FLAG)
if result != kerberos.AUTH_GSS_COMPLETE:
raise OperationFailure("Kerberos context failed to initialize.")
try:
# pykerberos uses a weird mix of exceptions and return values
# to indicate errors.
# 0 == continue, 1 == complete, -1 == error
# Only authGSSClientStep can return 0.
if kerberos.authGSSClientStep(ctx, "") != 0:
raise OperationFailure("Unknown kerberos failure in step function.")
# Start a SASL conversation with mongod/s
# Note: pykerberos deals with base64 encoded byte strings.
# Since mongo accepts base64 strings as the payload we don't
# have to use bson.binary.Binary.
payload = kerberos.authGSSClientResponse(ctx)
cmd = {
"saslStart": 1,
"mechanism": "GSSAPI",
"payload": payload,
"autoAuthorize": 1,
}
response = await conn.command("$external", cmd)
# Limit how many times we loop to catch protocol / library issues
for _ in range(10):
result = kerberos.authGSSClientStep(ctx, str(response["payload"]))
if result == -1:
raise OperationFailure("Unknown kerberos failure in step function.")
payload = kerberos.authGSSClientResponse(ctx) or ""
cmd = {
"saslContinue": 1,
"conversationId": response["conversationId"],
"payload": payload,
}
response = await conn.command("$external", cmd)
if result == kerberos.AUTH_GSS_COMPLETE:
break
else:
raise OperationFailure("Kerberos authentication failed to complete.")
# Once the security context is established actually authenticate.
# See RFC 4752, Section 3.1, last two paragraphs.
if kerberos.authGSSClientUnwrap(ctx, str(response["payload"])) != 1:
raise OperationFailure("Unknown kerberos failure during GSS_Unwrap step.")
if kerberos.authGSSClientWrap(ctx, kerberos.authGSSClientResponse(ctx), username) != 1:
raise OperationFailure("Unknown kerberos failure during GSS_Wrap step.")
payload = kerberos.authGSSClientResponse(ctx)
cmd = {
"saslContinue": 1,
"conversationId": response["conversationId"],
"payload": payload,
}
await conn.command("$external", cmd)
finally:
kerberos.authGSSClientClean(ctx)
except kerberos.KrbError as exc:
raise OperationFailure(str(exc)) from None
async def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using SASL PLAIN (RFC 4616)"""
source = credentials.source
username = credentials.username
password = credentials.password
payload = (f"\x00{username}\x00{password}").encode()
cmd = {
"saslStart": 1,
"mechanism": "PLAIN",
"payload": Binary(payload),
"autoAuthorize": 1,
}
await conn.command(source, cmd)
async def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using MONGODB-X509."""
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
# MONGODB-X509 is done after the speculative auth step.
return
cmd = _X509Context(credentials, conn.address).speculate_command()
await conn.command("$external", cmd)
async def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using MONGODB-CR."""
source = credentials.source
username = credentials.username
password = credentials.password
# Get a nonce
response = await conn.command(source, {"getnonce": 1})
nonce = response["nonce"]
key = _auth_key(nonce, username, password)
# Actually authenticate
query = {"authenticate": 1, "user": username, "nonce": nonce, "key": key}
await conn.command(source, query)
async def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None:
if conn.max_wire_version >= 7:
if conn.negotiated_mechs:
mechs = conn.negotiated_mechs
else:
source = credentials.source
cmd = conn.hello_cmd()
cmd["saslSupportedMechs"] = source + "." + credentials.username
mechs = (await conn.command(source, cmd, publish_events=False)).get(
"saslSupportedMechs", []
)
if "SCRAM-SHA-256" in mechs:
return await _authenticate_scram(credentials, conn, "SCRAM-SHA-256")
else:
return await _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
else:
return await _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
_AUTH_MAP: Mapping[str, Callable[..., Coroutine[Any, Any, None]]] = {
"GSSAPI": _authenticate_gssapi,
"MONGODB-CR": _authenticate_mongo_cr,
"MONGODB-X509": _authenticate_x509,
"MONGODB-AWS": _authenticate_aws,
"MONGODB-OIDC": _authenticate_oidc, # type:ignore[dict-item]
"PLAIN": _authenticate_plain,
"SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"),
"SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"),
"DEFAULT": _authenticate_default,
}
class _AuthContext:
def __init__(self, credentials: MongoCredential, address: tuple[str, int]) -> None:
self.credentials = credentials
self.speculative_authenticate: Optional[Mapping[str, Any]] = None
self.address = address
@staticmethod
def from_credentials(
creds: MongoCredential, address: tuple[str, int]
) -> Optional[_AuthContext]:
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
if spec_cls:
return cast(_AuthContext, spec_cls(creds, address))
return None
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
raise NotImplementedError
def parse_response(self, hello: Hello[Mapping[str, Any]]) -> None:
self.speculative_authenticate = hello.speculative_authenticate
def speculate_succeeded(self) -> bool:
return bool(self.speculative_authenticate)
class _ScramContext(_AuthContext):
def __init__(
self, credentials: MongoCredential, address: tuple[str, int], mechanism: str
) -> None:
super().__init__(credentials, address)
self.scram_data: Optional[tuple[bytes, bytes]] = None
self.mechanism = mechanism
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism)
# The 'db' field is included only on the speculative command.
cmd["db"] = self.credentials.source
# Save for later use.
self.scram_data = (nonce, first_bare)
return cmd
class _X509Context(_AuthContext):
def speculate_command(self) -> MutableMapping[str, Any]:
cmd = {"authenticate": 1, "mechanism": "MONGODB-X509"}
if self.credentials.username is not None:
cmd["user"] = self.credentials.username
return cmd
class _OIDCContext(_AuthContext):
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
authenticator = _get_authenticator(self.credentials, self.address)
cmd = authenticator.get_spec_auth_cmd()
if cmd is None:
return None
cmd["db"] = self.credentials.source
return cmd
_SPECULATIVE_AUTH_MAP: Mapping[str, Any] = {
"MONGODB-X509": _X509Context,
"SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"),
"SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
"MONGODB-OIDC": _OIDCContext,
"DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
}
async def authenticate(
credentials: MongoCredential, conn: Connection, reauthenticate: bool = False
) -> None:
"""Authenticate connection."""
mechanism = credentials.mechanism
auth_func = _AUTH_MAP[mechanism]
if mechanism == "MONGODB-OIDC":
await _authenticate_oidc(credentials, conn, reauthenticate)
else:
await auth_func(credentials, conn)

View File

@ -0,0 +1,100 @@
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MONGODB-AWS Authentication helpers."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Type
import bson
from bson.binary import Binary
from pymongo.errors import ConfigurationError, OperationFailure
if TYPE_CHECKING:
from bson.typings import _ReadableBuffer
from pymongo.asynchronous.auth import MongoCredential
from pymongo.asynchronous.pool import Connection
_IS_SYNC = False
async def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using MONGODB-AWS."""
try:
import pymongo_auth_aws # type:ignore[import]
except ImportError as e:
raise ConfigurationError(
"MONGODB-AWS authentication requires pymongo-auth-aws: "
"install with: python -m pip install 'pymongo[aws]'"
) from e
# Delayed import.
from pymongo_auth_aws.auth import ( # type:ignore[import]
set_cached_credentials,
set_use_cached_credentials,
)
set_use_cached_credentials(True)
if conn.max_wire_version < 9:
raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later")
class AwsSaslContext(pymongo_auth_aws.AwsSaslContext): # type: ignore
# Dependency injection:
def binary_type(self) -> Type[Binary]:
"""Return the bson.binary.Binary type."""
return Binary
def bson_encode(self, doc: Mapping[str, Any]) -> bytes:
"""Encode a dictionary to BSON."""
return bson.encode(doc)
def bson_decode(self, data: _ReadableBuffer) -> Mapping[str, Any]:
"""Decode BSON to a dictionary."""
return bson.decode(data)
try:
ctx = AwsSaslContext(
pymongo_auth_aws.AwsCredential(
credentials.username,
credentials.password,
credentials.mechanism_properties.aws_session_token,
)
)
client_payload = ctx.step(None)
client_first = {"saslStart": 1, "mechanism": "MONGODB-AWS", "payload": client_payload}
server_first = await conn.command("$external", client_first)
res = server_first
# Limit how many times we loop to catch protocol / library issues
for _ in range(10):
client_payload = ctx.step(res["payload"])
cmd = {
"saslContinue": 1,
"conversationId": server_first["conversationId"],
"payload": client_payload,
}
res = await conn.command("$external", cmd)
if res["done"]:
# SASL complete.
break
except pymongo_auth_aws.PyMongoAuthAwsError as exc:
# Clear the cached credentials if we hit a failure in auth.
set_cached_credentials(None)
# Convert to OperationFailure and include pymongo-auth-aws version.
raise OperationFailure(
f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})"
) from None
except Exception:
# Clear the cached credentials if we hit a failure in auth.
set_cached_credentials(None)
raise

View File

@ -0,0 +1,380 @@
# Copyright 2023-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MONGODB-OIDC Authentication helpers."""
from __future__ import annotations
import abc
import os
import threading
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
from urllib.parse import quote
import bson
from bson.binary import Binary
from pymongo._azure_helpers import _get_azure_response
from pymongo._csot import remaining
from pymongo._gcp_helpers import _get_gcp_response
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.helpers_constants import _AUTHENTICATION_FAILURE_CODE
if TYPE_CHECKING:
from pymongo.asynchronous.auth import MongoCredential
from pymongo.asynchronous.pool import Connection
_IS_SYNC = False
@dataclass
class OIDCIdPInfo:
issuer: str
clientId: Optional[str] = field(default=None)
requestScopes: Optional[list[str]] = field(default=None)
@dataclass
class OIDCCallbackContext:
timeout_seconds: float
username: str
version: int
refresh_token: Optional[str] = field(default=None)
idp_info: Optional[OIDCIdPInfo] = field(default=None)
@dataclass
class OIDCCallbackResult:
access_token: str
expires_in_seconds: Optional[float] = field(default=None)
refresh_token: Optional[str] = field(default=None)
class OIDCCallback(abc.ABC):
"""A base class for defining OIDC callbacks."""
@abc.abstractmethod
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
"""Convert the given BSON value into our own type."""
@dataclass
class _OIDCProperties:
callback: Optional[OIDCCallback] = field(default=None)
human_callback: Optional[OIDCCallback] = field(default=None)
environment: Optional[str] = field(default=None)
allowed_hosts: list[str] = field(default_factory=list)
token_resource: Optional[str] = field(default=None)
username: str = ""
"""Mechanism properties for MONGODB-OIDC authentication."""
TOKEN_BUFFER_MINUTES = 5
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
CALLBACK_VERSION = 1
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
TIME_BETWEEN_CALLS_SECONDS = 0.1
def _get_authenticator(
credentials: MongoCredential, address: tuple[str, int]
) -> _OIDCAuthenticator:
if credentials.cache.data:
return credentials.cache.data
# Extract values.
principal_name = credentials.username
properties = credentials.mechanism_properties
# Validate that the address is allowed.
if not properties.environment:
found = False
allowed_hosts = properties.allowed_hosts
for patt in allowed_hosts:
if patt == address[0]:
found = True
elif patt.startswith("*.") and address[0].endswith(patt[1:]):
found = True
if not found:
raise ConfigurationError(
f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
)
# Get or create the cache data.
credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties)
return credentials.cache.data
class _OIDCTestCallback(OIDCCallback):
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
token_file = os.environ.get("OIDC_TOKEN_FILE")
if not token_file:
raise RuntimeError(
'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set'
)
with open(token_file) as fid:
return OIDCCallbackResult(access_token=fid.read().strip())
class _OIDCAWSCallback(OIDCCallback):
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
if not token_file:
raise RuntimeError(
'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set'
)
with open(token_file) as fid:
return OIDCCallbackResult(access_token=fid.read().strip())
class _OIDCAzureCallback(OIDCCallback):
def __init__(self, token_resource: str) -> None:
self.token_resource = quote(token_resource)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds)
return OIDCCallbackResult(
access_token=resp["access_token"], expires_in_seconds=resp["expires_in"]
)
class _OIDCGCPCallback(OIDCCallback):
def __init__(self, token_resource: str) -> None:
self.token_resource = quote(token_resource)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
return OIDCCallbackResult(access_token=resp["access_token"])
@dataclass
class _OIDCAuthenticator:
username: str
properties: _OIDCProperties
refresh_token: Optional[str] = field(default=None)
access_token: Optional[str] = field(default=None)
idp_info: Optional[OIDCIdPInfo] = field(default=None)
token_gen_id: int = field(default=0)
lock: threading.Lock = field(default_factory=threading.Lock)
last_call_time: float = field(default=0)
async def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
"""Handle a reauthenticate from the server."""
# Invalidate the token for the connection.
self._invalidate(conn)
# Call the appropriate auth logic for the callback type.
if self.properties.callback:
return await self._authenticate_machine(conn)
return await self._authenticate_human(conn)
async def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
"""Handle an initial authenticate request."""
# First handle speculative auth.
# If it succeeded, we are done.
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
resp = ctx.speculative_authenticate
if resp and resp["done"]:
conn.oidc_token_gen_id = self.token_gen_id
return resp
# If spec auth failed, call the appropriate auth logic for the callback type.
# We cannot assume that the token is invalid, because a proxy may have been
# involved that stripped the speculative auth information.
if self.properties.callback:
return await self._authenticate_machine(conn)
return await self._authenticate_human(conn)
def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]:
"""Get the appropriate speculative auth command."""
if not self.access_token:
return None
return self._get_start_command({"jwt": self.access_token})
async def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]:
# If there is a cached access token, try to authenticate with it. If
# authentication fails with error code 18, invalidate the access token,
# fetch a new access token, and try to authenticate again. If authentication
# fails for any other reason, raise the error to the user.
if self.access_token:
try:
return await self._sasl_start_jwt(conn)
except OperationFailure as e:
if self._is_auth_error(e):
return await self._authenticate_machine(conn)
raise
return await self._sasl_start_jwt(conn)
async def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]:
# If we have a cached access token, try a JwtStepRequest.
# authentication fails with error code 18, invalidate the access token,
# and try to authenticate again. If authentication fails for any other
# reason, raise the error to the user.
if self.access_token:
try:
return await self._sasl_start_jwt(conn)
except OperationFailure as e:
if self._is_auth_error(e):
return await self._authenticate_human(conn)
raise
# If we have a cached refresh token, try a JwtStepRequest with that.
# If authentication fails with error code 18, invalidate the access and
# refresh tokens, and try to authenticate again. If authentication fails for
# any other reason, raise the error to the user.
if self.refresh_token:
try:
return await self._sasl_start_jwt(conn)
except OperationFailure as e:
if self._is_auth_error(e):
self.refresh_token = None
return await self._authenticate_human(conn)
raise
# Start a new Two-Step SASL conversation.
# Run a PrincipalStepRequest to get the IdpInfo.
cmd = self._get_start_command(None)
start_resp = await self._run_command(conn, cmd)
# Attempt to authenticate with a JwtStepRequest.
return await self._sasl_continue_jwt(conn, start_resp)
def _get_access_token(self) -> Optional[str]:
properties = self.properties
cb: Union[None, OIDCCallback]
resp: OIDCCallbackResult
is_human = properties.human_callback is not None
if is_human and self.idp_info is None:
return None
if properties.callback:
cb = properties.callback
if properties.human_callback:
cb = properties.human_callback
prev_token = self.access_token
if prev_token:
return prev_token
if cb is None and not prev_token:
return None
if not prev_token and cb is not None:
with self.lock:
# See if the token was changed while we were waiting for the
# lock.
new_token = self.access_token
if new_token != prev_token:
return new_token
# Ensure that we are waiting a min time between callback invocations.
delta = time.time() - self.last_call_time
if delta < TIME_BETWEEN_CALLS_SECONDS:
time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
self.last_call_time = time.time()
if is_human:
timeout = HUMAN_CALLBACK_TIMEOUT_SECONDS
assert self.idp_info is not None
else:
timeout = int(remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS)
context = OIDCCallbackContext(
timeout_seconds=timeout,
version=CALLBACK_VERSION,
refresh_token=self.refresh_token,
idp_info=self.idp_info,
username=self.properties.username,
)
resp = cb.fetch(context)
if not isinstance(resp, OIDCCallbackResult):
raise ValueError("Callback result must be of type OIDCCallbackResult")
self.refresh_token = resp.refresh_token
self.access_token = resp.access_token
self.token_gen_id += 1
return self.access_token
async def _run_command(
self, conn: Connection, cmd: MutableMapping[str, Any]
) -> Mapping[str, Any]:
try:
return await conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
except OperationFailure as e:
if self._is_auth_error(e):
self._invalidate(conn)
raise
def _is_auth_error(self, err: Exception) -> bool:
if not isinstance(err, OperationFailure):
return False
return err.code == _AUTHENTICATION_FAILURE_CODE
def _invalidate(self, conn: Connection) -> None:
# Ignore the invalidation if a token gen id is given and is less than our
# current token gen id.
token_gen_id = conn.oidc_token_gen_id or 0
if token_gen_id is not None and token_gen_id < self.token_gen_id:
return
self.access_token = None
async def _sasl_continue_jwt(
self, conn: Connection, start_resp: Mapping[str, Any]
) -> Mapping[str, Any]:
self.access_token = None
self.refresh_token = None
start_payload: dict = bson.decode(start_resp["payload"])
if "issuer" in start_payload:
self.idp_info = OIDCIdPInfo(**start_payload)
access_token = self._get_access_token()
conn.oidc_token_gen_id = self.token_gen_id
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
return await self._run_command(conn, cmd)
async def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]:
access_token = self._get_access_token()
conn.oidc_token_gen_id = self.token_gen_id
cmd = self._get_start_command({"jwt": access_token})
return await self._run_command(conn, cmd)
def _get_start_command(self, payload: Optional[Mapping[str, Any]]) -> MutableMapping[str, Any]:
if payload is None:
principal_name = self.username
if principal_name:
payload = {"n": principal_name}
else:
payload = {}
bin_payload = Binary(bson.encode(payload))
return {"saslStart": 1, "mechanism": "MONGODB-OIDC", "payload": bin_payload}
def _get_continue_command(
self, payload: Mapping[str, Any], start_resp: Mapping[str, Any]
) -> MutableMapping[str, Any]:
bin_payload = Binary(bson.encode(payload))
return {
"saslContinue": 1,
"payload": bin_payload,
"conversationId": start_resp["conversationId"],
}
async def _authenticate_oidc(
credentials: MongoCredential, conn: Connection, reauthenticate: bool
) -> Optional[Mapping[str, Any]]:
"""Authenticate using MONGODB-OIDC."""
authenticator = _get_authenticator(credentials, conn.address)
if reauthenticate:
return await authenticator.reauthenticate(conn)
else:
return await authenticator.authenticate(conn)

View File

@ -0,0 +1,599 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The bulk write operations interface.
.. versionadded:: 2.7
"""
from __future__ import annotations
import copy
from collections.abc import MutableMapping
from itertools import islice
from typing import (
TYPE_CHECKING,
Any,
Iterator,
Mapping,
NoReturn,
Optional,
Type,
Union,
)
from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from pymongo import _csot
from pymongo.asynchronous import common
from pymongo.asynchronous.client_session import ClientSession, _validate_session_write_concern
from pymongo.asynchronous.common import (
validate_is_document_type,
validate_ok_for_replace,
validate_ok_for_update,
)
from pymongo.asynchronous.helpers import _get_wce_doc
from pymongo.asynchronous.message import (
_DELETE,
_INSERT,
_UPDATE,
_BulkWriteContext,
_EncryptedBulkWriteContext,
_randint,
)
from pymongo.asynchronous.read_preferences import ReadPreference
from pymongo.errors import (
BulkWriteError,
ConfigurationError,
InvalidOperation,
OperationFailure,
)
from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.typings import _DocumentOut, _DocumentType, _Pipeline
_IS_SYNC = False
_DELETE_ALL: int = 0
_DELETE_ONE: int = 1
# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err
_BAD_VALUE: int = 2
_UNKNOWN_ERROR: int = 8
_WRITE_CONCERN_ERROR: int = 64
_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete")
class _Run:
"""Represents a batch of write operations."""
def __init__(self, op_type: int) -> None:
"""Initialize a new Run object."""
self.op_type: int = op_type
self.index_map: list[int] = []
self.ops: list[Any] = []
self.idx_offset: int = 0
def index(self, idx: int) -> int:
"""Get the original index of an operation in this run.
:param idx: The Run index that maps to the original index.
"""
return self.index_map[idx]
def add(self, original_index: int, operation: Any) -> None:
"""Add an operation to this Run instance.
:param original_index: The original index of this operation
within a larger bulk operation.
:param operation: The operation document.
"""
self.index_map.append(original_index)
self.ops.append(operation)
def _merge_command(
run: _Run,
full_result: MutableMapping[str, Any],
offset: int,
result: Mapping[str, Any],
) -> None:
"""Merge a write command result into the full bulk result."""
affected = result.get("n", 0)
if run.op_type == _INSERT:
full_result["nInserted"] += affected
elif run.op_type == _DELETE:
full_result["nRemoved"] += affected
elif run.op_type == _UPDATE:
upserted = result.get("upserted")
if upserted:
n_upserted = len(upserted)
for doc in upserted:
doc["index"] = run.index(doc["index"] + offset)
full_result["upserted"].extend(upserted)
full_result["nUpserted"] += n_upserted
full_result["nMatched"] += affected - n_upserted
else:
full_result["nMatched"] += affected
full_result["nModified"] += result["nModified"]
write_errors = result.get("writeErrors")
if write_errors:
for doc in write_errors:
# Leave the server response intact for APM.
replacement = doc.copy()
idx = doc["index"] + offset
replacement["index"] = run.index(idx)
# Add the failed operation to the error document.
replacement["op"] = run.ops[idx]
full_result["writeErrors"].append(replacement)
wce = _get_wce_doc(result)
if wce:
full_result["writeConcernErrors"].append(wce)
def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn:
"""Raise a BulkWriteError from the full bulk api result."""
# retryWrites on MMAPv1 should raise an actionable error.
if full_result["writeErrors"]:
full_result["writeErrors"].sort(key=lambda error: error["index"])
err = full_result["writeErrors"][0]
code = err["code"]
msg = err["errmsg"]
if code == 20 and msg.startswith("Transaction numbers"):
errmsg = (
"This MongoDB deployment does not support "
"retryable writes. Please add retryWrites=false "
"to your connection string."
)
raise OperationFailure(errmsg, code, full_result)
raise BulkWriteError(full_result)
class _Bulk:
"""The private guts of the bulk write API."""
def __init__(
self,
collection: AsyncCollection[_DocumentType],
ordered: bool,
bypass_document_validation: bool,
comment: Optional[str] = None,
let: Optional[Any] = None,
) -> None:
"""Initialize a _Bulk instance."""
self.collection = collection.with_options(
codec_options=collection.codec_options._replace(
unicode_decode_error_handler="replace", document_class=dict
)
)
self.let = let
if self.let is not None:
common.validate_is_document_type("let", self.let)
self.comment: Optional[str] = comment
self.ordered = ordered
self.ops: list[tuple[int, Mapping[str, Any]]] = []
self.executed = False
self.bypass_doc_val = bypass_document_validation
self.uses_collation = False
self.uses_array_filters = False
self.uses_hint_update = False
self.uses_hint_delete = False
self.is_retryable = True
self.retrying = False
self.started_retryable_write = False
# Extra state so that we know where to pick up on a retry attempt.
self.current_run = None
self.next_run = None
@property
def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
encrypter = self.collection.database.client._encrypter
if encrypter and not encrypter._bypass_auto_encryption:
return _EncryptedBulkWriteContext
else:
return _BulkWriteContext
def add_insert(self, document: _DocumentOut) -> None:
"""Add an insert document to the list of ops."""
validate_is_document_type("document", document)
# Generate ObjectId client side.
if not (isinstance(document, RawBSONDocument) or "_id" in document):
document["_id"] = ObjectId()
self.ops.append((_INSERT, document))
def add_update(
self,
selector: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
multi: bool = False,
upsert: bool = False,
collation: Optional[Mapping[str, Any]] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Union[str, dict[str, Any], None] = None,
) -> None:
"""Create an update document and add it to the list of ops."""
validate_ok_for_update(update)
cmd: dict[str, Any] = dict( # noqa: C406
[("q", selector), ("u", update), ("multi", multi), ("upsert", upsert)]
)
if collation is not None:
self.uses_collation = True
cmd["collation"] = collation
if array_filters is not None:
self.uses_array_filters = True
cmd["arrayFilters"] = array_filters
if hint is not None:
self.uses_hint_update = True
cmd["hint"] = hint
if multi:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
self.ops.append((_UPDATE, cmd))
def add_replace(
self,
selector: Mapping[str, Any],
replacement: Mapping[str, Any],
upsert: bool = False,
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
) -> None:
"""Create a replace document and add it to the list of ops."""
validate_ok_for_replace(replacement)
cmd = {"q": selector, "u": replacement, "multi": False, "upsert": upsert}
if collation is not None:
self.uses_collation = True
cmd["collation"] = collation
if hint is not None:
self.uses_hint_update = True
cmd["hint"] = hint
self.ops.append((_UPDATE, cmd))
def add_delete(
self,
selector: Mapping[str, Any],
limit: int,
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
) -> None:
"""Create a delete document and add it to the list of ops."""
cmd = {"q": selector, "limit": limit}
if collation is not None:
self.uses_collation = True
cmd["collation"] = collation
if hint is not None:
self.uses_hint_delete = True
cmd["hint"] = hint
if limit == _DELETE_ALL:
# A bulk_write containing a delete_many is not retryable.
self.is_retryable = False
self.ops.append((_DELETE, cmd))
def gen_ordered(self) -> Iterator[Optional[_Run]]:
"""Generate batches of operations, batched by type of
operation, in the order **provided**.
"""
run = None
for idx, (op_type, operation) in enumerate(self.ops):
if run is None:
run = _Run(op_type)
elif run.op_type != op_type:
yield run
run = _Run(op_type)
run.add(idx, operation)
yield run
def gen_unordered(self) -> Iterator[_Run]:
"""Generate batches of operations, batched by type of
operation, in arbitrary order.
"""
operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)]
for idx, (op_type, operation) in enumerate(self.ops):
operations[op_type].add(idx, operation)
for run in operations:
if run.ops:
yield run
async def _execute_command(
self,
generator: Iterator[Any],
write_concern: WriteConcern,
session: Optional[ClientSession],
conn: Connection,
op_id: int,
retryable: bool,
full_result: MutableMapping[str, Any],
final_write_concern: Optional[WriteConcern] = None,
) -> None:
db_name = self.collection.database.name
client = self.collection.database.client
listeners = client._event_listeners
if not self.current_run:
self.current_run = next(generator)
self.next_run = None
run = self.current_run
# Connection.command validates the session, but we use
# Connection.write_command
conn.validate_session(client, session)
last_run = False
while run:
if not self.retrying:
self.next_run = next(generator, None)
if self.next_run is None:
last_run = True
cmd_name = _COMMANDS[run.op_type]
bwc = self.bulk_ctx_class(
db_name,
cmd_name,
conn,
op_id,
listeners,
session,
run.op_type,
self.collection.codec_options,
)
while run.idx_offset < len(run.ops):
# If this is the last possible operation, use the
# final write concern.
if last_run and (len(run.ops) - run.idx_offset) == 1:
write_concern = final_write_concern or write_concern
cmd = {cmd_name: self.collection.name, "ordered": self.ordered}
if self.comment:
cmd["comment"] = self.comment
_csot.apply_write_concern(cmd, write_concern)
if self.bypass_doc_val:
cmd["bypassDocumentValidation"] = True
if self.let is not None and run.op_type in (_DELETE, _UPDATE):
cmd["let"] = self.let
if session:
# Start a new retryable write unless one was already
# started for this command.
if retryable and not self.started_retryable_write:
session._start_retryable_write()
self.started_retryable_write = True
await session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
conn.send_cluster_time(cmd, session, client)
conn.add_server_api(cmd)
# CSOT: apply timeout before encoding the command.
conn.apply_timeout(client, cmd)
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible in one command.
if write_concern.acknowledged:
result, to_send = await bwc.execute(cmd, ops, client)
# Retryable writeConcernErrors halt the execution of this run.
wce = result.get("writeConcernError", {})
if wce.get("code", 0) in _RETRYABLE_ERROR_CODES:
# Synthesize the full bulk result without modifying the
# current one because this write operation may be retried.
full = copy.deepcopy(full_result)
_merge_command(run, full, run.idx_offset, result)
_raise_bulk_write_error(full)
_merge_command(run, full_result, run.idx_offset, result)
# We're no longer in a retry once a command succeeds.
self.retrying = False
self.started_retryable_write = False
if self.ordered and "writeErrors" in result:
break
else:
to_send = await bwc.execute_unack(cmd, ops, client)
run.idx_offset += len(to_send)
# We're supposed to continue if errors are
# at the write concern level (e.g. wtimeout)
if self.ordered and full_result["writeErrors"]:
break
# Reset our state
self.current_run = run = self.next_run
async def execute_command(
self,
generator: Iterator[Any],
write_concern: WriteConcern,
session: Optional[ClientSession],
operation: str,
) -> dict[str, Any]:
"""Execute using write commands."""
# nModified is only reported for write commands, not legacy ops.
full_result = {
"writeErrors": [],
"writeConcernErrors": [],
"nInserted": 0,
"nUpserted": 0,
"nMatched": 0,
"nModified": 0,
"nRemoved": 0,
"upserted": [],
}
op_id = _randint()
async def retryable_bulk(
session: Optional[ClientSession], conn: Connection, retryable: bool
) -> None:
await self._execute_command(
generator,
write_concern,
session,
conn,
op_id,
retryable,
full_result,
)
client = self.collection.database.client
_ = await client._retryable_write(
self.is_retryable,
retryable_bulk,
session,
operation,
bulk=self,
operation_id=op_id,
)
if full_result["writeErrors"] or full_result["writeConcernErrors"]:
_raise_bulk_write_error(full_result)
return full_result
async def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None:
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
db_name = self.collection.database.name
client = self.collection.database.client
listeners = client._event_listeners
op_id = _randint()
if not self.current_run:
self.current_run = next(generator)
run = self.current_run
while run:
cmd_name = _COMMANDS[run.op_type]
bwc = self.bulk_ctx_class(
db_name,
cmd_name,
conn,
op_id,
listeners,
None,
run.op_type,
self.collection.codec_options,
)
while run.idx_offset < len(run.ops):
cmd = {
cmd_name: self.collection.name,
"ordered": False,
"writeConcern": {"w": 0},
}
conn.add_server_api(cmd)
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible.
to_send = await bwc.execute_unack(cmd, ops, client)
run.idx_offset += len(to_send)
self.current_run = run = next(generator, None)
async def execute_command_no_results(
self,
conn: Connection,
generator: Iterator[Any],
write_concern: WriteConcern,
) -> None:
"""Execute write commands with OP_MSG and w=0 WriteConcern, ordered."""
full_result = {
"writeErrors": [],
"writeConcernErrors": [],
"nInserted": 0,
"nUpserted": 0,
"nMatched": 0,
"nModified": 0,
"nRemoved": 0,
"upserted": [],
}
# Ordered bulk writes have to be acknowledged so that we stop
# processing at the first error, even when the application
# specified unacknowledged writeConcern.
initial_write_concern = WriteConcern()
op_id = _randint()
try:
await self._execute_command(
generator,
initial_write_concern,
None,
conn,
op_id,
False,
full_result,
write_concern,
)
except OperationFailure:
pass
async def execute_no_results(
self,
conn: Connection,
generator: Iterator[Any],
write_concern: WriteConcern,
) -> None:
"""Execute all operations, returning no results (w=0)."""
if self.uses_collation:
raise ConfigurationError("Collation is unsupported for unacknowledged writes.")
if self.uses_array_filters:
raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.")
# Guard against unsupported unacknowledged writes.
unack = write_concern and not write_concern.acknowledged
if unack and self.uses_hint_delete and conn.max_wire_version < 9:
raise ConfigurationError(
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
)
if unack and self.uses_hint_update and conn.max_wire_version < 8:
raise ConfigurationError(
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
)
# Cannot have both unacknowledged writes and bypass document validation.
if self.bypass_doc_val:
raise OperationFailure(
"Cannot set bypass_document_validation with unacknowledged write concern"
)
if self.ordered:
return await self.execute_command_no_results(conn, generator, write_concern)
return await self.execute_op_msg_no_results(conn, generator)
async def execute(
self,
write_concern: WriteConcern,
session: Optional[ClientSession],
operation: str,
) -> Any:
"""Execute operations."""
if not self.ops:
raise InvalidOperation("No operations to execute")
if self.executed:
raise InvalidOperation("Bulk operations can only be executed once.")
self.executed = True
write_concern = write_concern or self.collection.write_concern
session = _validate_session_write_concern(session, write_concern)
if self.ordered:
generator = self.gen_ordered()
else:
generator = self.gen_unordered()
client = self.collection.database.client
if not write_concern.acknowledged:
async with await client._conn_for_writes(session, operation) as connection:
await self.execute_no_results(connection, generator, write_concern)
return None
else:
return await self.execute_command(generator, write_concern, session, operation)

View File

@ -0,0 +1,499 @@
# Copyright 2017 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Watch changes on a collection, a database, or the entire cluster."""
from __future__ import annotations
import copy
from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
from bson import CodecOptions, _bson_to_dict
from bson.raw_bson import RawBSONDocument
from bson.timestamp import Timestamp
from pymongo import _csot
from pymongo.asynchronous import common
from pymongo.asynchronous.aggregation import (
_AggregationCommand,
_CollectionAggregationCommand,
_DatabaseAggregationCommand,
)
from pymongo.asynchronous.collation import validate_collation_or_none
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.operations import _Op
from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _Pipeline
from pymongo.errors import (
ConnectionFailure,
CursorNotFound,
InvalidOperation,
OperationFailure,
PyMongoError,
)
_IS_SYNC = False
# The change streams spec considers the following server errors from the
# getMore command non-resumable. All other getMore errors are resumable.
_RESUMABLE_GETMORE_ERRORS = frozenset(
[
6, # HostUnreachable
7, # HostNotFound
89, # NetworkTimeout
91, # ShutdownInProgress
189, # PrimarySteppedDown
262, # ExceededTimeLimit
9001, # SocketException
10107, # NotWritablePrimary
11600, # InterruptedAtShutdown
11602, # InterruptedDueToReplStateChange
13435, # NotPrimaryNoSecondaryOk
13436, # NotPrimaryOrSecondary
63, # StaleShardVersion
150, # StaleEpoch
13388, # StaleConfig
234, # RetryChangeStream
133, # FailedToSatisfyReadPreference
]
)
if TYPE_CHECKING:
from pymongo.asynchronous.client_session import ClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import Connection
def _resumable(exc: PyMongoError) -> bool:
"""Return True if given a resumable change stream error."""
if isinstance(exc, (ConnectionFailure, CursorNotFound)):
return True
if isinstance(exc, OperationFailure):
if exc._max_wire_version is None:
return False
return (
exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError")
) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS)
return False
class ChangeStream(Generic[_DocumentType]):
"""The internal abstract base class for change stream cursors.
Should not be called directly by application developers. Use
:meth:`pymongo.collection.AsyncCollection.watch`,
:meth:`pymongo.database.AsyncDatabase.watch`, or
:meth:`pymongo.mongo_client.AsyncMongoClient.watch` instead.
.. versionadded:: 3.6
.. seealso:: The MongoDB documentation on `changeStreams <https://mongodb.com/docs/manual/changeStreams/>`_.
"""
def __init__(
self,
target: Union[
AsyncMongoClient[_DocumentType],
AsyncDatabase[_DocumentType],
AsyncCollection[_DocumentType],
],
pipeline: Optional[_Pipeline],
full_document: Optional[str],
resume_after: Optional[Mapping[str, Any]],
max_await_time_ms: Optional[int],
batch_size: Optional[int],
collation: Optional[_CollationIn],
start_at_operation_time: Optional[Timestamp],
session: Optional[ClientSession],
start_after: Optional[Mapping[str, Any]],
comment: Optional[Any] = None,
full_document_before_change: Optional[str] = None,
show_expanded_events: Optional[bool] = None,
) -> None:
if pipeline is None:
pipeline = []
pipeline = common.validate_list("pipeline", pipeline)
common.validate_string_or_none("full_document", full_document)
validate_collation_or_none(collation)
common.validate_non_negative_integer_or_none("batchSize", batch_size)
self._decode_custom = False
self._orig_codec_options: CodecOptions[_DocumentType] = target.codec_options
if target.codec_options.type_registry._decoder_map:
self._decode_custom = True
# Keep the type registry so that we support encoding custom types
# in the pipeline.
self._target = target.with_options( # type: ignore
codec_options=target.codec_options.with_options(document_class=RawBSONDocument)
)
else:
self._target = target
self._pipeline = copy.deepcopy(pipeline)
self._full_document = full_document
self._full_document_before_change = full_document_before_change
self._uses_start_after = start_after is not None
self._uses_resume_after = resume_after is not None
self._resume_token = copy.deepcopy(start_after or resume_after)
self._max_await_time_ms = max_await_time_ms
self._batch_size = batch_size
self._collation = collation
self._start_at_operation_time = start_at_operation_time
self._session = session
self._comment = comment
self._closed = False
self._timeout = self._target._timeout
self._show_expanded_events = show_expanded_events
async def _initialize_cursor(self) -> None:
# Initialize cursor.
self._cursor = await self._create_cursor()
@property
def _aggregation_command_class(self) -> Type[_AggregationCommand]:
"""The aggregation command class to be used."""
raise NotImplementedError
@property
def _client(self) -> AsyncMongoClient:
"""The client against which the aggregation commands for
this ChangeStream will be run.
"""
raise NotImplementedError
def _change_stream_options(self) -> dict[str, Any]:
"""Return the options dict for the $changeStream pipeline stage."""
options: dict[str, Any] = {}
if self._full_document is not None:
options["fullDocument"] = self._full_document
if self._full_document_before_change is not None:
options["fullDocumentBeforeChange"] = self._full_document_before_change
resume_token = self.resume_token
if resume_token is not None:
if self._uses_start_after:
options["startAfter"] = resume_token
else:
options["resumeAfter"] = resume_token
elif self._start_at_operation_time is not None:
options["startAtOperationTime"] = self._start_at_operation_time
if self._show_expanded_events:
options["showExpandedEvents"] = self._show_expanded_events
return options
def _command_options(self) -> dict[str, Any]:
"""Return the options dict for the aggregation command."""
options = {}
if self._max_await_time_ms is not None:
options["maxAwaitTimeMS"] = self._max_await_time_ms
if self._batch_size is not None:
options["batchSize"] = self._batch_size
return options
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
"""Return the full aggregation pipeline for this ChangeStream."""
options = self._change_stream_options()
full_pipeline: list = [{"$changeStream": options}]
full_pipeline.extend(self._pipeline)
return full_pipeline
def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None:
"""Callback that caches the postBatchResumeToken or
startAtOperationTime from a changeStream aggregate command response
containing an empty batch of change documents.
This is implemented as a callback because we need access to the wire
version in order to determine whether to cache this value.
"""
if not result["cursor"]["firstBatch"]:
if "postBatchResumeToken" in result["cursor"]:
self._resume_token = result["cursor"]["postBatchResumeToken"]
elif (
self._start_at_operation_time is None
and self._uses_resume_after is False
and self._uses_start_after is False
and conn.max_wire_version >= 7
):
self._start_at_operation_time = result.get("operationTime")
# PYTHON-2181: informative error on missing operationTime.
if self._start_at_operation_time is None:
raise OperationFailure(
"Expected field 'operationTime' missing from command "
f"response : {result!r}"
)
async def _run_aggregation_cmd(
self, session: Optional[ClientSession], explicit_session: bool
) -> AsyncCommandCursor:
"""Run the full aggregation pipeline for this ChangeStream and return
the corresponding AsyncCommandCursor.
"""
cmd = self._aggregation_command_class(
self._target,
AsyncCommandCursor,
self._aggregation_pipeline(),
self._command_options(),
explicit_session,
result_processor=self._process_result,
comment=self._comment,
)
return await self._client._retryable_read(
cmd.get_cursor,
self._target._read_preference_for(session),
session,
operation=_Op.AGGREGATE,
)
async def _create_cursor(self) -> AsyncCommandCursor:
async with self._client._tmp_session(self._session, close=False) as s:
return await self._run_aggregation_cmd(
session=s, explicit_session=self._session is not None
)
async def _resume(self) -> None:
"""Reestablish this change stream after a resumable error."""
try:
await self._cursor.close()
except PyMongoError:
pass
self._cursor = await self._create_cursor()
async def close(self) -> None:
"""Close this ChangeStream."""
self._closed = True
await self._cursor.close()
def __aiter__(self) -> ChangeStream[_DocumentType]:
return self
@property
def resume_token(self) -> Optional[Mapping[str, Any]]:
"""The cached resume token that will be used to resume after the most
recently returned change.
.. versionadded:: 3.9
"""
return copy.deepcopy(self._resume_token)
@_csot.apply
async def next(self) -> _DocumentType:
"""Advance the cursor.
This method blocks until the next change document is returned or an
unrecoverable error is raised. This method is used when iterating over
all changes in the cursor. For example::
try:
resume_token = None
pipeline = [{'$match': {'operationType': 'insert'}}]
async with db.collection.watch(pipeline) as stream:
async for insert_change in stream:
print(insert_change)
resume_token = stream.resume_token
except pymongo.errors.PyMongoError:
# The ChangeStream encountered an unrecoverable error or the
# resume attempt failed to recreate the cursor.
if resume_token is None:
# There is no usable resume token because there was a
# failure during ChangeStream initialization.
logging.error('...')
else:
# Use the interrupted ChangeStream's resume token to create
# a new ChangeStream. The new stream will continue from the
# last seen insert change without missing any events.
async with db.collection.watch(
pipeline, resume_after=resume_token) as stream:
async for insert_change in stream:
print(insert_change)
Raises :exc:`StopIteration` if this ChangeStream is closed.
"""
while self.alive:
doc = await self.try_next()
if doc is not None:
return doc
raise StopAsyncIteration
__anext__ = next
@property
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
.. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise
:exc:`StopIteration` and :meth:`try_next` can return ``None``.
.. versionadded:: 3.8
"""
return not self._closed
@_csot.apply
async def try_next(self) -> Optional[_DocumentType]:
"""Advance the cursor without blocking indefinitely.
This method returns the next change document without waiting
indefinitely for the next change. For example::
async with db.collection.watch() as stream:
while stream.alive:
change = await stream.try_next()
# Note that the ChangeStream's resume token may be updated
# even when no changes are returned.
print("Current resume token: %r" % (stream.resume_token,))
if change is not None:
print("Change document: %r" % (change,))
continue
# We end up here when there are no recent changes.
# Sleep for a while before trying again to avoid flooding
# the server with getMore requests when no changes are
# available.
asyncio.sleep(10)
If no change document is cached locally then this method runs a single
getMore command. If the getMore yields any documents, the next
document is returned, otherwise, if the getMore returns no documents
(because there have been no changes) then ``None`` is returned.
:return: The next change document or ``None`` when no document is available
after running a single getMore or when the cursor is closed.
.. versionadded:: 3.8
"""
if not self._closed and not self._cursor.alive:
await self._resume()
# Attempt to get the next change with at most one getMore and at most
# one resume attempt.
try:
try:
change = await self._cursor._try_next(True)
except PyMongoError as exc:
if not _resumable(exc):
raise
await self._resume()
change = await self._cursor._try_next(False)
except PyMongoError as exc:
# Close the stream after a fatal error.
if not _resumable(exc) and not exc.timeout:
await self.close()
raise
except Exception:
await self.close()
raise
# Check if the cursor was invalidated.
if not self._cursor.alive:
self._closed = True
# If no changes are available.
if change is None:
# We have either iterated over all documents in the cursor,
# OR the most-recently returned batch is empty. In either case,
# update the cached resume token with the postBatchResumeToken if
# one was returned. We also clear the startAtOperationTime.
if self._cursor._post_batch_resume_token is not None:
self._resume_token = self._cursor._post_batch_resume_token
self._start_at_operation_time = None
return change
# Else, changes are available.
try:
resume_token = change["_id"]
except KeyError:
await self.close()
raise InvalidOperation(
"Cannot provide resume functionality when the resume token is missing."
) from None
# If this is the last change document from the current batch, cache the
# postBatchResumeToken.
if not self._cursor._has_next() and self._cursor._post_batch_resume_token:
resume_token = self._cursor._post_batch_resume_token
# Hereafter, don't use startAfter; instead use resumeAfter.
self._uses_start_after = False
self._uses_resume_after = True
# Cache the resume token and clear startAtOperationTime.
self._resume_token = resume_token
self._start_at_operation_time = None
if self._decode_custom:
return _bson_to_dict(change.raw, self._orig_codec_options)
return change
async def __aenter__(self) -> ChangeStream[_DocumentType]:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
class CollectionChangeStream(ChangeStream[_DocumentType]):
"""A change stream that watches changes on a single collection.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.collection.AsyncCollection.watch` instead.
.. versionadded:: 3.7
"""
_target: AsyncCollection[_DocumentType]
@property
def _aggregation_command_class(self) -> Type[_CollectionAggregationCommand]:
return _CollectionAggregationCommand
@property
def _client(self) -> AsyncMongoClient[_DocumentType]:
return self._target.database.client
class DatabaseChangeStream(ChangeStream[_DocumentType]):
"""A change stream that watches changes on all collections in a database.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.database.AsyncDatabase.watch` instead.
.. versionadded:: 3.7
"""
_target: AsyncDatabase[_DocumentType]
@property
def _aggregation_command_class(self) -> Type[_DatabaseAggregationCommand]:
return _DatabaseAggregationCommand
@property
def _client(self) -> AsyncMongoClient[_DocumentType]:
return self._target.client
class ClusterChangeStream(DatabaseChangeStream[_DocumentType]):
"""A change stream that watches changes on all collections in the cluster.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.mongo_client.AsyncMongoClient.watch` instead.
.. versionadded:: 3.7
"""
def _change_stream_options(self) -> dict[str, Any]:
options = super()._change_stream_options()
options["allChangesForCluster"] = True
return options

View File

@ -0,0 +1,334 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Tools to parse mongo client options."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast
from bson.codec_options import _parse_codec_options
from pymongo.asynchronous import common
from pymongo.asynchronous.compression_support import CompressionSettings
from pymongo.asynchronous.monitoring import _EventListener, _EventListeners
from pymongo.asynchronous.pool import PoolOptions
from pymongo.asynchronous.read_preferences import (
_ServerMode,
make_read_preference,
read_pref_mode_from_name,
)
from pymongo.asynchronous.server_selectors import any_server_selector
from pymongo.errors import ConfigurationError
from pymongo.read_concern import ReadConcern
from pymongo.ssl_support import get_ssl_context
from pymongo.write_concern import WriteConcern, validate_boolean
if TYPE_CHECKING:
from bson.codec_options import CodecOptions
from pymongo.asynchronous.auth import MongoCredential
from pymongo.asynchronous.encryption_options import AutoEncryptionOpts
from pymongo.asynchronous.topology_description import _ServerSelector
from pymongo.pyopenssl_context import SSLContext
_IS_SYNC = False
def _parse_credentials(
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
) -> Optional[MongoCredential]:
"""Parse authentication credentials."""
mechanism = options.get("authmechanism", "DEFAULT" if username else None)
source = options.get("authsource")
if username or mechanism:
from pymongo.asynchronous.auth import _build_credentials_tuple
return _build_credentials_tuple(mechanism, source, username, password, options, database)
return None
def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode:
"""Parse read preference options."""
if "read_preference" in options:
return options["read_preference"]
name = options.get("readpreference", "primary")
mode = read_pref_mode_from_name(name)
tags = options.get("readpreferencetags")
max_staleness = options.get("maxstalenessseconds", -1)
return make_read_preference(mode, tags, max_staleness)
def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern:
"""Parse write concern options."""
concern = options.get("w")
wtimeout = options.get("wtimeoutms")
j = options.get("journal")
fsync = options.get("fsync")
return WriteConcern(concern, wtimeout, j, fsync)
def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
"""Parse read concern options."""
concern = options.get("readconcernlevel")
return ReadConcern(concern)
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
"""Parse ssl options."""
use_tls = options.get("tls")
if use_tls is not None:
validate_boolean("tls", use_tls)
certfile = options.get("tlscertificatekeyfile")
passphrase = options.get("tlscertificatekeyfilepassword")
ca_certs = options.get("tlscafile")
crlfile = options.get("tlscrlfile")
allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False)
allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False)
disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False)
enabled_tls_opts = []
for opt in (
"tlscertificatekeyfile",
"tlscertificatekeyfilepassword",
"tlscafile",
"tlscrlfile",
):
# Any non-null value of these options implies tls=True.
if opt in options and options[opt]:
enabled_tls_opts.append(opt)
for opt in (
"tlsallowinvalidcertificates",
"tlsallowinvalidhostnames",
"tlsdisableocspendpointcheck",
):
# A value of False for these options implies tls=True.
if opt in options and not options[opt]:
enabled_tls_opts.append(opt)
if enabled_tls_opts:
if use_tls is None:
# Implicitly enable TLS when one of the tls* options is set.
use_tls = True
elif not use_tls:
# Error since tls is explicitly disabled but a tls option is set.
raise ConfigurationError(
"TLS has not been enabled but the "
"following tls parameters have been set: "
"%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts)
)
if use_tls:
ctx = get_ssl_context(
certfile,
passphrase,
ca_certs,
crlfile,
allow_invalid_certificates,
allow_invalid_hostnames,
disable_ocsp_endpoint_check,
)
return ctx, allow_invalid_hostnames
return None, allow_invalid_hostnames
def _parse_pool_options(
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
) -> PoolOptions:
"""Parse connection pool options."""
credentials = _parse_credentials(username, password, database, options)
max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE)
min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE)
max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC)
if max_pool_size is not None and min_pool_size > max_pool_size:
raise ValueError("minPoolSize must be smaller or equal to maxPoolSize")
connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT)
socket_timeout = options.get("sockettimeoutms")
wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT)
event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners"))
appname = options.get("appname")
driver = options.get("driver")
server_api = options.get("server_api")
compression_settings = CompressionSettings(
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
)
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
load_balanced = options.get("loadbalanced")
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
return PoolOptions(
max_pool_size,
min_pool_size,
max_idle_time_seconds,
connect_timeout,
socket_timeout,
wait_queue_timeout,
ssl_context,
tls_allow_invalid_hostnames,
_EventListeners(event_listeners),
appname,
driver,
compression_settings,
max_connecting=max_connecting,
server_api=server_api,
load_balanced=load_balanced,
credentials=credentials,
)
class ClientOptions:
"""Read only configuration options for an AsyncMongoClient.
Should not be instantiated directly by application developers. Access
a client's options via :attr:`pymongo.mongo_client.AsyncMongoClient.options`
instead.
"""
def __init__(
self, username: str, password: str, database: Optional[str], options: Mapping[str, Any]
):
self.__options = options
self.__codec_options = _parse_codec_options(options)
self.__direct_connection = options.get("directconnection")
self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS)
# self.__server_selection_timeout is in seconds. Must use full name for
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
self.__server_selection_timeout = options.get(
"serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT
)
self.__pool_options = _parse_pool_options(username, password, database, options)
self.__read_preference = _parse_read_preference(options)
self.__replica_set_name = options.get("replicaset")
self.__write_concern = _parse_write_concern(options)
self.__read_concern = _parse_read_concern(options)
self.__connect = options.get("connect")
self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY)
self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES)
self.__retry_reads = options.get("retryreads", common.RETRY_READS)
self.__server_selector = options.get("server_selector", any_server_selector)
self.__auto_encryption_opts = options.get("auto_encryption_opts")
self.__load_balanced = options.get("loadbalanced")
self.__timeout = options.get("timeoutms")
self.__server_monitoring_mode = options.get(
"servermonitoringmode", common.SERVER_MONITORING_MODE
)
@property
def _options(self) -> Mapping[str, Any]:
"""The original options used to create this ClientOptions."""
return self.__options
@property
def connect(self) -> Optional[bool]:
"""Whether to begin discovering a MongoDB topology automatically."""
return self.__connect
@property
def codec_options(self) -> CodecOptions:
"""A :class:`~bson.codec_options.CodecOptions` instance."""
return self.__codec_options
@property
def direct_connection(self) -> Optional[bool]:
"""Whether to connect to the deployment in 'Single' topology."""
return self.__direct_connection
@property
def local_threshold_ms(self) -> int:
"""The local threshold for this instance."""
return self.__local_threshold_ms
@property
def server_selection_timeout(self) -> int:
"""The server selection timeout for this instance in seconds."""
return self.__server_selection_timeout
@property
def server_selector(self) -> _ServerSelector:
return self.__server_selector
@property
def heartbeat_frequency(self) -> int:
"""The monitoring frequency in seconds."""
return self.__heartbeat_frequency
@property
def pool_options(self) -> PoolOptions:
"""A :class:`~pymongo.pool.PoolOptions` instance."""
return self.__pool_options
@property
def read_preference(self) -> _ServerMode:
"""A read preference instance."""
return self.__read_preference
@property
def replica_set_name(self) -> Optional[str]:
"""Replica set name or None."""
return self.__replica_set_name
@property
def write_concern(self) -> WriteConcern:
"""A :class:`~pymongo.write_concern.WriteConcern` instance."""
return self.__write_concern
@property
def read_concern(self) -> ReadConcern:
"""A :class:`~pymongo.read_concern.ReadConcern` instance."""
return self.__read_concern
@property
def timeout(self) -> Optional[float]:
"""The configured timeoutMS converted to seconds, or None.
.. versionadded:: 4.2
"""
return self.__timeout
@property
def retry_writes(self) -> bool:
"""If this instance should retry supported write operations."""
return self.__retry_writes
@property
def retry_reads(self) -> bool:
"""If this instance should retry supported read operations."""
return self.__retry_reads
@property
def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]:
"""A :class:`~pymongo.encryption.AutoEncryptionOpts` or None."""
return self.__auto_encryption_opts
@property
def load_balanced(self) -> Optional[bool]:
"""True if the client was configured to connect to a load balancer."""
return self.__load_balanced
@property
def event_listeners(self) -> list[_EventListeners]:
"""The event listeners registered for this client.
See :mod:`~pymongo.monitoring` for details.
.. versionadded:: 4.0
"""
assert self.__pool_options._event_listeners is not None
return self.__pool_options._event_listeners.event_listeners()
@property
def server_monitoring_mode(self) -> str:
"""The configured serverMonitoringMode option.
.. versionadded:: 4.5
"""
return self.__server_monitoring_mode

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,226 @@
# Copyright 2016 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for working with `collations`_.
.. _collations: https://www.mongodb.com/docs/manual/reference/collation/
"""
from __future__ import annotations
from typing import Any, Mapping, Optional, Union
from pymongo.asynchronous import common
from pymongo.write_concern import validate_boolean
_IS_SYNC = False
class CollationStrength:
"""
An enum that defines values for `strength` on a
:class:`~pymongo.collation.Collation`.
"""
PRIMARY = 1
"""Differentiate base (unadorned) characters."""
SECONDARY = 2
"""Differentiate character accents."""
TERTIARY = 3
"""Differentiate character case."""
QUATERNARY = 4
"""Differentiate words with and without punctuation."""
IDENTICAL = 5
"""Differentiate unicode code point (characters are exactly identical)."""
class CollationAlternate:
"""
An enum that defines values for `alternate` on a
:class:`~pymongo.collation.Collation`.
"""
NON_IGNORABLE = "non-ignorable"
"""Spaces and punctuation are treated as base characters."""
SHIFTED = "shifted"
"""Spaces and punctuation are *not* considered base characters.
Spaces and punctuation are distinguished regardless when the
:class:`~pymongo.collation.Collation` strength is at least
:data:`~pymongo.collation.CollationStrength.QUATERNARY`.
"""
class CollationMaxVariable:
"""
An enum that defines values for `max_variable` on a
:class:`~pymongo.collation.Collation`.
"""
PUNCT = "punct"
"""Both punctuation and spaces are ignored."""
SPACE = "space"
"""Spaces alone are ignored."""
class CollationCaseFirst:
"""
An enum that defines values for `case_first` on a
:class:`~pymongo.collation.Collation`.
"""
UPPER = "upper"
"""Sort uppercase characters first."""
LOWER = "lower"
"""Sort lowercase characters first."""
OFF = "off"
"""Default for locale or collation strength."""
class Collation:
"""Collation
:param locale: (string) The locale of the collation. This should be a string
that identifies an `ICU locale ID` exactly. For example, ``en_US`` is
valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB
documentation for a list of supported locales.
:param caseLevel: (optional) If ``True``, turn on case sensitivity if
`strength` is 1 or 2 (case sensitivity is implied if `strength` is
greater than 2). Defaults to ``False``.
:param caseFirst: (optional) Specify that either uppercase or lowercase
characters take precedence. Must be one of the following values:
* :data:`~CollationCaseFirst.UPPER`
* :data:`~CollationCaseFirst.LOWER`
* :data:`~CollationCaseFirst.OFF` (the default)
:param strength: Specify the comparison strength. This is also
known as the ICU comparison level. This must be one of the following
values:
* :data:`~CollationStrength.PRIMARY`
* :data:`~CollationStrength.SECONDARY`
* :data:`~CollationStrength.TERTIARY` (the default)
* :data:`~CollationStrength.QUATERNARY`
* :data:`~CollationStrength.IDENTICAL`
Each successive level builds upon the previous. For example, a
`strength` of :data:`~CollationStrength.SECONDARY` differentiates
characters based both on the unadorned base character and its accents.
:param numericOrdering: If ``True``, order numbers numerically
instead of in collation order (defaults to ``False``).
:param alternate: Specify whether spaces and punctuation are
considered base characters. This must be one of the following values:
* :data:`~CollationAlternate.NON_IGNORABLE` (the default)
* :data:`~CollationAlternate.SHIFTED`
:param maxVariable: When `alternate` is
:data:`~CollationAlternate.SHIFTED`, this option specifies what
characters may be ignored. This must be one of the following values:
* :data:`~CollationMaxVariable.PUNCT` (the default)
* :data:`~CollationMaxVariable.SPACE`
:param normalization: If ``True``, normalizes text into Unicode
NFD. Defaults to ``False``.
:param backwards: If ``True``, accents on characters are
considered from the back of the word to the front, as it is done in some
French dictionary ordering traditions. Defaults to ``False``.
:param kwargs: Keyword arguments supplying any additional options
to be sent with this Collation object.
.. versionadded: 3.4
"""
__slots__ = ("__document",)
def __init__(
self,
locale: str,
caseLevel: Optional[bool] = None,
caseFirst: Optional[str] = None,
strength: Optional[int] = None,
numericOrdering: Optional[bool] = None,
alternate: Optional[str] = None,
maxVariable: Optional[str] = None,
normalization: Optional[bool] = None,
backwards: Optional[bool] = None,
**kwargs: Any,
) -> None:
locale = common.validate_string("locale", locale)
self.__document: dict[str, Any] = {"locale": locale}
if caseLevel is not None:
self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel)
if caseFirst is not None:
self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst)
if strength is not None:
self.__document["strength"] = common.validate_integer("strength", strength)
if numericOrdering is not None:
self.__document["numericOrdering"] = validate_boolean(
"numericOrdering", numericOrdering
)
if alternate is not None:
self.__document["alternate"] = common.validate_string("alternate", alternate)
if maxVariable is not None:
self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable)
if normalization is not None:
self.__document["normalization"] = validate_boolean("normalization", normalization)
if backwards is not None:
self.__document["backwards"] = validate_boolean("backwards", backwards)
self.__document.update(kwargs)
@property
def document(self) -> dict[str, Any]:
"""The document representation of this collation.
.. note::
:class:`Collation` is immutable. Mutating the value of
:attr:`document` does not mutate this :class:`Collation`.
"""
return self.__document.copy()
def __repr__(self) -> str:
document = self.document
return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document))
def __eq__(self, other: Any) -> bool:
if isinstance(other, Collation):
return self.document == other.document
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def validate_collation_or_none(
value: Optional[Union[Mapping[str, Any], Collation]]
) -> Optional[dict[str, Any]]:
if value is None:
return None
if isinstance(value, Collation):
return value.document
if isinstance(value, dict):
return value
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,415 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CommandCursor class to iterate over command results."""
from __future__ import annotations
from collections import deque
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Generic,
Mapping,
NoReturn,
Optional,
Sequence,
Union,
)
from bson import CodecOptions, _convert_raw_document_lists_to_streams
from pymongo.asynchronous.cursor import _ConnectionManager
from pymongo.asynchronous.message import (
_CursorAddress,
_GetMore,
_OpMsg,
_OpReply,
_RawBatchGetMore,
)
from pymongo.asynchronous.response import PinnedResponse
from pymongo.asynchronous.typings import _Address, _DocumentOut, _DocumentType
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
if TYPE_CHECKING:
from pymongo.asynchronous.client_session import ClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.pool import Connection
_IS_SYNC = False
class AsyncCommandCursor(Generic[_DocumentType]):
"""An asynchronous cursor / iterator over command cursors."""
_getmore_class = _GetMore
def __init__(
self,
collection: AsyncCollection[_DocumentType],
cursor_info: Mapping[str, Any],
address: Optional[_Address],
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[ClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new command cursor."""
self._sock_mgr: Any = None
self._collection: AsyncCollection[_DocumentType] = collection
self._id = cursor_info["id"]
self._data = deque(cursor_info["firstBatch"])
self._postbatchresumetoken: Optional[Mapping[str, Any]] = cursor_info.get(
"postBatchResumeToken"
)
self._address = address
self._batch_size = batch_size
self._max_await_time_ms = max_await_time_ms
self._session = session
self._explicit_session = explicit_session
self._killed = self._id == 0
self._comment = comment
if _IS_SYNC and self._killed:
self._end_session(True) # type: ignore[unused-coroutine]
if "ns" in cursor_info: # noqa: SIM401
self._ns = cursor_info["ns"]
else:
self._ns = collection.full_name
self.batch_size(batch_size)
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
raise TypeError("max_await_time_ms must be an integer or None")
def __del__(self) -> None:
if _IS_SYNC:
self._die(False) # type: ignore[unused-coroutine]
def batch_size(self, batch_size: int) -> AsyncCommandCursor[_DocumentType]:
"""Limits the number of documents returned in one batch. Each batch
requires a round trip to the server. It can be adjusted to optimize
performance and limit data transfer.
.. note:: batch_size can not override MongoDB's internal limits on the
amount of data it will return to the client in a single batch (i.e
if you set batch size to 1,000,000,000, MongoDB will currently only
return 4-16MB of results per batch).
Raises :exc:`TypeError` if `batch_size` is not an integer.
Raises :exc:`ValueError` if `batch_size` is less than ``0``.
:param batch_size: The size of each batch of results requested.
"""
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")
if batch_size < 0:
raise ValueError("batch_size must be >= 0")
self._batch_size = batch_size == 1 and 2 or batch_size
return self
def _has_next(self) -> bool:
"""Returns `True` if the cursor has documents remaining from the
previous batch.
"""
return len(self._data) > 0
@property
def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]:
"""Retrieve the postBatchResumeToken from the response to a
changeStream aggregate or getMore.
"""
return self._postbatchresumetoken
async def _maybe_pin_connection(self, conn: Connection) -> None:
client = self._collection.database.client
if not client._should_pin_cursor(self._session):
return
if not self._sock_mgr:
conn.pin_cursor()
conn_mgr = _ConnectionManager(conn, False)
# Ensure the connection gets returned when the entire result is
# returned in the first batch.
if self._id == 0:
await conn_mgr.close()
else:
self._sock_mgr = conn_mgr
def _unpack_response(
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions[Mapping[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> Sequence[_DocumentOut]:
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
@property
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
Even if :attr:`alive` is ``True``, :meth:`next` can raise
:exc:`StopIteration`. Best to use a for loop::
async for doc in collection.aggregate(pipeline):
print(doc)
.. note:: :attr:`alive` can be True while iterating a cursor from
a failed server. In this case :attr:`alive` will return False after
:meth:`next` fails to retrieve the next batch of results from the
server.
"""
return bool(len(self._data) or (not self._killed))
@property
def cursor_id(self) -> int:
"""Returns the id of the cursor."""
return self._id
@property
def address(self) -> Optional[_Address]:
"""The (host, port) of the server used, or None.
.. versionadded:: 3.0
"""
return self._address
@property
def session(self) -> Optional[ClientSession]:
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
.. versionadded:: 3.6
"""
if self._explicit_session:
return self._session
return None
async def _die(self, synchronous: bool = False) -> None:
"""Closes this cursor."""
already_killed = self._killed
self._killed = True
if self._id and not already_killed:
cursor_id = self._id
assert self._address is not None
address = _CursorAddress(self._address, self._ns)
else:
# Skip killCursors.
cursor_id = 0
address = None
await self._collection.database.client._cleanup_cursor(
synchronous,
cursor_id,
address,
self._sock_mgr,
self._session,
self._explicit_session,
)
if not self._explicit_session:
self._session = None
self._sock_mgr = None
async def _end_session(self, synchronous: bool) -> None:
if self._session and not self._explicit_session:
await self._session._end_session(lock=synchronous)
self._session = None
async def close(self) -> None:
"""Explicitly close / kill this cursor."""
await self._die(True)
async def _send_message(self, operation: _GetMore) -> None:
"""Send a getmore message and handle the response."""
client = self._collection.database.client
try:
response = await client._run_operation(
operation, self._unpack_response, address=self._address
)
except OperationFailure as exc:
if exc.code in _CURSOR_CLOSED_ERRORS:
# Don't send killCursors because the cursor is already closed.
self._killed = True
if exc.timeout:
await self._die(False)
else:
# Return the session and pinned connection, if necessary.
await self.close()
raise
except ConnectionFailure:
# Don't send killCursors because the cursor is already closed.
self._killed = True
# Return the session and pinned connection, if necessary.
await self.close()
raise
except Exception:
await self.close()
raise
if isinstance(response, PinnedResponse):
if not self._sock_mgr:
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
if response.from_command:
cursor = response.docs[0]["cursor"]
documents = cursor["nextBatch"]
self._postbatchresumetoken = cursor.get("postBatchResumeToken")
self._id = cursor["id"]
else:
documents = response.docs
assert isinstance(response.data, _OpReply)
self._id = response.data.cursor_id
if self._id == 0:
await self.close()
self._data = deque(documents)
async def _refresh(self) -> int:
"""Refreshes the cursor with more data from the server.
Returns the length of self._data after refresh. Will exit early if
self._data is already non-empty. Raises OperationFailure when the
cursor cannot be refreshed due to an error on the query.
"""
if len(self._data) or self._killed:
return len(self._data)
if self._id: # Get More
dbname, collname = self._ns.split(".", 1)
read_pref = self._collection._read_preference_for(self.session)
await self._send_message(
self._getmore_class(
dbname,
collname,
self._batch_size,
self._id,
self._collection.codec_options,
read_pref,
self._session,
self._collection.database.client,
self._max_await_time_ms,
self._sock_mgr,
False,
self._comment,
)
)
else: # Cursor id is zero nothing else to return
await self._die(True)
return len(self._data)
def __aiter__(self) -> AsyncIterator[_DocumentType]:
return self
async def next(self) -> _DocumentType:
"""Advance the cursor."""
# Block until a document is returnable.
while self.alive:
doc = await self._try_next(True)
if doc is not None:
return doc
raise StopAsyncIteration
async def __anext__(self) -> _DocumentType:
return await self.next()
async def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
"""Advance the cursor blocking for at most one getMore command."""
if not len(self._data) and not self._killed and get_more_allowed:
await self._refresh()
if len(self._data):
return self._data.popleft()
else:
return None
async def try_next(self) -> Optional[_DocumentType]:
"""Advance the cursor without blocking indefinitely.
This method returns the next document without waiting
indefinitely for data.
If no document is cached locally then this method runs a single
getMore command. If the getMore yields any documents, the next
document is returned, otherwise, if the getMore returns no documents
(because there is no additional data) then ``None`` is returned.
:return: The next document or ``None`` when no document is available
after running a single getMore or when the cursor is closed.
.. versionadded:: 4.5
"""
return await self._try_next(get_more_allowed=True)
async def __aenter__(self) -> AsyncCommandCursor[_DocumentType]:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
async def to_list(self) -> list[_DocumentType]:
return [x async for x in self] # noqa: C416,RUF100
class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):
_getmore_class = _RawBatchGetMore
def __init__(
self,
collection: AsyncCollection[_DocumentType],
cursor_info: Mapping[str, Any],
address: Optional[_Address],
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[ClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new cursor / iterator over raw batches of BSON data.
Should not be called directly by application developers -
see :meth:`~pymongo.collection.AsyncCollection.aggregate_raw_batches`
instead.
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
"""
assert not cursor_info.get("firstBatch")
super().__init__(
collection,
cursor_info,
address,
batch_size,
max_await_time_ms,
session,
explicit_session,
comment,
)
def _unpack_response( # type: ignore[override]
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> list[Mapping[str, Any]]:
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
if not legacy_response:
# OP_MSG returns firstBatch/nextBatch documents as a BSON array
# Re-assemble the array of documents into a document stream
_convert_raw_document_lists_to_streams(raw_response[0])
return raw_response # type: ignore[return-value]
def __getitem__(self, index: int) -> NoReturn:
raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,178 @@
# Copyright 2018 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from typing import Any, Iterable, Optional, Union
from pymongo.asynchronous.hello_compat import HelloCompat
from pymongo.helpers_constants import _SENSITIVE_COMMANDS
_IS_SYNC = False
_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"}
_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD}
_NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
def _have_snappy() -> bool:
try:
import snappy # type:ignore[import] # noqa: F401
return True
except ImportError:
return False
def _have_zlib() -> bool:
try:
import zlib # noqa: F401
return True
except ImportError:
return False
def _have_zstd() -> bool:
try:
import zstandard # noqa: F401
return True
except ImportError:
return False
def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[str]:
try:
# `value` is string.
compressors = value.split(",") # type: ignore[union-attr]
except AttributeError:
# `value` is an iterable.
compressors = list(value)
for compressor in compressors[:]:
if compressor not in _SUPPORTED_COMPRESSORS:
compressors.remove(compressor)
warnings.warn(f"Unsupported compressor: {compressor}", stacklevel=2)
elif compressor == "snappy" and not _have_snappy():
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with snappy is not available. "
"You must install the python-snappy module for snappy support.",
stacklevel=2,
)
elif compressor == "zlib" and not _have_zlib():
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with zlib is not available. "
"The zlib module is not available.",
stacklevel=2,
)
elif compressor == "zstd" and not _have_zstd():
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with zstandard is not available. "
"You must install the zstandard module for zstandard support.",
stacklevel=2,
)
return compressors
def validate_zlib_compression_level(option: str, value: Any) -> int:
try:
level = int(value)
except Exception:
raise TypeError(f"{option} must be an integer, not {value!r}.") from None
if level < -1 or level > 9:
raise ValueError("%s must be between -1 and 9, not %d." % (option, level))
return level
class CompressionSettings:
def __init__(self, compressors: list[str], zlib_compression_level: int):
self.compressors = compressors
self.zlib_compression_level = zlib_compression_level
def get_compression_context(
self, compressors: Optional[list[str]]
) -> Union[SnappyContext, ZlibContext, ZstdContext, None]:
if compressors:
chosen = compressors[0]
if chosen == "snappy":
return SnappyContext()
elif chosen == "zlib":
return ZlibContext(self.zlib_compression_level)
elif chosen == "zstd":
return ZstdContext()
return None
return None
class SnappyContext:
compressor_id = 1
@staticmethod
def compress(data: bytes) -> bytes:
import snappy
return snappy.compress(data)
class ZlibContext:
compressor_id = 2
def __init__(self, level: int):
self.level = level
def compress(self, data: bytes) -> bytes:
import zlib
return zlib.compress(data, self.level)
class ZstdContext:
compressor_id = 3
@staticmethod
def compress(data: bytes) -> bytes:
# ZstdCompressor is not thread safe.
# TODO: Use a pool?
import zstandard
return zstandard.ZstdCompressor().compress(data)
def decompress(data: bytes, compressor_id: int) -> bytes:
if compressor_id == SnappyContext.compressor_id:
# python-snappy doesn't support the buffer interface.
# https://github.com/andrix/python-snappy/issues/65
# This only matters when data is a memoryview since
# id(bytes(data)) == id(data) when data is a bytes.
import snappy
return snappy.uncompress(bytes(data))
elif compressor_id == ZlibContext.compressor_id:
import zlib
return zlib.decompress(data)
elif compressor_id == ZstdContext.compressor_id:
# ZstdDecompressor is not thread safe.
# TODO: Use a pool?
import zstandard
return zstandard.ZstdDecompressor().decompress(data)
else:
raise ValueError("Unknown compressorId %d" % (compressor_id,))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,270 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support for automatic client-side field level encryption."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional
try:
import pymongocrypt # type:ignore[import] # noqa: F401
_HAVE_PYMONGOCRYPT = True
except ImportError:
_HAVE_PYMONGOCRYPT = False
from bson import int64
from pymongo.asynchronous.common import validate_is_mapping
from pymongo.asynchronous.uri_parser import _parse_kms_tls_options
from pymongo.errors import ConfigurationError
if TYPE_CHECKING:
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.typings import _DocumentTypeArg
_IS_SYNC = False
class AutoEncryptionOpts:
"""Options to configure automatic client-side field level encryption."""
def __init__(
self,
kms_providers: Mapping[str, Any],
key_vault_namespace: str,
key_vault_client: Optional[AsyncMongoClient[_DocumentTypeArg]] = None,
schema_map: Optional[Mapping[str, Any]] = None,
bypass_auto_encryption: bool = False,
mongocryptd_uri: str = "mongodb://localhost:27020",
mongocryptd_bypass_spawn: bool = False,
mongocryptd_spawn_path: str = "mongocryptd",
mongocryptd_spawn_args: Optional[list[str]] = None,
kms_tls_options: Optional[Mapping[str, Any]] = None,
crypt_shared_lib_path: Optional[str] = None,
crypt_shared_lib_required: bool = False,
bypass_query_analysis: bool = False,
encrypted_fields_map: Optional[Mapping[str, Any]] = None,
) -> None:
"""Options to configure automatic client-side field level encryption.
Automatic client-side field level encryption requires MongoDB >=4.2
enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not
supported for operations on a database or view and will result in
error.
Although automatic encryption requires MongoDB >=4.2 enterprise or a
MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all
users. To configure automatic *decryption* without automatic
*encryption* set ``bypass_auto_encryption=True``. Explicit
encryption and explicit decryption is also supported for all users
with the :class:`~pymongo.encryption.ClientEncryption` class.
See :ref:`automatic-client-side-encryption` for an example.
:param kms_providers: Map of KMS provider options. The `kms_providers`
map values differ by provider:
- `aws`: Map with "accessKeyId" and "secretAccessKey" as strings.
These are the AWS access key ID and AWS secret access key used
to generate KMS messages. An optional "sessionToken" may be
included to support temporary AWS credentials.
- `azure`: Map with "tenantId", "clientId", and "clientSecret" as
strings. Additionally, "identityPlatformEndpoint" may also be
specified as a string (defaults to 'login.microsoftonline.com').
These are the Azure Active Directory credentials used to
generate Azure Key Vault messages.
- `gcp`: Map with "email" as a string and "privateKey"
as `bytes` or a base64 encoded string.
Additionally, "endpoint" may also be specified as a string
(defaults to 'oauth2.googleapis.com'). These are the
credentials used to generate Google Cloud KMS messages.
- `kmip`: Map with "endpoint" as a host with required port.
For example: ``{"endpoint": "example.com:443"}``.
- `local`: Map with "key" as `bytes` (96 bytes in length) or
a base64 encoded string which decodes
to 96 bytes. "key" is the master key used to encrypt/decrypt
data keys. This key should be generated and stored as securely
as possible.
KMS providers may be specified with an optional name suffix
separated by a colon, for example "kmip:name" or "aws:name".
Named KMS providers do not support :ref:`CSFLE on-demand credentials`.
Named KMS providers enables more than one of each KMS provider type to be configured.
For example, to configure multiple local KMS providers::
kms_providers = {
"local": {"key": local_kek1}, # Unnamed KMS provider.
"local:myname": {"key": local_kek2}, # Named KMS provider with name "myname".
}
:param key_vault_namespace: The namespace for the key vault collection.
The key vault collection contains all data keys used for encryption
and decryption. Data keys are stored as documents in this MongoDB
collection. Data keys are protected with encryption by a KMS
provider.
:param key_vault_client: By default, the key vault collection
is assumed to reside in the same MongoDB cluster as the encrypted
AsyncMongoClient. Use this option to route data key queries to a
separate MongoDB cluster.
:param schema_map: Map of collection namespace ("db.coll") to
JSON Schema. By default, a collection's JSONSchema is periodically
polled with the listCollections command. But a JSONSchema may be
specified locally with the schemaMap option.
**Supplying a `schema_map` provides more security than relying on
JSON Schemas obtained from the server. It protects against a
malicious server advertising a false JSON Schema, which could trick
the client into sending unencrypted data that should be
encrypted.**
Schemas supplied in the schemaMap only apply to configuring
automatic encryption for client side encryption. Other validation
rules in the JSON schema will not be enforced by the driver and
will result in an error.
:param bypass_auto_encryption: If ``True``, automatic
encryption will be disabled but automatic decryption will still be
enabled. Defaults to ``False``.
:param mongocryptd_uri: The MongoDB URI used to connect
to the *local* mongocryptd process. Defaults to
``'mongodb://localhost:27020'``.
:param mongocryptd_bypass_spawn: If ``True``, the encrypted
AsyncMongoClient will not attempt to spawn the mongocryptd process.
Defaults to ``False``.
:param mongocryptd_spawn_path: Used for spawning the
mongocryptd process. Defaults to ``'mongocryptd'`` and spawns
mongocryptd from the system path.
:param mongocryptd_spawn_args: A list of string arguments to
use when spawning the mongocryptd process. Defaults to
``['--idleShutdownTimeoutSecs=60']``. If the list does not include
the ``idleShutdownTimeoutSecs`` option then
``'--idleShutdownTimeoutSecs=60'`` will be added.
:param kms_tls_options: A map of KMS provider names to TLS
options to use when creating secure connections to KMS providers.
Accepts the same TLS options as
:class:`pymongo.mongo_client.AsyncMongoClient`. For example, to
override the system default CA file::
kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}}
Or to supply a client certificate::
kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}}
:param crypt_shared_lib_path: Override the path to load the crypt_shared library.
:param crypt_shared_lib_required: If True, raise an error if libmongocrypt is
unable to load the crypt_shared library.
:param bypass_query_analysis: If ``True``, disable automatic analysis
of outgoing commands. Set `bypass_query_analysis` to use explicit
encryption on indexed fields without the MongoDB Enterprise Advanced
licensed crypt_shared library.
:param encrypted_fields_map: Map of collection namespace ("db.coll") to documents
that described the encrypted fields for Queryable Encryption. For example::
{
"db.encryptedCollection": {
"escCollection": "enxcol_.encryptedCollection.esc",
"ecocCollection": "enxcol_.encryptedCollection.ecoc",
"fields": [
{
"path": "firstName",
"keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')),
"bsonType": "string",
"queries": {"queryType": "equality"}
},
{
"path": "ssn",
"keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')),
"bsonType": "string"
}
]
}
}
.. versionchanged:: 4.2
Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`,
and `bypass_query_analysis` parameters.
.. versionchanged:: 4.0
Added the `kms_tls_options` parameter and the "kmip" KMS provider.
.. versionadded:: 3.9
"""
if not _HAVE_PYMONGOCRYPT:
raise ConfigurationError(
"client side encryption requires the pymongocrypt library: "
"install a compatible version with: "
"python -m pip install 'pymongo[encryption]'"
)
if encrypted_fields_map:
validate_is_mapping("encrypted_fields_map", encrypted_fields_map)
self._encrypted_fields_map = encrypted_fields_map
self._bypass_query_analysis = bypass_query_analysis
self._crypt_shared_lib_path = crypt_shared_lib_path
self._crypt_shared_lib_required = crypt_shared_lib_required
self._kms_providers = kms_providers
self._key_vault_namespace = key_vault_namespace
self._key_vault_client = key_vault_client
self._schema_map = schema_map
self._bypass_auto_encryption = bypass_auto_encryption
self._mongocryptd_uri = mongocryptd_uri
self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
self._mongocryptd_spawn_path = mongocryptd_spawn_path
if mongocryptd_spawn_args is None:
mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"]
self._mongocryptd_spawn_args = mongocryptd_spawn_args
if not isinstance(self._mongocryptd_spawn_args, list):
raise TypeError("mongocryptd_spawn_args must be a list")
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
# Maps KMS provider name to a SSLContext.
self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options)
self._bypass_query_analysis = bypass_query_analysis
class RangeOpts:
"""Options to configure encrypted queries using the rangePreview algorithm."""
def __init__(
self,
sparsity: int,
min: Optional[Any] = None,
max: Optional[Any] = None,
precision: Optional[int] = None,
) -> None:
"""Options to configure encrypted queries using the rangePreview algorithm.
.. note:: This feature is experimental only, and not intended for public use.
:param sparsity: An integer.
:param min: A BSON scalar value corresponding to the type being queried.
:param max: A BSON scalar value corresponding to the type being queried.
:param precision: An integer, may only be set for double or decimal128 types.
.. versionadded:: 4.4
"""
self.min = min
self.max = max
self.sparsity = sparsity
self.precision = precision
@property
def document(self) -> dict[str, Any]:
doc = {}
for k, v in [
("sparsity", int64.Int64(self.sparsity)),
("precision", self.precision),
("min", self.min),
("max", self.max),
]:
if v is not None:
doc[k] = v
return doc

View File

@ -0,0 +1,225 @@
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example event logger classes.
.. versionadded:: 3.11
These loggers can be registered using :func:`register` or
:class:`~pymongo.mongo_client.MongoClient`.
``monitoring.register(CommandLogger())``
or
``MongoClient(event_listeners=[CommandLogger()])``
"""
from __future__ import annotations
import logging
from pymongo.asynchronous import monitoring
_IS_SYNC = False
class CommandLogger(monitoring.CommandListener):
"""A simple listener that logs command events.
Listens for :class:`~pymongo.monitoring.CommandStartedEvent`,
:class:`~pymongo.monitoring.CommandSucceededEvent` and
:class:`~pymongo.monitoring.CommandFailedEvent` events and
logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def started(self, event: monitoring.CommandStartedEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} started on server "
f"{event.connection_id}"
)
def succeeded(self, event: monitoring.CommandSucceededEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} on server {event.connection_id} "
f"succeeded in {event.duration_micros} "
"microseconds"
)
def failed(self, event: monitoring.CommandFailedEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} on server {event.connection_id} "
f"failed in {event.duration_micros} "
"microseconds"
)
class ServerLogger(monitoring.ServerListener):
"""A simple listener that logs server discovery events.
Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`,
:class:`~pymongo.monitoring.ServerDescriptionChangedEvent`,
and :class:`~pymongo.monitoring.ServerClosedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def opened(self, event: monitoring.ServerOpeningEvent) -> None:
logging.info(f"Server {event.server_address} added to topology {event.topology_id}")
def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None:
previous_server_type = event.previous_description.server_type
new_server_type = event.new_description.server_type
if new_server_type != previous_server_type:
# server_type_name was added in PyMongo 3.4
logging.info(
f"Server {event.server_address} changed type from "
f"{event.previous_description.server_type_name} to "
f"{event.new_description.server_type_name}"
)
def closed(self, event: monitoring.ServerClosedEvent) -> None:
logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}")
class HeartbeatLogger(monitoring.ServerHeartbeatListener):
"""A simple listener that logs server heartbeat events.
Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`,
:class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`,
and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
logging.info(f"Heartbeat sent to server {event.connection_id}")
def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
# The reply.document attribute was added in PyMongo 3.4.
logging.info(
f"Heartbeat to server {event.connection_id} "
"succeeded with reply "
f"{event.reply.document}"
)
def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
logging.warning(
f"Heartbeat to server {event.connection_id} failed with error {event.reply}"
)
class TopologyLogger(monitoring.TopologyListener):
"""A simple listener that logs server topology events.
Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`,
:class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`,
and :class:`~pymongo.monitoring.TopologyClosedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def opened(self, event: monitoring.TopologyOpenedEvent) -> None:
logging.info(f"Topology with id {event.topology_id} opened")
def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None:
logging.info(f"Topology description updated for topology id {event.topology_id}")
previous_topology_type = event.previous_description.topology_type
new_topology_type = event.new_description.topology_type
if new_topology_type != previous_topology_type:
# topology_type_name was added in PyMongo 3.4
logging.info(
f"Topology {event.topology_id} changed type from "
f"{event.previous_description.topology_type_name} to "
f"{event.new_description.topology_type_name}"
)
# The has_writable_server and has_readable_server methods
# were added in PyMongo 3.4.
if not event.new_description.has_writable_server():
logging.warning("No writable servers available.")
if not event.new_description.has_readable_server():
logging.warning("No readable servers available.")
def closed(self, event: monitoring.TopologyClosedEvent) -> None:
logging.info(f"Topology with id {event.topology_id} closed")
class ConnectionPoolLogger(monitoring.ConnectionPoolListener):
"""A simple listener that logs server connection pool events.
Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`,
:class:`~pymongo.monitoring.PoolClearedEvent`,
:class:`~pymongo.monitoring.PoolClosedEvent`,
:~pymongo.monitoring.class:`ConnectionCreatedEvent`,
:class:`~pymongo.monitoring.ConnectionReadyEvent`,
:class:`~pymongo.monitoring.ConnectionClosedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckedOutEvent`,
and :class:`~pymongo.monitoring.ConnectionCheckedInEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def pool_created(self, event: monitoring.PoolCreatedEvent) -> None:
logging.info(f"[pool {event.address}] pool created")
def pool_ready(self, event: monitoring.PoolReadyEvent) -> None:
logging.info(f"[pool {event.address}] pool ready")
def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None:
logging.info(f"[pool {event.address}] pool cleared")
def pool_closed(self, event: monitoring.PoolClosedEvent) -> None:
logging.info(f"[pool {event.address}] pool closed")
def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None:
logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created")
def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded"
)
def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] "
f'connection closed, reason: "{event.reason}"'
)
def connection_check_out_started(
self, event: monitoring.ConnectionCheckOutStartedEvent
) -> None:
logging.info(f"[pool {event.address}] connection check out started")
def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None:
logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}")
def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool"
)
def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool"
)

View File

@ -21,17 +21,12 @@ import itertools
from typing import Any, Generic, Mapping, Optional
from bson.objectid import ObjectId
from pymongo import common
from pymongo.asynchronous import common
from pymongo.asynchronous.hello_compat import HelloCompat
from pymongo.asynchronous.typings import ClusterTime, _DocumentType
from pymongo.server_type import SERVER_TYPE
from pymongo.typings import ClusterTime, _DocumentType
class HelloCompat:
CMD = "hello"
LEGACY_CMD = "ismaster"
PRIMARY = "isWritablePrimary"
LEGACY_PRIMARY = "ismaster"
LEGACY_ERROR = "not master"
_IS_SYNC = False
def _get_server_type(doc: Mapping[str, Any]) -> int:

View File

@ -0,0 +1,26 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The HelloCompat class, placed here to break circular import issues."""
from __future__ import annotations
_IS_SYNC = False
class HelloCompat:
CMD = "hello"
LEGACY_CMD = "ismaster"
PRIMARY = "isWritablePrimary"
LEGACY_PRIMARY = "ismaster"
LEGACY_ERROR = "not master"

View File

@ -0,0 +1,321 @@
# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bits and pieces used by the driver that don't really fit elsewhere."""
from __future__ import annotations
import builtins
import sys
import traceback
from collections import abc
from typing import (
TYPE_CHECKING,
Any,
Callable,
Container,
Iterable,
Mapping,
NoReturn,
Optional,
Sequence,
TypeVar,
Union,
cast,
)
from pymongo import ASCENDING
from pymongo.asynchronous.hello_compat import HelloCompat
from pymongo.errors import (
CursorNotFound,
DuplicateKeyError,
ExecutionTimeout,
NotPrimaryError,
OperationFailure,
WriteConcernError,
WriteError,
WTimeoutError,
_wtimeout_error,
)
from pymongo.helpers_constants import _NOT_PRIMARY_CODES, _REAUTHENTICATION_REQUIRED_CODE
if TYPE_CHECKING:
from pymongo.asynchronous.operations import _IndexList
from pymongo.asynchronous.typings import _DocumentOut
from pymongo.cursor_shared import _Hint
_IS_SYNC = False
def _gen_index_name(keys: _IndexList) -> str:
"""Generate an index name from the set of fields it is over."""
return "_".join(["{}_{}".format(*item) for item in keys])
def _index_list(
key_or_list: _Hint, direction: Optional[Union[int, str]] = None
) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]:
"""Helper to generate a list of (key, direction) pairs.
Takes such a list, or a single key, or a single key and direction.
"""
if direction is not None:
if not isinstance(key_or_list, str):
raise TypeError("Expected a string and a direction")
return [(key_or_list, direction)]
else:
if isinstance(key_or_list, str):
return [(key_or_list, ASCENDING)]
elif isinstance(key_or_list, abc.ItemsView):
return list(key_or_list) # type: ignore[arg-type]
elif isinstance(key_or_list, abc.Mapping):
return list(key_or_list.items())
elif not isinstance(key_or_list, (list, tuple)):
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
values: list[tuple[str, int]] = []
for item in key_or_list:
if isinstance(item, str):
item = (item, ASCENDING) # noqa: PLW2901
values.append(item)
return values
def _index_document(index_list: _IndexList) -> dict[str, Any]:
"""Helper to generate an index specifying document.
Takes a list of (key, direction) pairs.
"""
if not isinstance(index_list, (list, tuple, abc.Mapping)):
raise TypeError(
"must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list)
)
if not len(index_list):
raise ValueError("key_or_list must not be empty")
index: dict[str, Any] = {}
if isinstance(index_list, abc.Mapping):
for key in index_list:
value = index_list[key]
_validate_index_key_pair(key, value)
index[key] = value
else:
for item in index_list:
if isinstance(item, str):
item = (item, ASCENDING) # noqa: PLW2901
key, value = item
_validate_index_key_pair(key, value)
index[key] = value
return index
def _validate_index_key_pair(key: Any, value: Any) -> None:
if not isinstance(key, str):
raise TypeError("first item in each key pair must be an instance of str")
if not isinstance(value, (str, int, abc.Mapping)):
raise TypeError(
"second item in each key pair must be 1, -1, "
"'2d', or another valid MongoDB index specifier."
)
def _check_command_response(
response: _DocumentOut,
max_wire_version: Optional[int],
allowable_errors: Optional[Container[Union[int, str]]] = None,
parse_write_concern_error: bool = False,
) -> None:
"""Check the response to a command for errors."""
if "ok" not in response:
# Server didn't recognize our message as a command.
raise OperationFailure(
response.get("$err"), # type: ignore[arg-type]
response.get("code"),
response,
max_wire_version,
)
if parse_write_concern_error and "writeConcernError" in response:
_error = response["writeConcernError"]
_labels = response.get("errorLabels")
if _labels:
_error.update({"errorLabels": _labels})
_raise_write_concern_error(_error)
if response["ok"]:
return
details = response
# Mongos returns the error details in a 'raw' object
# for some errors.
if "raw" in response:
for shard in response["raw"].values():
# Grab the first non-empty raw error from a shard.
if shard.get("errmsg") and not shard.get("ok"):
details = shard
break
errmsg = details["errmsg"]
code = details.get("code")
# For allowable errors, only check for error messages when the code is not
# included.
if allowable_errors:
if code is not None:
if code in allowable_errors:
return
elif errmsg in allowable_errors:
return
# Server is "not primary" or "recovering"
if code is not None:
if code in _NOT_PRIMARY_CODES:
raise NotPrimaryError(errmsg, response)
elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg:
raise NotPrimaryError(errmsg, response)
# Other errors
# findAndModify with upsert can raise duplicate key error
if code in (11000, 11001, 12582):
raise DuplicateKeyError(errmsg, code, response, max_wire_version)
elif code == 50:
raise ExecutionTimeout(errmsg, code, response, max_wire_version)
elif code == 43:
raise CursorNotFound(errmsg, code, response, max_wire_version)
raise OperationFailure(errmsg, code, response, max_wire_version)
def _raise_last_write_error(write_errors: list[Any]) -> NoReturn:
# If the last batch had multiple errors only report
# the last error to emulate continue_on_error.
error = write_errors[-1]
if error.get("code") == 11000:
raise DuplicateKeyError(error.get("errmsg"), 11000, error)
raise WriteError(error.get("errmsg"), error.get("code"), error)
def _raise_write_concern_error(error: Any) -> NoReturn:
if _wtimeout_error(error):
# Make sure we raise WTimeoutError
raise WTimeoutError(error.get("errmsg"), error.get("code"), error)
raise WriteConcernError(error.get("errmsg"), error.get("code"), error)
def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
"""Return the writeConcernError or None."""
wce = result.get("writeConcernError")
if wce:
# The server reports errorLabels at the top level but it's more
# convenient to attach it to the writeConcernError doc itself.
error_labels = result.get("errorLabels")
if error_labels:
# Copy to avoid changing the original document.
wce = wce.copy()
wce["errorLabels"] = error_labels
return wce
def _check_write_command_response(result: Mapping[str, Any]) -> None:
"""Backward compatibility helper for write command error handling."""
# Prefer write errors over write concern errors
write_errors = result.get("writeErrors")
if write_errors:
_raise_last_write_error(write_errors)
wce = _get_wce_doc(result)
if wce:
_raise_write_concern_error(wce)
def _fields_list_to_dict(
fields: Union[Mapping[str, Any], Iterable[str]], option_name: str
) -> Mapping[str, Any]:
"""Takes a sequence of field names and returns a matching dictionary.
["a", "b"] becomes {"a": 1, "b": 1}
and
["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1}
"""
if isinstance(fields, abc.Mapping):
return fields
if isinstance(fields, (abc.Sequence, abc.Set)):
if not all(isinstance(field, str) for field in fields):
raise TypeError(f"{option_name} must be a list of key names, each an instance of str")
return dict.fromkeys(fields, 1)
raise TypeError(f"{option_name} must be a mapping or list of key names")
def _handle_exception() -> None:
"""Print exceptions raised by subscribers to stderr."""
# Heavily influenced by logging.Handler.handleError.
# See note here:
# https://docs.python.org/3.4/library/sys.html#sys.__stderr__
if sys.stderr:
einfo = sys.exc_info()
try:
traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr)
except OSError:
pass
finally:
del einfo
# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories
F = TypeVar("F", bound=Callable[..., Any])
def _handle_reauth(func: F) -> F:
async def inner(*args: Any, **kwargs: Any) -> Any:
no_reauth = kwargs.pop("no_reauth", False)
from pymongo.asynchronous.message import _BulkWriteContext
from pymongo.asynchronous.pool import Connection
try:
return await func(*args, **kwargs)
except OperationFailure as exc:
if no_reauth:
raise
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
# Look for an argument that either is a Connection
# or has a connection attribute, so we can trigger
# a reauth.
conn = None
for arg in args:
if isinstance(arg, Connection):
conn = arg
break
if isinstance(arg, _BulkWriteContext):
conn = arg.conn
break
if conn:
await conn.authenticate(reauthenticate=True)
else:
raise
return func(*args, **kwargs)
raise
return cast(F, inner)
async def anext(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
if sys.version_info >= (3, 10):
return await builtins.anext(cls)
else:
return await cls.__anext__()

View File

@ -0,0 +1,171 @@
# Copyright 2023-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import enum
import logging
import os
import warnings
from typing import Any
from bson import UuidRepresentation, json_util
from bson.json_util import JSONOptions, _truncate_documents
from pymongo.asynchronous.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
_IS_SYNC = False
class _CommandStatusMessage(str, enum.Enum):
STARTED = "Command started"
SUCCEEDED = "Command succeeded"
FAILED = "Command failed"
class _ServerSelectionStatusMessage(str, enum.Enum):
STARTED = "Server selection started"
SUCCEEDED = "Server selection succeeded"
FAILED = "Server selection failed"
WAITING = "Waiting for suitable server to become available"
class _ConnectionStatusMessage(str, enum.Enum):
POOL_CREATED = "Connection pool created"
POOL_READY = "Connection pool ready"
POOL_CLOSED = "Connection pool closed"
POOL_CLEARED = "Connection pool cleared"
CONN_CREATED = "Connection created"
CONN_READY = "Connection ready"
CONN_CLOSED = "Connection closed"
CHECKOUT_STARTED = "Connection checkout started"
CHECKOUT_SUCCEEDED = "Connection checked out"
CHECKOUT_FAILED = "Connection checkout failed"
CHECKEDIN = "Connection checked in"
_DEFAULT_DOCUMENT_LENGTH = 1000
_SENSITIVE_COMMANDS = [
"authenticate",
"saslStart",
"saslContinue",
"getnonce",
"createUser",
"updateUser",
"copydbgetnonce",
"copydbsaslstart",
"copydb",
]
_HELLO_COMMANDS = ["hello", "ismaster", "isMaster"]
_REDACTED_FAILURE_FIELDS = ["code", "codeName", "errorLabels"]
_DOCUMENT_NAMES = ["command", "reply", "failure"]
_JSON_OPTIONS = JSONOptions(uuid_representation=UuidRepresentation.STANDARD)
_COMMAND_LOGGER = logging.getLogger("pymongo.command")
_CONNECTION_LOGGER = logging.getLogger("pymongo.connection")
_SERVER_SELECTION_LOGGER = logging.getLogger("pymongo.serverSelection")
_CLIENT_LOGGER = logging.getLogger("pymongo.client")
_VERBOSE_CONNECTION_ERROR_REASONS = {
ConnectionClosedReason.POOL_CLOSED: "Connection pool was closed",
ConnectionCheckOutFailedReason.POOL_CLOSED: "Connection pool was closed",
ConnectionClosedReason.STALE: "Connection pool was stale",
ConnectionClosedReason.ERROR: "An error occurred while using the connection",
ConnectionCheckOutFailedReason.CONN_ERROR: "An error occurred while trying to establish a new connection",
ConnectionClosedReason.IDLE: "Connection was idle too long",
ConnectionCheckOutFailedReason.TIMEOUT: "Connection exceeded the specified timeout",
}
def _debug_log(logger: logging.Logger, **fields: Any) -> None:
logger.debug(LogMessage(**fields))
def _verbose_connection_error_reason(reason: str) -> str:
return _VERBOSE_CONNECTION_ERROR_REASONS.get(reason, reason)
def _info_log(logger: logging.Logger, **fields: Any) -> None:
logger.info(LogMessage(**fields))
def _log_or_warn(logger: logging.Logger, message: str) -> None:
if logger.isEnabledFor(logging.INFO):
logger.info(message)
else:
# stacklevel=4 ensures that the warning is for the user's code.
warnings.warn(message, UserWarning, stacklevel=4)
class LogMessage:
__slots__ = ("_kwargs", "_redacted")
def __init__(self, **kwargs: Any):
self._kwargs = kwargs
self._redacted = False
def __str__(self) -> str:
self._redact()
return "%s" % (
json_util.dumps(
self._kwargs, json_options=_JSON_OPTIONS, default=lambda o: o.__repr__()
)
)
def _is_sensitive(self, doc_name: str) -> bool:
is_speculative_authenticate = (
self._kwargs.pop("speculative_authenticate", False)
or "speculativeAuthenticate" in self._kwargs[doc_name]
)
is_sensitive_command = (
"commandName" in self._kwargs and self._kwargs["commandName"] in _SENSITIVE_COMMANDS
)
is_sensitive_hello = (
self._kwargs["commandName"] in _HELLO_COMMANDS and is_speculative_authenticate
)
return is_sensitive_command or is_sensitive_hello
def _redact(self) -> None:
if self._redacted:
return
self._kwargs = {k: v for k, v in self._kwargs.items() if v is not None}
if "durationMS" in self._kwargs and hasattr(self._kwargs["durationMS"], "total_seconds"):
self._kwargs["durationMS"] = self._kwargs["durationMS"].total_seconds() * 1000
if "serviceId" in self._kwargs:
self._kwargs["serviceId"] = str(self._kwargs["serviceId"])
document_length = int(os.getenv("MONGOB_LOG_MAX_DOCUMENT_LENGTH", _DEFAULT_DOCUMENT_LENGTH))
if document_length < 0:
document_length = _DEFAULT_DOCUMENT_LENGTH
is_server_side_error = self._kwargs.pop("isServerSideError", False)
for doc_name in _DOCUMENT_NAMES:
doc = self._kwargs.get(doc_name)
if doc:
if doc_name == "failure" and is_server_side_error:
doc = {k: v for k, v in doc.items() if k in _REDACTED_FAILURE_FIELDS}
if doc_name != "failure" and self._is_sensitive(doc_name):
doc = json_util.dumps({})
else:
truncated_doc = _truncate_documents(doc, document_length)[0]
doc = json_util.dumps(
truncated_doc,
json_options=_JSON_OPTIONS,
default=lambda o: o.__repr__(),
)
if len(doc) > document_length:
doc = (
doc.encode()[:document_length].decode("unicode-escape", "ignore")
) + "..."
self._kwargs[doc_name] = doc
self._redacted = True

View File

@ -0,0 +1,125 @@
# Copyright 2016 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Criteria to select ServerDescriptions based on maxStalenessSeconds.
The Max Staleness Spec says: When there is a known primary P,
a secondary S's staleness is estimated with this formula:
(S.lastUpdateTime - S.lastWriteDate) - (P.lastUpdateTime - P.lastWriteDate)
+ heartbeatFrequencyMS
When there is no known primary, a secondary S's staleness is estimated with:
SMax.lastWriteDate - S.lastWriteDate + heartbeatFrequencyMS
where "SMax" is the secondary with the greatest lastWriteDate.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from pymongo.errors import ConfigurationError
from pymongo.server_type import SERVER_TYPE
if TYPE_CHECKING:
from pymongo.asynchronous.server_selectors import Selection
_IS_SYNC = False
# Constant defined in Max Staleness Spec: An idle primary writes a no-op every
# 10 seconds to refresh secondaries' lastWriteDate values.
IDLE_WRITE_PERIOD = 10
SMALLEST_MAX_STALENESS = 90
def _validate_max_staleness(max_staleness: int, heartbeat_frequency: int) -> None:
# We checked for max staleness -1 before this, it must be positive here.
if max_staleness < heartbeat_frequency + IDLE_WRITE_PERIOD:
raise ConfigurationError(
"maxStalenessSeconds must be at least heartbeatFrequencyMS +"
" %d seconds. maxStalenessSeconds is set to %d,"
" heartbeatFrequencyMS is set to %d."
% (IDLE_WRITE_PERIOD, max_staleness, heartbeat_frequency * 1000)
)
if max_staleness < SMALLEST_MAX_STALENESS:
raise ConfigurationError(
"maxStalenessSeconds must be at least %d. "
"maxStalenessSeconds is set to %d." % (SMALLEST_MAX_STALENESS, max_staleness)
)
def _with_primary(max_staleness: int, selection: Selection) -> Selection:
"""Apply max_staleness, in seconds, to a Selection with a known primary."""
primary = selection.primary
assert primary
sds = []
for s in selection.server_descriptions:
if s.server_type == SERVER_TYPE.RSSecondary:
# See max-staleness.rst for explanation of this formula.
assert s.last_write_date and primary.last_write_date # noqa: PT018
staleness = (
(s.last_update_time - s.last_write_date)
- (primary.last_update_time - primary.last_write_date)
+ selection.heartbeat_frequency
)
if staleness <= max_staleness:
sds.append(s)
else:
sds.append(s)
return selection.with_server_descriptions(sds)
def _no_primary(max_staleness: int, selection: Selection) -> Selection:
"""Apply max_staleness, in seconds, to a Selection with no known primary."""
# Secondary that's replicated the most recent writes.
smax = selection.secondary_with_max_last_write_date()
if not smax:
# No secondaries and no primary, short-circuit out of here.
return selection.with_server_descriptions([])
sds = []
for s in selection.server_descriptions:
if s.server_type == SERVER_TYPE.RSSecondary:
# See max-staleness.rst for explanation of this formula.
assert smax.last_write_date and s.last_write_date # noqa: PT018
staleness = smax.last_write_date - s.last_write_date + selection.heartbeat_frequency
if staleness <= max_staleness:
sds.append(s)
else:
sds.append(s)
return selection.with_server_descriptions(sds)
def select(max_staleness: int, selection: Selection) -> Selection:
"""Apply max_staleness, in seconds, to a Selection."""
if max_staleness == -1:
return selection
# Server Selection Spec: If the TopologyType is ReplicaSetWithPrimary or
# ReplicaSetNoPrimary, a client MUST raise an error if maxStaleness <
# heartbeatFrequency + IDLE_WRITE_PERIOD, or if maxStaleness < 90.
_validate_max_staleness(max_staleness, selection.heartbeat_frequency)
if selection.primary:
return _with_primary(max_staleness, selection)
else:
return _no_primary(max_staleness, selection)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,487 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Class to monitor a MongoDB server on a background thread."""
from __future__ import annotations
import atexit
import time
import weakref
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
from pymongo._csot import MovingMinimum
from pymongo.asynchronous import common, periodic_executor
from pymongo.asynchronous.hello import Hello
from pymongo.asynchronous.periodic_executor import _shutdown_executors
from pymongo.asynchronous.pool import _is_faas
from pymongo.asynchronous.read_preferences import MovingAverage
from pymongo.asynchronous.server_description import ServerDescription
from pymongo.asynchronous.srv_resolver import _SrvResolver
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
from pymongo.lock import _create_lock
if TYPE_CHECKING:
from pymongo.asynchronous.pool import Connection, Pool, _CancellationContext
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology
_IS_SYNC = False
def _sanitize(error: Exception) -> None:
"""PYTHON-2433 Clear error traceback info."""
error.__traceback__ = None
error.__context__ = None
error.__cause__ = None
class MonitorBase:
def __init__(self, topology: Topology, name: str, interval: int, min_interval: float):
"""Base class to do periodic work on a background thread.
The background thread is signaled to stop when the Topology or
this instance is freed.
"""
# We strongly reference the executor and it weakly references us via
# this closure. When the monitor is freed, stop the executor soon.
async def target() -> bool:
monitor = self_ref()
if monitor is None:
return False # Stop the executor.
await monitor._run() # type:ignore[attr-defined]
return True
executor = periodic_executor.PeriodicExecutor(
interval=interval, min_interval=min_interval, target=target, name=name
)
self._executor = executor
def _on_topology_gc(dummy: Optional[Topology] = None) -> None:
# This prevents GC from waiting 10 seconds for hello to complete
# See test_cleanup_executors_on_client_del.
monitor = self_ref()
if monitor:
monitor.gc_safe_close()
# Avoid cycles. When self or topology is freed, stop executor soon.
self_ref = weakref.ref(self, executor.close)
self._topology = weakref.proxy(topology, _on_topology_gc)
_register(self)
def open(self) -> None:
"""Start monitoring, or restart after a fork.
Multiple calls have no effect.
"""
self._executor.open()
def gc_safe_close(self) -> None:
"""GC safe close."""
self._executor.close()
async def close(self) -> None:
"""Close and stop monitoring.
open() restarts the monitor after closing.
"""
self.gc_safe_close()
def join(self, timeout: Optional[int] = None) -> None:
"""Wait for the monitor to stop."""
self._executor.join(timeout)
def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
self._executor.wake()
class Monitor(MonitorBase):
def __init__(
self,
server_description: ServerDescription,
topology: Topology,
pool: Pool,
topology_settings: TopologySettings,
):
"""Class to monitor a MongoDB server on a background thread.
Pass an initial ServerDescription, a Topology, a Pool, and
TopologySettings.
The Topology is weakly referenced. The Pool must be exclusive to this
Monitor.
"""
super().__init__(
topology,
"pymongo_server_monitor_thread",
topology_settings.heartbeat_frequency,
common.MIN_HEARTBEAT_INTERVAL,
)
self._server_description = server_description
self._pool = pool
self._settings = topology_settings
self._listeners = self._settings._pool_options._event_listeners
self._publish = self._listeners is not None and self._listeners.enabled_for_server_heartbeat
self._cancel_context: Optional[_CancellationContext] = None
self._rtt_monitor = _RttMonitor(
topology,
topology_settings,
topology._create_pool_for_monitor(server_description.address),
)
if topology_settings.server_monitoring_mode == "stream":
self._stream = True
elif topology_settings.server_monitoring_mode == "poll":
self._stream = False
else:
self._stream = not _is_faas()
def cancel_check(self) -> None:
"""Cancel any concurrent hello check.
Note: this is called from a weakref.proxy callback and MUST NOT take
any locks.
"""
context = self._cancel_context
if context:
# Note: we cannot close the socket because doing so may cause
# concurrent reads/writes to hang until a timeout occurs
# (depending on the platform).
context.cancel()
async def _start_rtt_monitor(self) -> None:
"""Start an _RttMonitor that periodically runs ping."""
# If this monitor is closed directly before (or during) this open()
# call, the _RttMonitor will not be closed. Checking if this monitor
# was closed directly after resolves the race.
self._rtt_monitor.open()
if self._executor._stopped:
await self._rtt_monitor.close()
def gc_safe_close(self) -> None:
self._executor.close()
self._rtt_monitor.gc_safe_close()
self.cancel_check()
async def close(self) -> None:
self.gc_safe_close()
await self._rtt_monitor.close()
# Increment the generation and maybe close the socket. If the executor
# thread has the socket checked out, it will be closed when checked in.
await self._reset_connection()
async def _reset_connection(self) -> None:
# Clear our pooled connection.
await self._pool.reset()
async def _run(self) -> None:
try:
prev_sd = self._server_description
try:
self._server_description = await self._check_server()
except _OperationCancelled as exc:
_sanitize(exc)
# Already closed the connection, wait for the next check.
self._server_description = ServerDescription(
self._server_description.address, error=exc
)
if prev_sd.is_server_type_known:
# Immediately retry since we've already waited 500ms to
# discover that we've been cancelled.
self._executor.skip_sleep()
return
# Update the Topology and clear the server pool on error.
await self._topology.on_change(
self._server_description,
reset_pool=self._server_description.error,
interrupt_connections=isinstance(self._server_description.error, NetworkTimeout),
)
if self._stream and (
self._server_description.is_server_type_known
and self._server_description.topology_version
):
await self._start_rtt_monitor()
# Immediately check for the next streaming response.
self._executor.skip_sleep()
if self._server_description.error and prev_sd.is_server_type_known:
# Immediately retry on network errors.
self._executor.skip_sleep()
except ReferenceError:
# Topology was garbage-collected.
await self.close()
async def _check_server(self) -> ServerDescription:
"""Call hello or read the next streaming response.
Returns a ServerDescription.
"""
start = time.monotonic()
try:
try:
return await self._check_once()
except (OperationFailure, NotPrimaryError) as exc:
# Update max cluster time even when hello fails.
details = cast(Mapping[str, Any], exc.details)
self._topology.receive_cluster_time(details.get("$clusterTime"))
raise
except ReferenceError:
raise
except Exception as error:
_sanitize(error)
sd = self._server_description
address = sd.address
duration = time.monotonic() - start
if self._publish:
awaited = bool(self._stream and sd.is_server_type_known and sd.topology_version)
assert self._listeners is not None
self._listeners.publish_server_heartbeat_failed(address, duration, error, awaited)
await self._reset_connection()
if isinstance(error, _OperationCancelled):
raise
self._rtt_monitor.reset()
# Server type defaults to Unknown.
return ServerDescription(address, error=error)
async def _check_once(self) -> ServerDescription:
"""A single attempt to call hello.
Returns a ServerDescription, or raises an exception.
"""
address = self._server_description.address
if self._publish:
assert self._listeners is not None
sd = self._server_description
# XXX: "awaited" could be incorrectly set to True in the rare case
# the pool checkout closes and recreates a connection.
awaited = bool(
self._pool.conns
and self._stream
and sd.is_server_type_known
and sd.topology_version
)
self._listeners.publish_server_heartbeat_started(address, awaited)
if self._cancel_context and self._cancel_context.cancelled:
await self._reset_connection()
async with self._pool.checkout() as conn:
self._cancel_context = conn.cancel_context
response, round_trip_time = await self._check_with_socket(conn)
if not response.awaitable:
self._rtt_monitor.add_sample(round_trip_time)
avg_rtt, min_rtt = self._rtt_monitor.get()
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
if self._publish:
assert self._listeners is not None
self._listeners.publish_server_heartbeat_succeeded(
address, round_trip_time, response, response.awaitable
)
return sd
async def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]:
"""Return (Hello, round_trip_time).
Can raise ConnectionFailure or OperationFailure.
"""
cluster_time = self._topology.max_cluster_time()
start = time.monotonic()
if conn.more_to_come:
# Read the next streaming hello (MongoDB 4.4+).
response = Hello(await conn._next_reply(), awaitable=True)
elif (
self._stream and conn.performed_handshake and self._server_description.topology_version
):
# Initiate streaming hello (MongoDB 4.4+).
response = await conn._hello(
cluster_time,
self._server_description.topology_version,
self._settings.heartbeat_frequency,
)
else:
# New connection handshake or polling hello (MongoDB <4.4).
response = await conn._hello(cluster_time, None, None)
return response, time.monotonic() - start
class SrvMonitor(MonitorBase):
def __init__(self, topology: Topology, topology_settings: TopologySettings):
"""Class to poll SRV records on a background thread.
Pass a Topology and a TopologySettings.
The Topology is weakly referenced.
"""
super().__init__(
topology,
"pymongo_srv_polling_thread",
common.MIN_SRV_RESCAN_INTERVAL,
topology_settings.heartbeat_frequency,
)
self._settings = topology_settings
self._seedlist = self._settings._seeds
assert isinstance(self._settings.fqdn, str)
self._fqdn: str = self._settings.fqdn
self._startup_time = time.monotonic()
async def _run(self) -> None:
# Don't poll right after creation, wait 60 seconds first
if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL:
return
seedlist = self._get_seedlist()
if seedlist:
self._seedlist = seedlist
try:
await self._topology.on_srv_update(self._seedlist)
except ReferenceError:
# Topology was garbage-collected.
await self.close()
def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
"""Poll SRV records for a seedlist.
Returns a list of ServerDescriptions.
"""
try:
resolver = _SrvResolver(
self._fqdn,
self._settings.pool_options.connect_timeout,
self._settings.srv_service_name,
)
seedlist, ttl = resolver.get_hosts_and_min_ttl()
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception
except Exception:
# As per the spec, upon encountering an error:
# - An error must not be raised
# - SRV records must be rescanned every heartbeatFrequencyMS
# - Topology must be left unchanged
self.request_check()
return None
else:
self._executor.update_interval(max(ttl, common.MIN_SRV_RESCAN_INTERVAL))
return seedlist
class _RttMonitor(MonitorBase):
def __init__(self, topology: Topology, topology_settings: TopologySettings, pool: Pool):
"""Maintain round trip times for a server.
The Topology is weakly referenced.
"""
super().__init__(
topology,
"pymongo_server_rtt_thread",
topology_settings.heartbeat_frequency,
common.MIN_HEARTBEAT_INTERVAL,
)
self._pool = pool
self._moving_average = MovingAverage()
self._moving_min = MovingMinimum()
self._lock = _create_lock()
async def close(self) -> None:
self.gc_safe_close()
# Increment the generation and maybe close the socket. If the executor
# thread has the socket checked out, it will be closed when checked in.
await self._pool.reset()
def add_sample(self, sample: float) -> None:
"""Add a RTT sample."""
with self._lock:
self._moving_average.add_sample(sample)
self._moving_min.add_sample(sample)
def get(self) -> tuple[Optional[float], float]:
"""Get the calculated average, or None if no samples yet and the min."""
with self._lock:
return self._moving_average.get(), self._moving_min.get()
def reset(self) -> None:
"""Reset the average RTT."""
with self._lock:
self._moving_average.reset()
self._moving_min.reset()
async def _run(self) -> None:
try:
# NOTE: This thread is only run when using the streaming
# heartbeat protocol (MongoDB 4.4+).
# XXX: Skip check if the server is unknown?
rtt = await self._ping()
self.add_sample(rtt)
except ReferenceError:
# Topology was garbage-collected.
await self.close()
except Exception:
await self._pool.reset()
async def _ping(self) -> float:
"""Run a "hello" command and return the RTT."""
async with self._pool.checkout() as conn:
if self._executor._stopped:
raise Exception("_RttMonitor closed")
start = time.monotonic()
await conn.hello()
return time.monotonic() - start
# Close monitors to cancel any in progress streaming checks before joining
# executor threads. For an explanation of how this works see the comment
# about _EXECUTORS in periodic_executor.py.
_MONITORS = set()
def _register(monitor: MonitorBase) -> None:
ref = weakref.ref(monitor, _unregister)
_MONITORS.add(ref)
def _unregister(monitor_ref: weakref.ReferenceType[MonitorBase]) -> None:
_MONITORS.remove(monitor_ref)
def _shutdown_monitors() -> None:
if _MONITORS is None:
return
# Copy the set. Closing monitors removes them.
monitors = list(_MONITORS)
# Close all monitors.
for ref in monitors:
monitor = ref()
if monitor:
monitor.gc_safe_close()
monitor = None
def _shutdown_resources() -> None:
# _shutdown_monitors/_shutdown_executors may already be GC'd at shutdown.
shutdown = _shutdown_monitors
if shutdown: # type:ignore[truthy-function]
shutdown()
shutdown = _shutdown_executors
if shutdown: # type:ignore[truthy-function]
shutdown()
atexit.register(_shutdown_resources)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,418 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal network layer helper methods."""
from __future__ import annotations
import asyncio
import datetime
import errno
import logging
import socket
import time
from typing import (
TYPE_CHECKING,
Any,
Mapping,
MutableMapping,
Optional,
Sequence,
Union,
cast,
)
from bson import _decode_all_selective
from pymongo import _csot
from pymongo.asynchronous import helpers as _async_helpers
from pymongo.asynchronous import message as _async_message
from pymongo.asynchronous.common import MAX_MESSAGE_SIZE
from pymongo.asynchronous.compression_support import _NO_COMPRESSION, decompress
from pymongo.asynchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.asynchronous.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.asynchronous.monitoring import _is_speculative_authenticate
from pymongo.errors import (
NotPrimaryError,
OperationFailure,
ProtocolError,
_OperationCancelled,
)
from pymongo.network_layer import (
_POLL_TIMEOUT,
_UNPACK_COMPRESSION_HEADER,
_UNPACK_HEADER,
BLOCKING_IO_ERRORS,
async_sendall,
)
from pymongo.socket_checker import _errno_from_exception
if TYPE_CHECKING:
from bson import CodecOptions
from pymongo.asynchronous.client_session import ClientSession
from pymongo.asynchronous.compression_support import SnappyContext, ZlibContext, ZstdContext
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.monitoring import _EventListeners
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.read_preferences import _ServerMode
from pymongo.asynchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
async def command(
conn: Connection,
dbname: str,
spec: MutableMapping[str, Any],
is_mongos: bool,
read_preference: Optional[_ServerMode],
codec_options: CodecOptions[_DocumentType],
session: Optional[ClientSession],
client: Optional[AsyncMongoClient],
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
address: Optional[_Address] = None,
listeners: Optional[_EventListeners] = None,
max_bson_size: Optional[int] = None,
read_concern: Optional[ReadConcern] = None,
parse_write_concern_error: bool = False,
collation: Optional[_CollationIn] = None,
compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
use_op_msg: bool = False,
unacknowledged: bool = False,
user_fields: Optional[Mapping[str, Any]] = None,
exhaust_allowed: bool = False,
write_concern: Optional[WriteConcern] = None,
) -> _DocumentType:
"""Execute a command over the socket, or raise socket.error.
:param conn: a Connection instance
:param dbname: name of the database on which to run the command
:param spec: a command document as an ordered dict type, eg SON.
:param is_mongos: are we connected to a mongos?
:param read_preference: a read preference
:param codec_options: a CodecOptions instance
:param session: optional ClientSession instance.
:param client: optional AsyncMongoClient instance for updating $clusterTime.
:param check: raise OperationFailure if there are errors
:param allowable_errors: errors to ignore if `check` is True
:param address: the (host, port) of `conn`
:param listeners: An instance of :class:`~pymongo.monitoring.EventListeners`
:param max_bson_size: The maximum encoded bson size for this server
:param read_concern: The read concern for this command.
:param parse_write_concern_error: Whether to parse the ``writeConcernError``
field in the command response.
:param collation: The collation for this command.
:param compression_ctx: optional compression Context.
:param use_op_msg: True if we should use OP_MSG.
:param unacknowledged: True if this is an unacknowledged command.
:param user_fields: Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
:param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed.
"""
name = next(iter(spec))
ns = dbname + ".$cmd"
speculative_hello = False
# Publish the original command document, perhaps with lsid and $clusterTime.
orig = spec
if is_mongos and not use_op_msg:
assert read_preference is not None
spec = _async_message._maybe_add_read_preference(spec, read_preference)
if read_concern and not (session and session.in_transaction):
if read_concern.level:
spec["readConcern"] = read_concern.document
if session:
session._update_read_concern(spec, conn)
if collation is not None:
spec["collation"] = collation
publish = listeners is not None and listeners.enabled_for_commands
start = datetime.datetime.now()
if publish:
speculative_hello = _is_speculative_authenticate(name, spec)
if compression_ctx and name.lower() in _NO_COMPRESSION:
compression_ctx = None
if client and client._encrypter and not client._encrypter._bypass_auto_encryption:
spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options)
# Support CSOT
if client:
conn.apply_timeout(client, spec)
_csot.apply_write_concern(spec, write_concern)
if use_op_msg:
flags = _OpMsg.MORE_TO_COME if unacknowledged else 0
flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0
request_id, msg, size, max_doc_size = _async_message._op_msg(
flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx
)
# If this is an unacknowledged write then make sure the encoded doc(s)
# are small enough, otherwise rely on the server to return an error.
if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size:
_async_message._raise_document_too_large(name, size, max_bson_size)
else:
request_id, msg, size = _async_message._query(
0, ns, 0, -1, spec, None, codec_options, compression_ctx
)
if max_bson_size is not None and size > max_bson_size + _async_message._COMMAND_OVERHEAD:
_async_message._raise_document_too_large(
name, size, max_bson_size + _async_message._COMMAND_OVERHEAD
)
if client is not None:
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=spec,
commandName=next(iter(spec)),
databaseName=dbname,
requestId=request_id,
operationId=request_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=conn.address[0],
serverPort=conn.address[1],
serviceId=conn.service_id,
)
if publish:
assert listeners is not None
assert address is not None
listeners.publish_command_start(
orig,
dbname,
request_id,
address,
conn.server_connection_id,
service_id=conn.service_id,
)
try:
await async_sendall(conn.conn, msg)
if use_op_msg and unacknowledged:
# Unacknowledged, fake a successful command response.
reply = None
response_doc: _DocumentOut = {"ok": 1}
else:
reply = await receive_message(conn, request_id)
conn.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields
)
response_doc = unpacked_docs[0]
if client:
await client._process_response(response_doc, session)
if check:
_async_helpers._check_command_response(
response_doc,
conn.max_wire_version,
allowable_errors,
parse_write_concern_error=parse_write_concern_error,
)
except Exception as exc:
duration = datetime.datetime.now() - start
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _async_message._convert_exception(exc)
if client is not None:
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(spec)),
databaseName=dbname,
requestId=request_id,
operationId=request_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=conn.address[0],
serverPort=conn.address[1],
serviceId=conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if publish:
assert listeners is not None
assert address is not None
listeners.publish_command_failure(
duration,
failure,
name,
request_id,
address,
conn.server_connection_id,
service_id=conn.service_id,
database_name=dbname,
)
raise
duration = datetime.datetime.now() - start
if client is not None:
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=response_doc,
commandName=next(iter(spec)),
databaseName=dbname,
requestId=request_id,
operationId=request_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=conn.address[0],
serverPort=conn.address[1],
serviceId=conn.service_id,
speculative_authenticate="speculativeAuthenticate" in orig,
)
if publish:
assert listeners is not None
assert address is not None
listeners.publish_command_success(
duration,
response_doc,
name,
request_id,
address,
conn.server_connection_id,
service_id=conn.service_id,
speculative_hello=speculative_hello,
database_name=dbname,
)
if client and client._encrypter and reply:
decrypted = client._encrypter.decrypt(reply.raw_command_response())
response_doc = cast(
"_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0]
)
return response_doc # type: ignore[return-value]
async def receive_message(
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout():
deadline = _csot.get_deadline()
else:
timeout = conn.conn.gettimeout()
if timeout:
deadline = time.monotonic() + timeout
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(
await _receive_data_on_socket(conn, 16, deadline)
)
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
if length <= 16:
raise ProtocolError(
f"Message length ({length!r}) not longer than standard message header size (16)"
)
if length > max_message_size:
raise ProtocolError(
f"Message length ({length!r}) is larger than server max "
f"message size ({max_message_size!r})"
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
await _receive_data_on_socket(conn, 9, deadline)
)
data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id)
else:
data = await _receive_data_on_socket(conn, length - 16, deadline)
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)
async def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
return
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
readable = True
else:
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled")
if readable:
return
if timed_out:
raise socket.timeout("timed out")
await asyncio.sleep(0)
async def _receive_data_on_socket(
conn: Connection, length: int, deadline: Optional[float]
) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
await wait_for_read(conn, deadline)
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None:
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except OSError as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk_length == 0:
raise OSError("connection closed")
bytes_read += chunk_length
return mv

View File

@ -0,0 +1,625 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Operation class definitions."""
from __future__ import annotations
import enum
from typing import (
TYPE_CHECKING,
Any,
Generic,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
from bson.raw_bson import RawBSONDocument
from pymongo.asynchronous import helpers
from pymongo.asynchronous.collation import validate_collation_or_none
from pymongo.asynchronous.common import validate_is_mapping, validate_list
from pymongo.asynchronous.helpers import _gen_index_name, _index_document, _index_list
from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _Pipeline
from pymongo.write_concern import validate_boolean
if TYPE_CHECKING:
from pymongo.asynchronous.bulk import _Bulk
_IS_SYNC = False
# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary
_IndexList = Union[
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
]
_IndexKeyHint = Union[str, _IndexList]
class _Op(str, enum.Enum):
ABORT = "abortTransaction"
AGGREGATE = "aggregate"
COMMIT = "commitTransaction"
COUNT = "count"
CREATE = "create"
CREATE_INDEXES = "createIndexes"
CREATE_SEARCH_INDEXES = "createSearchIndexes"
DELETE = "delete"
DISTINCT = "distinct"
DROP = "drop"
DROP_DATABASE = "dropDatabase"
DROP_INDEXES = "dropIndexes"
DROP_SEARCH_INDEXES = "dropSearchIndexes"
END_SESSIONS = "endSessions"
FIND_AND_MODIFY = "findAndModify"
FIND = "find"
INSERT = "insert"
LIST_COLLECTIONS = "listCollections"
LIST_INDEXES = "listIndexes"
LIST_SEARCH_INDEX = "listSearchIndexes"
LIST_DATABASES = "listDatabases"
UPDATE = "update"
UPDATE_INDEX = "updateIndex"
UPDATE_SEARCH_INDEX = "updateSearchIndex"
RENAME = "rename"
GETMORE = "getMore"
KILL_CURSORS = "killCursors"
TEST = "testOperation"
class InsertOne(Generic[_DocumentType]):
"""Represents an insert_one operation."""
__slots__ = ("_doc",)
def __init__(self, document: _DocumentType) -> None:
"""Create an InsertOne instance.
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
:param document: The document to insert. If the document is missing an
_id field one will be added.
"""
self._doc = document
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_insert(self._doc) # type: ignore[arg-type]
def __repr__(self) -> str:
return f"InsertOne({self._doc!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return other._doc == self._doc
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
class DeleteOne:
"""Represents a delete_one operation."""
__slots__ = ("_filter", "_collation", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create a DeleteOne instance.
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
:param filter: A query that matches the document to delete.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.4 and above.
.. versionchanged:: 3.11
Added the ``hint`` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._collation = collation
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_delete(
self._filter,
1,
collation=validate_collation_or_none(self._collation),
hint=self._hint,
)
def __repr__(self) -> str:
return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return (other._filter, other._collation, other._hint) == (
self._filter,
self._collation,
self._hint,
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
class DeleteMany:
"""Represents a delete_many operation."""
__slots__ = ("_filter", "_collation", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create a DeleteMany instance.
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
:param filter: A query that matches the documents to delete.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.4 and above.
.. versionchanged:: 3.11
Added the ``hint`` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._collation = collation
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_delete(
self._filter,
0,
collation=validate_collation_or_none(self._collation),
hint=self._hint,
)
def __repr__(self) -> str:
return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return (other._filter, other._collation, other._hint) == (
self._filter,
self._collation,
self._hint,
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
class ReplaceOne(Generic[_DocumentType]):
"""Represents a replace_one operation."""
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
replacement: Union[_DocumentType, RawBSONDocument],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create a ReplaceOne instance.
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
:param filter: A query that matches the document to replace.
:param replacement: The new document.
:param upsert: If ``True``, perform an insert if no documents
match the filter.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
.. versionchanged:: 3.11
Added the ``hint`` option.
.. versionchanged:: 3.5
Added the ``collation`` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
if upsert is not None:
validate_boolean("upsert", upsert)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._doc = replacement
self._upsert = upsert
self._collation = collation
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_replace(
self._filter,
self._doc,
self._upsert,
collation=validate_collation_or_none(self._collation),
hint=self._hint,
)
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return (
other._filter,
other._doc,
other._upsert,
other._collation,
other._hint,
) == (
self._filter,
self._doc,
self._upsert,
self._collation,
other._hint,
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __repr__(self) -> str:
return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format(
self.__class__.__name__,
self._filter,
self._doc,
self._upsert,
self._collation,
self._hint,
)
class _UpdateOp:
"""Private base class for update operations."""
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
doc: Union[Mapping[str, Any], _Pipeline],
upsert: bool,
collation: Optional[_CollationIn],
array_filters: Optional[list[Mapping[str, Any]]],
hint: Optional[_IndexKeyHint],
):
if filter is not None:
validate_is_mapping("filter", filter)
if upsert is not None:
validate_boolean("upsert", upsert)
if array_filters is not None:
validate_list("array_filters", array_filters)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._doc = doc
self._upsert = upsert
self._collation = collation
self._array_filters = array_filters
def __eq__(self, other: object) -> bool:
if isinstance(other, type(self)):
return (
other._filter,
other._doc,
other._upsert,
other._collation,
other._array_filters,
other._hint,
) == (
self._filter,
self._doc,
self._upsert,
self._collation,
self._array_filters,
self._hint,
)
return NotImplemented
def __repr__(self) -> str:
return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format(
self.__class__.__name__,
self._filter,
self._doc,
self._upsert,
self._collation,
self._array_filters,
self._hint,
)
class UpdateOne(_UpdateOp):
"""Represents an update_one operation."""
__slots__ = ()
def __init__(
self,
filter: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Represents an update_one operation.
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
:param filter: A query that matches the document to update.
:param update: The modifications to apply.
:param upsert: If ``True``, perform an insert if no documents
match the filter.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param array_filters: A list of filters specifying which
array elements an update should apply.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
.. versionchanged:: 3.11
Added the `hint` option.
.. versionchanged:: 3.9
Added the ability to accept a pipeline as the `update`.
.. versionchanged:: 3.6
Added the `array_filters` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
super().__init__(filter, update, upsert, collation, array_filters, hint)
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_update(
self._filter,
self._doc,
False,
self._upsert,
collation=validate_collation_or_none(self._collation),
array_filters=self._array_filters,
hint=self._hint,
)
class UpdateMany(_UpdateOp):
"""Represents an update_many operation."""
__slots__ = ()
def __init__(
self,
filter: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create an UpdateMany instance.
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
:param filter: A query that matches the documents to update.
:param update: The modifications to apply.
:param upsert: If ``True``, perform an insert if no documents
match the filter.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param array_filters: A list of filters specifying which
array elements an update should apply.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
.. versionchanged:: 3.11
Added the `hint` option.
.. versionchanged:: 3.9
Added the ability to accept a pipeline as the `update`.
.. versionchanged:: 3.6
Added the `array_filters` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
super().__init__(filter, update, upsert, collation, array_filters, hint)
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_update(
self._filter,
self._doc,
True,
self._upsert,
collation=validate_collation_or_none(self._collation),
array_filters=self._array_filters,
hint=self._hint,
)
class IndexModel:
"""Represents an index to create."""
__slots__ = ("__document",)
def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None:
"""Create an Index instance.
For use with :meth:`~pymongo.collection.AsyncCollection.create_indexes`.
Takes either a single key or a list containing (key, direction) pairs
or keys. If no direction is given, :data:`~pymongo.ASCENDING` will
be assumed.
The key(s) must be an instance of :class:`str`, and the direction(s) must
be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`,
:data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`,
:data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`).
Valid options include, but are not limited to:
- `name`: custom name to use for this index - if none is
given, a name will be generated.
- `unique`: if ``True``, creates a uniqueness constraint on the index.
- `background`: if ``True``, this index should be created in the
background.
- `sparse`: if ``True``, omit from the index any documents that lack
the indexed field.
- `bucketSize`: for use with geoHaystack indexes.
Number of documents to group together within a certain proximity
to a given longitude and latitude.
- `min`: minimum value for keys in a :data:`~pymongo.GEO2D`
index.
- `max`: maximum value for keys in a :data:`~pymongo.GEO2D`
index.
- `expireAfterSeconds`: <int> Used to create an expiring (TTL)
collection. MongoDB will automatically delete documents from
this collection after <int> seconds. The indexed field must
be a UTC datetime or the data will not expire.
- `partialFilterExpression`: A document that specifies a filter for
a partial index.
- `collation`: An instance of :class:`~pymongo.collation.Collation`
that specifies the collation to use.
- `wildcardProjection`: Allows users to include or exclude specific
field paths from a `wildcard index`_ using the { "$**" : 1} key
pattern. Requires MongoDB >= 4.2.
- `hidden`: if ``True``, this index will be hidden from the query
planner and will not be evaluated as part of query plan
selection. Requires MongoDB >= 4.4.
See the MongoDB documentation for a full list of supported options by
server version.
:param keys: a single key or a list containing (key, direction) pairs
or keys specifying the index to create.
:param kwargs: any additional index creation
options (see the above list) should be passed as keyword
arguments.
.. versionchanged:: 3.11
Added the ``hidden`` option.
.. versionchanged:: 3.2
Added the ``partialFilterExpression`` option to support partial
indexes.
.. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/
"""
keys = _index_list(keys)
if kwargs.get("name") is None:
kwargs["name"] = _gen_index_name(keys)
kwargs["key"] = _index_document(keys)
collation = validate_collation_or_none(kwargs.pop("collation", None))
self.__document = kwargs
if collation is not None:
self.__document["collation"] = collation
@property
def document(self) -> dict[str, Any]:
"""An index document suitable for passing to the createIndexes
command.
"""
return self.__document
class SearchIndexModel:
"""Represents a search index to create."""
__slots__ = ("__document",)
def __init__(
self,
definition: Mapping[str, Any],
name: Optional[str] = None,
type: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Create a Search Index instance.
For use with :meth:`~pymongo.collection.AsyncCollection.create_search_index` and :meth:`~pymongo.collection.AsyncCollection.create_search_indexes`.
:param definition: The definition for this index.
:param name: The name for this index, if present.
:param type: The type for this index which defaults to "search". Alternative values include "vectorSearch".
:param kwargs: Keyword arguments supplying any additional options.
.. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster.
.. versionadded:: 4.5
.. versionchanged:: 4.7
Added the type and kwargs arguments.
"""
self.__document: dict[str, Any] = {}
if name is not None:
self.__document["name"] = name
self.__document["definition"] = definition
if type is not None:
self.__document["type"] = type
self.__document.update(kwargs)
@property
def document(self) -> Mapping[str, Any]:
"""The document for this index."""
return self.__document

View File

@ -0,0 +1,209 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Run a target function on a background thread."""
from __future__ import annotations
import asyncio
import sys
import threading
import time
import weakref
from typing import Any, Optional
from pymongo.lock import _ALock, _create_lock
_IS_SYNC = False
class PeriodicExecutor:
def __init__(
self,
interval: float,
min_interval: float,
target: Any,
name: Optional[str] = None,
):
"""Run a target function periodically on a background thread.
If the target's return value is false, the executor stops.
:param interval: Seconds between calls to `target`.
:param min_interval: Minimum seconds between calls if `wake` is
called very often.
:param target: A function.
:param name: A name to give the underlying thread.
"""
# threading.Event and its internal condition variable are expensive
# in Python 2, see PYTHON-983. Use a boolean to know when to wake.
# The executor's design is constrained by several Python issues, see
# "periodic_executor.rst" in this repository.
self._event = False
self._interval = interval
self._min_interval = min_interval
self._target = target
self._stopped = False
self._thread: Optional[threading.Thread] = None
self._name = name
self._skip_sleep = False
self._thread_will_exit = False
self._lock = _ALock(_create_lock())
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
def _run_async(self) -> None:
asyncio.run(self._run()) # type: ignore[func-returns-value]
def open(self) -> None:
"""Start. Multiple calls have no effect.
Not safe to call from multiple threads at once.
"""
with self._lock:
if self._thread_will_exit:
# If the background thread has read self._stopped as True
# there is a chance that it has not yet exited. The call to
# join should not block indefinitely because there is no
# other work done outside the while loop in self._run.
try:
assert self._thread is not None
self._thread.join()
except ReferenceError:
# Thread terminated.
pass
self._thread_will_exit = False
self._stopped = False
started: Any = False
try:
started = self._thread and self._thread.is_alive()
except ReferenceError:
# Thread terminated.
pass
if not started:
if _IS_SYNC:
thread = threading.Thread(target=self._run, name=self._name)
else:
thread = threading.Thread(target=self._run_async, name=self._name)
thread.daemon = True
self._thread = weakref.proxy(thread)
_register_executor(self)
# Mitigation to RuntimeError firing when thread starts on shutdown
# https://github.com/python/cpython/issues/114570
try:
thread.start()
except RuntimeError as e:
if "interpreter shutdown" in str(e) or sys.is_finalizing():
self._thread = None
return
raise
def close(self, dummy: Any = None) -> None:
"""Stop. To restart, call open().
The dummy parameter allows an executor's close method to be a weakref
callback; see monitor.py.
"""
self._stopped = True
def join(self, timeout: Optional[int] = None) -> None:
if self._thread is not None:
try:
self._thread.join(timeout)
except (ReferenceError, RuntimeError):
# Thread already terminated, or not yet started.
pass
def wake(self) -> None:
"""Execute the target function soon."""
self._event = True
def update_interval(self, new_interval: int) -> None:
self._interval = new_interval
def skip_sleep(self) -> None:
self._skip_sleep = True
async def _should_stop(self) -> bool:
async with self._lock:
if self._stopped:
self._thread_will_exit = True
return True
return False
async def _run(self) -> None:
while not await self._should_stop():
try:
if not await self._target():
self._stopped = True
break
except BaseException:
async with self._lock:
self._stopped = True
self._thread_will_exit = True
raise
if self._skip_sleep:
self._skip_sleep = False
else:
deadline = time.monotonic() + self._interval
while not self._stopped and time.monotonic() < deadline:
await asyncio.sleep(self._min_interval)
if self._event:
break # Early wake.
self._event = False
# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started,
# an executor is kept alive by a strong reference from its thread and perhaps
# from other objects. When the thread dies and all other referrers are freed,
# the executor is freed and removed from _EXECUTORS. If any threads are
# running when the interpreter begins to shut down, we try to halt and join
# them to avoid spurious errors.
_EXECUTORS = set()
def _register_executor(executor: PeriodicExecutor) -> None:
ref = weakref.ref(executor, _on_executor_deleted)
_EXECUTORS.add(ref)
def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None:
_EXECUTORS.remove(ref)
def _shutdown_executors() -> None:
if _EXECUTORS is None:
return
# Copy the set. Stopping threads has the side effect of removing executors.
executors = list(_EXECUTORS)
# First signal all executors to close...
for ref in executors:
executor = ref()
if executor:
executor.close()
# ...then try to join them.
for ref in executors:
executor = ref()
if executor:
executor.join(1)
executor = None

2128
pymongo/asynchronous/pool.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,624 @@
# Copyright 2012-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License",
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for choosing which member of a replica set to read from."""
from __future__ import annotations
from collections import abc
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
from pymongo.asynchronous import max_staleness_selectors
from pymongo.asynchronous.server_selectors import (
member_with_tags_server_selector,
secondary_with_tags_server_selector,
)
from pymongo.errors import ConfigurationError
if TYPE_CHECKING:
from pymongo.asynchronous.server_selectors import Selection
from pymongo.asynchronous.topology_description import TopologyDescription
_IS_SYNC = False
_PRIMARY = 0
_PRIMARY_PREFERRED = 1
_SECONDARY = 2
_SECONDARY_PREFERRED = 3
_NEAREST = 4
_MONGOS_MODES = (
"primary",
"primaryPreferred",
"secondary",
"secondaryPreferred",
"nearest",
)
_Hedge = Mapping[str, Any]
_TagSets = Sequence[Mapping[str, Any]]
def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]:
"""Validate tag sets for a MongoClient."""
if tag_sets is None:
return tag_sets
if not isinstance(tag_sets, (list, tuple)):
raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence")
if len(tag_sets) == 0:
raise ValueError(
f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags"
)
for tags in tag_sets:
if not isinstance(tags, abc.Mapping):
raise TypeError(
f"Tag set {tags!r} invalid, must be an instance of dict, "
"bson.son.SON or other type that inherits from "
"collection.Mapping"
)
return list(tag_sets)
def _invalid_max_staleness_msg(max_staleness: Any) -> str:
return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness
# Some duplication with common.py to avoid import cycle.
def _validate_max_staleness(max_staleness: Any) -> int:
"""Validate max_staleness."""
if max_staleness == -1:
return -1
if not isinstance(max_staleness, int):
raise TypeError(_invalid_max_staleness_msg(max_staleness))
if max_staleness <= 0:
raise ValueError(_invalid_max_staleness_msg(max_staleness))
return max_staleness
def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]:
"""Validate hedge."""
if hedge is None:
return None
if not isinstance(hedge, dict):
raise TypeError(f"hedge must be a dictionary, not {hedge!r}")
return hedge
class _ServerMode:
"""Base class for all read preferences."""
__slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge")
def __init__(
self,
mode: int,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
self.__mongos_mode = _MONGOS_MODES[mode]
self.__mode = mode
self.__tag_sets = _validate_tag_sets(tag_sets)
self.__max_staleness = _validate_max_staleness(max_staleness)
self.__hedge = _validate_hedge(hedge)
@property
def name(self) -> str:
"""The name of this read preference."""
return self.__class__.__name__
@property
def mongos_mode(self) -> str:
"""The mongos mode of this read preference."""
return self.__mongos_mode
@property
def document(self) -> dict[str, Any]:
"""Read preference as a document."""
doc: dict[str, Any] = {"mode": self.__mongos_mode}
if self.__tag_sets not in (None, [{}]):
doc["tags"] = self.__tag_sets
if self.__max_staleness != -1:
doc["maxStalenessSeconds"] = self.__max_staleness
if self.__hedge not in (None, {}):
doc["hedge"] = self.__hedge
return doc
@property
def mode(self) -> int:
"""The mode of this read preference instance."""
return self.__mode
@property
def tag_sets(self) -> _TagSets:
"""Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to
read only from members whose ``dc`` tag has the value ``"ny"``.
To specify a priority-order for tag sets, provide a list of
tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag
set, ``{}``, means "read from any member that matches the mode,
ignoring tags." MongoClient tries each set of tags in turn
until it finds a set of tags with at least one matching member.
For example, to only send a query to an analytic node::
Nearest(tag_sets=[{"node":"analytics"}])
Or using :class:`SecondaryPreferred`::
SecondaryPreferred(tag_sets=[{"node":"analytics"}])
.. seealso:: `Data-Center Awareness
<https://www.mongodb.com/docs/manual/data-center-awareness/>`_
"""
return list(self.__tag_sets) if self.__tag_sets else [{}]
@property
def max_staleness(self) -> int:
"""The maximum estimated length of time (in seconds) a replica set
secondary can fall behind the primary in replication before it will
no longer be selected for operations, or -1 for no maximum.
"""
return self.__max_staleness
@property
def hedge(self) -> Optional[_Hedge]:
"""The read preference ``hedge`` parameter.
A dictionary that configures how the server will perform hedged reads.
It consists of the following keys:
- ``enabled``: Enables or disables hedged reads in sharded clusters.
Hedged reads are automatically enabled in MongoDB 4.4+ when using a
``nearest`` read preference. To explicitly enable hedged reads, set
the ``enabled`` key to ``true``::
>>> Nearest(hedge={'enabled': True})
To explicitly disable hedged reads, set the ``enabled`` key to
``False``::
>>> Nearest(hedge={'enabled': False})
.. versionadded:: 3.11
"""
return self.__hedge
@property
def min_wire_version(self) -> int:
"""The wire protocol version the server must support.
Some read preferences impose version requirements on all servers (e.g.
maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5).
All servers' maxWireVersion must be at least this read preference's
`min_wire_version`, or the driver raises
:exc:`~pymongo.errors.ConfigurationError`.
"""
return 0 if self.__max_staleness == -1 else 5
def __repr__(self) -> str:
return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format(
self.name,
self.__tag_sets,
self.__max_staleness,
self.__hedge,
)
def __eq__(self, other: Any) -> bool:
if isinstance(other, _ServerMode):
return (
self.mode == other.mode
and self.tag_sets == other.tag_sets
and self.max_staleness == other.max_staleness
and self.hedge == other.hedge
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __getstate__(self) -> dict[str, Any]:
"""Return value of object for pickling.
Needed explicitly because __slots__() defined.
"""
return {
"mode": self.__mode,
"tag_sets": self.__tag_sets,
"max_staleness": self.__max_staleness,
"hedge": self.__hedge,
}
def __setstate__(self, value: Mapping[str, Any]) -> None:
"""Restore from pickling."""
self.__mode = value["mode"]
self.__mongos_mode = _MONGOS_MODES[self.__mode]
self.__tag_sets = _validate_tag_sets(value["tag_sets"])
self.__max_staleness = _validate_max_staleness(value["max_staleness"])
self.__hedge = _validate_hedge(value["hedge"])
def __call__(self, selection: Selection) -> Selection:
return selection
class Primary(_ServerMode):
"""Primary read preference.
* When directly connected to one mongod queries are allowed if the server
is standalone or a replica set primary.
* When connected to a mongos queries are sent to the primary of a shard.
* When connected to a replica set queries are sent to the primary of
the replica set.
"""
__slots__ = ()
def __init__(self) -> None:
super().__init__(_PRIMARY)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to a Selection."""
return selection.primary_selection
def __repr__(self) -> str:
return "Primary()"
def __eq__(self, other: Any) -> bool:
if isinstance(other, _ServerMode):
return other.mode == _PRIMARY
return NotImplemented
class PrimaryPreferred(_ServerMode):
"""PrimaryPreferred read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are sent to the primary of a shard if
available, otherwise a shard secondary.
* When connected to a replica set queries are sent to the primary if
available, otherwise a secondary.
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
created reads will be routed to an available secondary until the
primary of the replica set is discovered.
:param tag_sets: The :attr:`~tag_sets` to use if the primary is not
available.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` to use if the primary is not available.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
if selection.primary:
return selection.primary_selection
else:
return secondary_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
class Secondary(_ServerMode):
"""Secondary read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among shard
secondaries. An error is raised if no secondaries are available.
* When connected to a replica set queries are distributed among
secondaries. An error is raised if no secondaries are available.
:param tag_sets: The :attr:`~tag_sets` for this read preference.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_SECONDARY, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
return secondary_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
class SecondaryPreferred(_ServerMode):
"""SecondaryPreferred read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among shard
secondaries, or the shard primary if no secondary is available.
* When connected to a replica set queries are distributed among
secondaries, or the primary if no secondary is available.
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
created reads will be routed to the primary of the replica set until
an available secondary is discovered.
:param tag_sets: The :attr:`~tag_sets` for this read preference.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
secondaries = secondary_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
if secondaries:
return secondaries
else:
return selection.primary_selection
class Nearest(_ServerMode):
"""Nearest read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among all members of
a shard.
* When connected to a replica set queries are distributed among all
members.
:param tag_sets: The :attr:`~tag_sets` for this read preference.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_NEAREST, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
return member_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
class _AggWritePref:
"""Agg $out/$merge write preference.
* If there are readable servers and there is any pre-5.0 server, use
primary read preference.
* Otherwise use `pref` read preference.
:param pref: The read preference to use on MongoDB 5.0+.
"""
__slots__ = ("pref", "effective_pref")
def __init__(self, pref: _ServerMode):
self.pref = pref
self.effective_pref: _ServerMode = ReadPreference.PRIMARY
def selection_hook(self, topology_description: TopologyDescription) -> None:
common_wv = topology_description.common_wire_version
if (
topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED)
and common_wv
and common_wv < 13
):
self.effective_pref = ReadPreference.PRIMARY
else:
self.effective_pref = self.pref
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to a Selection."""
return self.effective_pref(selection)
def __repr__(self) -> str:
return f"_AggWritePref(pref={self.pref!r})"
# Proxy other calls to the effective_pref so that _AggWritePref can be
# used in place of an actual read preference.
def __getattr__(self, name: str) -> Any:
return getattr(self.effective_pref, name)
_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest)
def make_read_preference(
mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1
) -> _ServerMode:
if mode == _PRIMARY:
if tag_sets not in (None, [{}]):
raise ConfigurationError("Read preference primary cannot be combined with tags")
if max_staleness != -1:
raise ConfigurationError(
"Read preference primary cannot be combined with maxStalenessSeconds"
)
return Primary()
return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore
_MODES = (
"PRIMARY",
"PRIMARY_PREFERRED",
"SECONDARY",
"SECONDARY_PREFERRED",
"NEAREST",
)
class ReadPreference:
"""An enum that defines some commonly used read preference modes.
Apps can also create a custom read preference, for example::
Nearest(tag_sets=[{"node":"analytics"}])
See :doc:`/examples/high_availability` for code examples.
A read preference is used in three cases:
:class:`~pymongo.mongo_client.MongoClient` connected to a single mongod:
- ``PRIMARY``: Queries are allowed if the server is standalone or a replica
set primary.
- All other modes allow queries to standalone servers, to a replica set
primary, or to replica set secondaries.
:class:`~pymongo.mongo_client.MongoClient` initialized with the
``replicaSet`` option:
- ``PRIMARY``: Read from the primary. This is the default, and provides the
strongest consistency. If no primary is available, raise
:class:`~pymongo.errors.AutoReconnect`.
- ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is
none, read from a secondary.
- ``SECONDARY``: Read from a secondary. If no secondary is available,
raise :class:`~pymongo.errors.AutoReconnect`.
- ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise
from the primary.
- ``NEAREST``: Read from any member.
:class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a
sharded cluster of replica sets:
- ``PRIMARY``: Read from the primary of the shard, or raise
:class:`~pymongo.errors.OperationFailure` if there is none.
This is the default.
- ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is
none, read from a secondary of the shard.
- ``SECONDARY``: Read from a secondary of the shard, or raise
:class:`~pymongo.errors.OperationFailure` if there is none.
- ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available,
otherwise from the shard primary.
- ``NEAREST``: Read from any shard member.
"""
PRIMARY = Primary()
PRIMARY_PREFERRED = PrimaryPreferred()
SECONDARY = Secondary()
SECONDARY_PREFERRED = SecondaryPreferred()
NEAREST = Nearest()
def read_pref_mode_from_name(name: str) -> int:
"""Get the read preference mode from mongos/uri name."""
return _MONGOS_MODES.index(name)
class MovingAverage:
"""Tracks an exponentially-weighted moving average."""
average: Optional[float]
def __init__(self) -> None:
self.average = None
def add_sample(self, sample: float) -> None:
if sample < 0:
# Likely system time change while waiting for hello response
# and not using time.monotonic. Ignore it, the next one will
# probably be valid.
return
if self.average is None:
self.average = sample
else:
# The Server Selection Spec requires an exponentially weighted
# average with alpha = 0.2.
self.average = 0.8 * self.average + 0.2 * sample
def get(self) -> Optional[float]:
"""Get the calculated average, or None if no samples yet."""
return self.average
def reset(self) -> None:
self.average = None

View File

@ -0,0 +1,133 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Represent a response from the server."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union
if TYPE_CHECKING:
from datetime import timedelta
from pymongo.asynchronous.message import _OpMsg, _OpReply
from pymongo.asynchronous.pool import Connection
from pymongo.asynchronous.typings import _Address, _DocumentOut
_IS_SYNC = False
class Response:
__slots__ = ("_data", "_address", "_request_id", "_duration", "_from_command", "_docs")
def __init__(
self,
data: Union[_OpMsg, _OpReply],
address: _Address,
request_id: int,
duration: Optional[timedelta],
from_command: bool,
docs: Sequence[Mapping[str, Any]],
):
"""Represent a response from the server.
:param data: A network response message.
:param address: (host, port) of the source server.
:param request_id: The request id of this operation.
:param duration: The duration of the operation.
:param from_command: if the response is the result of a db command.
"""
self._data = data
self._address = address
self._request_id = request_id
self._duration = duration
self._from_command = from_command
self._docs = docs
@property
def data(self) -> Union[_OpMsg, _OpReply]:
"""Server response's raw BSON bytes."""
return self._data
@property
def address(self) -> _Address:
"""(host, port) of the source server."""
return self._address
@property
def request_id(self) -> int:
"""The request id of this operation."""
return self._request_id
@property
def duration(self) -> Optional[timedelta]:
"""The duration of the operation."""
return self._duration
@property
def from_command(self) -> bool:
"""If the response is a result from a db command."""
return self._from_command
@property
def docs(self) -> Sequence[Mapping[str, Any]]:
"""The decoded document(s)."""
return self._docs
class PinnedResponse(Response):
__slots__ = ("_conn", "_more_to_come")
def __init__(
self,
data: Union[_OpMsg, _OpReply],
address: _Address,
conn: Connection,
request_id: int,
duration: Optional[timedelta],
from_command: bool,
docs: list[_DocumentOut],
more_to_come: bool,
):
"""Represent a response to an exhaust cursor's initial query.
:param data: A network response message.
:param address: (host, port) of the source server.
:param conn: The Connection used for the initial query.
:param request_id: The request id of this operation.
:param duration: The duration of the operation.
:param from_command: If the response is the result of a db command.
:param docs: List of documents.
:param more_to_come: Bool indicating whether cursor is ready to be
exhausted.
"""
super().__init__(data, address, request_id, duration, from_command, docs)
self._conn = conn
self._more_to_come = more_to_come
@property
def conn(self) -> Connection:
"""The Connection used for the initial query.
The server will send batches on this socket, without waiting for
getMores from the client, until the result set is exhausted or there
is an error.
"""
return self._conn
@property
def more_to_come(self) -> bool:
"""If true, server is ready to send batches on the socket until the
result set is exhausted or there is an error.
"""
return self._more_to_come

View File

@ -0,0 +1,355 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Communicate with one MongoDB server in a topology."""
from __future__ import annotations
import logging
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Callable,
Optional,
Union,
)
from bson import _decode_all_selective
from pymongo.asynchronous.helpers import _check_command_response, _handle_reauth
from pymongo.asynchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.asynchronous.message import _convert_exception, _GetMore, _OpMsg, _Query
from pymongo.asynchronous.response import PinnedResponse, Response
from pymongo.errors import NotPrimaryError, OperationFailure
if TYPE_CHECKING:
from queue import Queue
from weakref import ReferenceType
from bson.objectid import ObjectId
from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler
from pymongo.asynchronous.monitor import Monitor
from pymongo.asynchronous.monitoring import _EventListeners
from pymongo.asynchronous.pool import Connection, Pool
from pymongo.asynchronous.read_preferences import _ServerMode
from pymongo.asynchronous.server_description import ServerDescription
from pymongo.asynchronous.typings import _DocumentOut
_IS_SYNC = False
_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}
class Server:
def __init__(
self,
server_description: ServerDescription,
pool: Pool,
monitor: Monitor,
topology_id: Optional[ObjectId] = None,
listeners: Optional[_EventListeners] = None,
events: Optional[ReferenceType[Queue]] = None,
) -> None:
"""Represent one MongoDB server."""
self._description = server_description
self._pool = pool
self._monitor = monitor
self._topology_id = topology_id
self._publish = listeners is not None and listeners.enabled_for_server
self._listener = listeners
self._events = None
if self._publish:
self._events = events() # type: ignore[misc]
async def open(self) -> None:
"""Start monitoring, or restart after a fork.
Multiple calls have no effect.
"""
if not self._pool.opts.load_balanced:
self._monitor.open()
async def reset(self, service_id: Optional[ObjectId] = None) -> None:
"""Clear the connection pool."""
await self.pool.reset(service_id)
async def close(self) -> None:
"""Clear the connection pool and stop the monitor.
Reconnect with open().
"""
if self._publish:
assert self._listener is not None
assert self._events is not None
self._events.put(
(
self._listener.publish_server_closed,
(self._description.address, self._topology_id),
)
)
await self._monitor.close()
await self._pool.close()
def request_check(self) -> None:
"""Check the server's state soon."""
self._monitor.request_check()
@_handle_reauth
async def run_operation(
self,
conn: Connection,
operation: Union[_Query, _GetMore],
read_preference: _ServerMode,
listeners: Optional[_EventListeners],
unpack_res: Callable[..., list[_DocumentOut]],
client: AsyncMongoClient,
) -> Response:
"""Run a _Query or _GetMore operation and return a Response object.
This method is used only to run _Query/_GetMore operations from
cursors.
Can raise ConnectionFailure, OperationFailure, etc.
:param conn: A Connection instance.
:param operation: A _Query or _GetMore object.
:param read_preference: The read preference to use.
:param listeners: Instance of _EventListeners or None.
:param unpack_res: A callable that decodes the wire protocol response.
"""
duration = None
assert listeners is not None
publish = listeners.enabled_for_commands
start = datetime.now()
use_cmd = operation.use_command(conn)
more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come
if more_to_come:
request_id = 0
else:
message = await operation.get_message(read_preference, conn, use_cmd)
request_id, data, max_doc_size = self._split_message(message)
cmd, dbn = await operation.as_command(conn)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=cmd,
commandName=next(iter(cmd)),
databaseName=dbn,
requestId=request_id,
operationId=request_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=conn.address[0],
serverPort=conn.address[1],
serviceId=conn.service_id,
)
if publish:
cmd, dbn = await operation.as_command(conn)
if "$db" not in cmd:
cmd["$db"] = dbn
assert listeners is not None
listeners.publish_command_start(
cmd,
dbn,
request_id,
conn.address,
conn.server_connection_id,
service_id=conn.service_id,
)
try:
if more_to_come:
reply = await conn.receive_message(None)
else:
await conn.send_message(data, max_doc_size)
reply = await conn.receive_message(request_id)
# Unpack and check for command errors.
if use_cmd:
user_fields = _CURSOR_DOC_FIELDS
legacy_response = False
else:
user_fields = None
legacy_response = True
docs = unpack_res(
reply,
operation.cursor_id,
operation.codec_options,
legacy_response=legacy_response,
user_fields=user_fields,
)
if use_cmd:
first = docs[0]
await operation.client._process_response(first, operation.session)
_check_command_response(first, conn.max_wire_version)
except Exception as exc:
duration = datetime.now() - start
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=dbn,
requestId=request_id,
operationId=request_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=conn.address[0],
serverPort=conn.address[1],
serviceId=conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if publish:
assert listeners is not None
listeners.publish_command_failure(
duration,
failure,
operation.name,
request_id,
conn.address,
conn.server_connection_id,
service_id=conn.service_id,
database_name=dbn,
)
raise
duration = datetime.now() - start
# Must publish in find / getMore / explain command response
# format.
if use_cmd:
res = docs[0]
elif operation.name == "explain":
res = docs[0] if docs else {}
else:
res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr]
if operation.name == "find":
res["cursor"]["firstBatch"] = docs
else:
res["cursor"]["nextBatch"] = docs
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=res,
commandName=next(iter(cmd)),
databaseName=dbn,
requestId=request_id,
operationId=request_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=conn.address[0],
serverPort=conn.address[1],
serviceId=conn.service_id,
)
if publish:
assert listeners is not None
listeners.publish_command_success(
duration,
res,
operation.name,
request_id,
conn.address,
conn.server_connection_id,
service_id=conn.service_id,
database_name=dbn,
)
# Decrypt response.
client = operation.client
if client and client._encrypter:
if use_cmd:
decrypted = client._encrypter.decrypt(reply.raw_command_response())
docs = _decode_all_selective(decrypted, operation.codec_options, user_fields)
response: Response
if client._should_pin_cursor(operation.session) or operation.exhaust:
conn.pin_cursor()
if isinstance(reply, _OpMsg):
# In OP_MSG, the server keeps sending only if the
# more_to_come flag is set.
more_to_come = reply.more_to_come
else:
# In OP_REPLY, the server keeps sending until cursor_id is 0.
more_to_come = bool(operation.exhaust and reply.cursor_id)
if operation.conn_mgr:
operation.conn_mgr.update_exhaust(more_to_come)
response = PinnedResponse(
data=reply,
address=self._description.address,
conn=conn,
duration=duration,
request_id=request_id,
from_command=use_cmd,
docs=docs,
more_to_come=more_to_come,
)
else:
response = Response(
data=reply,
address=self._description.address,
duration=duration,
request_id=request_id,
from_command=use_cmd,
docs=docs,
)
return response
async def checkout(
self, handler: Optional[_MongoClientErrorHandler] = None
) -> AsyncContextManager[Connection]:
return self.pool.checkout(handler)
@property
def description(self) -> ServerDescription:
return self._description
@description.setter
def description(self, server_description: ServerDescription) -> None:
assert server_description.address == self._description.address
self._description = server_description
@property
def pool(self) -> Pool:
return self._pool
def _split_message(
self, message: Union[tuple[int, Any], tuple[int, Any, int]]
) -> tuple[int, Any, int]:
"""Return request_id, data, max_doc_size.
:param message: (request_id, data, max_doc_size) or (request_id, data)
"""
if len(message) == 3:
return message # type: ignore[return-value]
else:
# get_more and kill_cursors messages don't include BSON documents.
request_id, data = message # type: ignore[misc]
return request_id, data, 0
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self._description!r}>"

View File

@ -0,0 +1,301 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Represent one server the driver is connected to."""
from __future__ import annotations
import time
import warnings
from typing import Any, Mapping, Optional
from bson import EPOCH_NAIVE
from bson.objectid import ObjectId
from pymongo.asynchronous.hello import Hello
from pymongo.asynchronous.typings import ClusterTime, _Address
from pymongo.server_type import SERVER_TYPE
_IS_SYNC = False
class ServerDescription:
"""Immutable representation of one server.
:param address: A (host, port) pair
:param hello: Optional Hello instance
:param round_trip_time: Optional float
:param error: Optional, the last error attempting to connect to the server
:param round_trip_time: Optional float, the min latency from the most recent samples
"""
__slots__ = (
"_address",
"_server_type",
"_all_hosts",
"_tags",
"_replica_set_name",
"_primary",
"_max_bson_size",
"_max_message_size",
"_max_write_batch_size",
"_min_wire_version",
"_max_wire_version",
"_round_trip_time",
"_min_round_trip_time",
"_me",
"_is_writable",
"_is_readable",
"_ls_timeout_minutes",
"_error",
"_set_version",
"_election_id",
"_cluster_time",
"_last_write_date",
"_last_update_time",
"_topology_version",
)
def __init__(
self,
address: _Address,
hello: Optional[Hello] = None,
round_trip_time: Optional[float] = None,
error: Optional[Exception] = None,
min_round_trip_time: float = 0.0,
) -> None:
self._address = address
if not hello:
hello = Hello({})
self._server_type = hello.server_type
self._all_hosts = hello.all_hosts
self._tags = hello.tags
self._replica_set_name = hello.replica_set_name
self._primary = hello.primary
self._max_bson_size = hello.max_bson_size
self._max_message_size = hello.max_message_size
self._max_write_batch_size = hello.max_write_batch_size
self._min_wire_version = hello.min_wire_version
self._max_wire_version = hello.max_wire_version
self._set_version = hello.set_version
self._election_id = hello.election_id
self._cluster_time = hello.cluster_time
self._is_writable = hello.is_writable
self._is_readable = hello.is_readable
self._ls_timeout_minutes = hello.logical_session_timeout_minutes
self._round_trip_time = round_trip_time
self._min_round_trip_time = min_round_trip_time
self._me = hello.me
self._last_update_time = time.monotonic()
self._error = error
self._topology_version = hello.topology_version
if error:
details = getattr(error, "details", None)
if isinstance(details, dict):
self._topology_version = details.get("topologyVersion")
self._last_write_date: Optional[float]
if hello.last_write_date:
# Convert from datetime to seconds.
delta = hello.last_write_date - EPOCH_NAIVE
self._last_write_date = delta.total_seconds()
else:
self._last_write_date = None
@property
def address(self) -> _Address:
"""The address (host, port) of this server."""
return self._address
@property
def server_type(self) -> int:
"""The type of this server."""
return self._server_type
@property
def server_type_name(self) -> str:
"""The server type as a human readable string.
.. versionadded:: 3.4
"""
return SERVER_TYPE._fields[self._server_type]
@property
def all_hosts(self) -> set[tuple[str, int]]:
"""List of hosts, passives, and arbiters known to this server."""
return self._all_hosts
@property
def tags(self) -> Mapping[str, Any]:
return self._tags
@property
def replica_set_name(self) -> Optional[str]:
"""Replica set name or None."""
return self._replica_set_name
@property
def primary(self) -> Optional[tuple[str, int]]:
"""This server's opinion about who the primary is, or None."""
return self._primary
@property
def max_bson_size(self) -> int:
return self._max_bson_size
@property
def max_message_size(self) -> int:
return self._max_message_size
@property
def max_write_batch_size(self) -> int:
return self._max_write_batch_size
@property
def min_wire_version(self) -> int:
return self._min_wire_version
@property
def max_wire_version(self) -> int:
return self._max_wire_version
@property
def set_version(self) -> Optional[int]:
return self._set_version
@property
def election_id(self) -> Optional[ObjectId]:
return self._election_id
@property
def cluster_time(self) -> Optional[ClusterTime]:
return self._cluster_time
@property
def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]:
warnings.warn(
"'election_tuple' is deprecated, use 'set_version' and 'election_id' instead",
DeprecationWarning,
stacklevel=2,
)
return self._set_version, self._election_id
@property
def me(self) -> Optional[tuple[str, int]]:
return self._me
@property
def logical_session_timeout_minutes(self) -> Optional[int]:
return self._ls_timeout_minutes
@property
def last_write_date(self) -> Optional[float]:
return self._last_write_date
@property
def last_update_time(self) -> float:
return self._last_update_time
@property
def round_trip_time(self) -> Optional[float]:
"""The current average latency or None."""
# This override is for unittesting only!
if self._address in self._host_to_round_trip_time:
return self._host_to_round_trip_time[self._address]
return self._round_trip_time
@property
def min_round_trip_time(self) -> float:
"""The min latency from the most recent samples."""
return self._min_round_trip_time
@property
def error(self) -> Optional[Exception]:
"""The last error attempting to connect to the server, or None."""
return self._error
@property
def is_writable(self) -> bool:
return self._is_writable
@property
def is_readable(self) -> bool:
return self._is_readable
@property
def mongos(self) -> bool:
return self._server_type == SERVER_TYPE.Mongos
@property
def is_server_type_known(self) -> bool:
return self.server_type != SERVER_TYPE.Unknown
@property
def retryable_writes_supported(self) -> bool:
"""Checks if this server supports retryable writes."""
return (
self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)
) or self._server_type == SERVER_TYPE.LoadBalancer
@property
def retryable_reads_supported(self) -> bool:
"""Checks if this server supports retryable writes."""
return self._max_wire_version >= 6
@property
def topology_version(self) -> Optional[Mapping[str, Any]]:
return self._topology_version
def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription:
unknown = ServerDescription(self.address, error=error)
unknown._topology_version = self.topology_version
return unknown
def __eq__(self, other: Any) -> bool:
if isinstance(other, ServerDescription):
return (
(self._address == other.address)
and (self._server_type == other.server_type)
and (self._min_wire_version == other.min_wire_version)
and (self._max_wire_version == other.max_wire_version)
and (self._me == other.me)
and (self._all_hosts == other.all_hosts)
and (self._tags == other.tags)
and (self._replica_set_name == other.replica_set_name)
and (self._set_version == other.set_version)
and (self._election_id == other.election_id)
and (self._primary == other.primary)
and (self._ls_timeout_minutes == other.logical_session_timeout_minutes)
and (self._error == other.error)
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __repr__(self) -> str:
errmsg = ""
if self.error:
errmsg = f", error={self.error!r}"
return "<{} {} server_type: {}, rtt: {}{}>".format(
self.__class__.__name__,
self.address,
self.server_type_name,
self.round_trip_time,
errmsg,
)
# For unittesting only. Use under no circumstances!
_host_to_round_trip_time: dict = {}

View File

@ -0,0 +1,175 @@
# Copyright 2014-2016 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Criteria to select some ServerDescriptions from a TopologyDescription."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, TypeVar, cast
from pymongo.server_type import SERVER_TYPE
if TYPE_CHECKING:
from pymongo.asynchronous.server_description import ServerDescription
from pymongo.asynchronous.topology_description import TopologyDescription
_IS_SYNC = False
T = TypeVar("T")
TagSet = Mapping[str, Any]
TagSets = Sequence[TagSet]
class Selection:
"""Input or output of a server selector function."""
@classmethod
def from_topology_description(cls, topology_description: TopologyDescription) -> Selection:
known_servers = topology_description.known_servers
primary = None
for sd in known_servers:
if sd.server_type == SERVER_TYPE.RSPrimary:
primary = sd
break
return Selection(
topology_description,
topology_description.known_servers,
topology_description.common_wire_version,
primary,
)
def __init__(
self,
topology_description: TopologyDescription,
server_descriptions: list[ServerDescription],
common_wire_version: Optional[int],
primary: Optional[ServerDescription],
):
self.topology_description = topology_description
self.server_descriptions = server_descriptions
self.primary = primary
self.common_wire_version = common_wire_version
def with_server_descriptions(self, server_descriptions: list[ServerDescription]) -> Selection:
return Selection(
self.topology_description, server_descriptions, self.common_wire_version, self.primary
)
def secondary_with_max_last_write_date(self) -> Optional[ServerDescription]:
secondaries = secondary_server_selector(self)
if secondaries.server_descriptions:
return max(
secondaries.server_descriptions, key=lambda sd: cast(float, sd.last_write_date)
)
return None
@property
def primary_selection(self) -> Selection:
primaries = [self.primary] if self.primary else []
return self.with_server_descriptions(primaries)
@property
def heartbeat_frequency(self) -> int:
return self.topology_description.heartbeat_frequency
@property
def topology_type(self) -> int:
return self.topology_description.topology_type
def __bool__(self) -> bool:
return bool(self.server_descriptions)
def __getitem__(self, item: int) -> ServerDescription:
return self.server_descriptions[item]
def any_server_selector(selection: T) -> T:
return selection
def readable_server_selector(selection: Selection) -> Selection:
return selection.with_server_descriptions(
[s for s in selection.server_descriptions if s.is_readable]
)
def writable_server_selector(selection: Selection) -> Selection:
return selection.with_server_descriptions(
[s for s in selection.server_descriptions if s.is_writable]
)
def secondary_server_selector(selection: Selection) -> Selection:
return selection.with_server_descriptions(
[s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSSecondary]
)
def arbiter_server_selector(selection: Selection) -> Selection:
return selection.with_server_descriptions(
[s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSArbiter]
)
def writable_preferred_server_selector(selection: Selection) -> Selection:
"""Like PrimaryPreferred but doesn't use tags or latency."""
return writable_server_selector(selection) or secondary_server_selector(selection)
def apply_single_tag_set(tag_set: TagSet, selection: Selection) -> Selection:
"""All servers matching one tag set.
A tag set is a dict. A server matches if its tags are a superset:
A server tagged {'a': '1', 'b': '2'} matches the tag set {'a': '1'}.
The empty tag set {} matches any server.
"""
def tags_match(server_tags: Mapping[str, Any]) -> bool:
for key, value in tag_set.items():
if key not in server_tags or server_tags[key] != value:
return False
return True
return selection.with_server_descriptions(
[s for s in selection.server_descriptions if tags_match(s.tags)]
)
def apply_tag_sets(tag_sets: TagSets, selection: Selection) -> Selection:
"""All servers match a list of tag sets.
tag_sets is a list of dicts. The empty tag set {} matches any server,
and may be provided at the end of the list as a fallback. So
[{'a': 'value'}, {}] expresses a preference for servers tagged
{'a': 'value'}, but accepts any server if none matches the first
preference.
"""
for tag_set in tag_sets:
with_tag_set = apply_single_tag_set(tag_set, selection)
if with_tag_set:
return with_tag_set
return selection.with_server_descriptions([])
def secondary_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection:
"""All near-enough secondaries matching the tag sets."""
return apply_tag_sets(tag_sets, secondary_server_selector(selection))
def member_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection:
"""All near-enough members matching the tag sets."""
return apply_tag_sets(tag_sets, readable_server_selector(selection))

View File

@ -0,0 +1,170 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Represent MongoClient's configuration."""
from __future__ import annotations
import threading
import traceback
from typing import Any, Collection, Optional, Type, Union
from bson.objectid import ObjectId
from pymongo.asynchronous import common, monitor, pool
from pymongo.asynchronous.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT
from pymongo.asynchronous.pool import Pool, PoolOptions
from pymongo.asynchronous.server_description import ServerDescription
from pymongo.asynchronous.topology_description import TOPOLOGY_TYPE, _ServerSelector
from pymongo.errors import ConfigurationError
_IS_SYNC = False
class TopologySettings:
def __init__(
self,
seeds: Optional[Collection[tuple[str, int]]] = None,
replica_set_name: Optional[str] = None,
pool_class: Optional[Type[Pool]] = None,
pool_options: Optional[PoolOptions] = None,
monitor_class: Optional[Type[monitor.Monitor]] = None,
condition_class: Optional[Type[threading.Condition]] = None,
local_threshold_ms: int = LOCAL_THRESHOLD_MS,
server_selection_timeout: int = SERVER_SELECTION_TIMEOUT,
heartbeat_frequency: int = common.HEARTBEAT_FREQUENCY,
server_selector: Optional[_ServerSelector] = None,
fqdn: Optional[str] = None,
direct_connection: Optional[bool] = False,
load_balanced: Optional[bool] = None,
srv_service_name: str = common.SRV_SERVICE_NAME,
srv_max_hosts: int = 0,
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
):
"""Represent MongoClient's configuration.
Take a list of (host, port) pairs and optional replica set name.
"""
if heartbeat_frequency < common.MIN_HEARTBEAT_INTERVAL:
raise ConfigurationError(
"heartbeatFrequencyMS cannot be less than %d"
% (common.MIN_HEARTBEAT_INTERVAL * 1000,)
)
self._seeds: Collection[tuple[str, int]] = seeds or [("localhost", 27017)]
self._replica_set_name = replica_set_name
self._pool_class: Type[Pool] = pool_class or pool.Pool
self._pool_options: PoolOptions = pool_options or PoolOptions()
self._monitor_class: Type[monitor.Monitor] = monitor_class or monitor.Monitor
self._condition_class: Type[threading.Condition] = condition_class or threading.Condition
self._local_threshold_ms = local_threshold_ms
self._server_selection_timeout = server_selection_timeout
self._server_selector = server_selector
self._fqdn = fqdn
self._heartbeat_frequency = heartbeat_frequency
self._direct = direct_connection
self._load_balanced = load_balanced
self._srv_service_name = srv_service_name
self._srv_max_hosts = srv_max_hosts or 0
self._server_monitoring_mode = server_monitoring_mode
self._topology_id = ObjectId()
# Store the allocation traceback to catch unclosed clients in the
# test suite.
self._stack = "".join(traceback.format_stack())
@property
def seeds(self) -> Collection[tuple[str, int]]:
"""List of server addresses."""
return self._seeds
@property
def replica_set_name(self) -> Optional[str]:
return self._replica_set_name
@property
def pool_class(self) -> Type[Pool]:
return self._pool_class
@property
def pool_options(self) -> PoolOptions:
return self._pool_options
@property
def monitor_class(self) -> Type[monitor.Monitor]:
return self._monitor_class
@property
def condition_class(self) -> Type[threading.Condition]:
return self._condition_class
@property
def local_threshold_ms(self) -> int:
return self._local_threshold_ms
@property
def server_selection_timeout(self) -> int:
return self._server_selection_timeout
@property
def server_selector(self) -> Optional[_ServerSelector]:
return self._server_selector
@property
def heartbeat_frequency(self) -> int:
return self._heartbeat_frequency
@property
def fqdn(self) -> Optional[str]:
return self._fqdn
@property
def direct(self) -> Optional[bool]:
"""Connect directly to a single server, or use a set of servers?
True if there is one seed and no replica_set_name.
"""
return self._direct
@property
def load_balanced(self) -> Optional[bool]:
"""True if the client was configured to connect to a load balancer."""
return self._load_balanced
@property
def srv_service_name(self) -> str:
"""The srvServiceName."""
return self._srv_service_name
@property
def srv_max_hosts(self) -> int:
"""The srvMaxHosts."""
return self._srv_max_hosts
@property
def server_monitoring_mode(self) -> str:
"""The serverMonitoringMode."""
return self._server_monitoring_mode
def get_topology_type(self) -> int:
if self.load_balanced:
return TOPOLOGY_TYPE.LoadBalanced
elif self.direct:
return TOPOLOGY_TYPE.Single
elif self.replica_set_name is not None:
return TOPOLOGY_TYPE.ReplicaSetNoPrimary
else:
return TOPOLOGY_TYPE.Unknown
def get_server_descriptions(self) -> dict[Union[tuple[str, int], Any], ServerDescription]:
"""Initial dict of (address, ServerDescription) for all seeds."""
return {address: ServerDescription(address) for address in self.seeds}

View File

@ -0,0 +1,149 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
from __future__ import annotations
import ipaddress
import random
from typing import TYPE_CHECKING, Any, Optional, Union
from pymongo.asynchronous.common import CONNECT_TIMEOUT
from pymongo.errors import ConfigurationError
if TYPE_CHECKING:
from dns import resolver
_IS_SYNC = False
def _have_dnspython() -> bool:
try:
import dns # noqa: F401
return True
except ImportError:
return False
# dnspython can return bytes or str from various parts
# of its API depending on version. We always want str.
def maybe_decode(text: Union[str, bytes]) -> str:
if isinstance(text, bytes):
return text.decode()
return text
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
from dns import resolver
if hasattr(resolver, "resolve"):
# dnspython >= 2
return resolver.resolve(*args, **kwargs)
# dnspython 1.X
return resolver.query(*args, **kwargs)
_INVALID_HOST_MSG = (
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. "
"Did you mean to use 'mongodb://'?"
)
class _SrvResolver:
def __init__(
self,
fqdn: str,
connect_timeout: Optional[float],
srv_service_name: str,
srv_max_hosts: int = 0,
):
self.__fqdn = fqdn
self.__srv = srv_service_name
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
self.__srv_max_hosts = srv_max_hosts or 0
# Validate the fully qualified domain name.
try:
ipaddress.ip_address(fqdn)
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
except ValueError:
pass
try:
self.__plist = self.__fqdn.split(".")[1:]
except Exception:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
self.__slen = len(self.__plist)
if self.__slen < 2:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
def get_options(self) -> Optional[str]:
from dns import resolver
try:
results = _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout)
except (resolver.NoAnswer, resolver.NXDOMAIN):
# No TXT records
return None
except Exception as exc:
raise ConfigurationError(str(exc)) from None
if len(results) > 1:
raise ConfigurationError("Only one TXT record is supported")
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8")
def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer:
try:
results = _resolve(
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
)
except Exception as exc:
if not encapsulate_errors:
# Raise the original error.
raise
# Else, raise all errors as ConfigurationError.
raise ConfigurationError(str(exc)) from None
return results
def _get_srv_response_and_hosts(
self, encapsulate_errors: bool
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
results = self._resolve_uri(encapsulate_errors)
# Construct address tuples
nodes = [
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) for res in results
]
# Validate hosts
for node in nodes:
try:
nlist = node[0].lower().split(".")[1:][-self.__slen :]
except Exception:
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
if self.__plist != nlist:
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
if self.__srv_max_hosts:
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
return results, nodes
def get_hosts(self) -> list[tuple[str, Any]]:
_, nodes = self._get_srv_response_and_hosts(True)
return nodes
def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]:
results, nodes = self._get_srv_response_and_hosts(False)
rrset = results.rrset
ttl = rrset.ttl if rrset else 0
return nodes, ttl

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,678 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Represent a deployment of MongoDB servers."""
from __future__ import annotations
from random import sample
from typing import (
Any,
Callable,
List,
Mapping,
MutableMapping,
NamedTuple,
Optional,
cast,
)
from bson.min_key import MinKey
from bson.objectid import ObjectId
from pymongo.asynchronous import common
from pymongo.asynchronous.read_preferences import ReadPreference, _AggWritePref, _ServerMode
from pymongo.asynchronous.server_description import ServerDescription
from pymongo.asynchronous.server_selectors import Selection
from pymongo.asynchronous.typings import _Address
from pymongo.errors import ConfigurationError
from pymongo.server_type import SERVER_TYPE
_IS_SYNC = False
# Enumeration for various kinds of MongoDB cluster topologies.
class _TopologyType(NamedTuple):
Single: int
ReplicaSetNoPrimary: int
ReplicaSetWithPrimary: int
Sharded: int
Unknown: int
LoadBalanced: int
TOPOLOGY_TYPE = _TopologyType(*range(6))
# Topologies compatible with SRV record polling.
SRV_POLLING_TOPOLOGIES: tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded)
_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]]
class TopologyDescription:
def __init__(
self,
topology_type: int,
server_descriptions: dict[_Address, ServerDescription],
replica_set_name: Optional[str],
max_set_version: Optional[int],
max_election_id: Optional[ObjectId],
topology_settings: Any,
) -> None:
"""Representation of a deployment of MongoDB servers.
:param topology_type: initial type
:param server_descriptions: dict of (address, ServerDescription) for
all seeds
:param replica_set_name: replica set name or None
:param max_set_version: greatest setVersion seen from a primary, or None
:param max_election_id: greatest electionId seen from a primary, or None
:param topology_settings: a TopologySettings
"""
self._topology_type = topology_type
self._replica_set_name = replica_set_name
self._server_descriptions = server_descriptions
self._max_set_version = max_set_version
self._max_election_id = max_election_id
# The heartbeat_frequency is used in staleness estimates.
self._topology_settings = topology_settings
# Is PyMongo compatible with all servers' wire protocols?
self._incompatible_err = None
if self._topology_type != TOPOLOGY_TYPE.LoadBalanced:
self._init_incompatible_err()
# Server Discovery And Monitoring Spec: Whenever a client updates the
# TopologyDescription from an hello response, it MUST set
# TopologyDescription.logicalSessionTimeoutMinutes to the smallest
# logicalSessionTimeoutMinutes value among ServerDescriptions of all
# data-bearing server types. If any have a null
# logicalSessionTimeoutMinutes, then
# TopologyDescription.logicalSessionTimeoutMinutes MUST be set to null.
readable_servers = self.readable_servers
if not readable_servers:
self._ls_timeout_minutes = None
elif any(s.logical_session_timeout_minutes is None for s in readable_servers):
self._ls_timeout_minutes = None
else:
self._ls_timeout_minutes = min( # type: ignore[type-var]
s.logical_session_timeout_minutes for s in readable_servers
)
def _init_incompatible_err(self) -> None:
"""Internal compatibility check for non-load balanced topologies."""
for s in self._server_descriptions.values():
if not s.is_server_type_known:
continue
# s.min/max_wire_version is the server's wire protocol.
# MIN/MAX_SUPPORTED_WIRE_VERSION is what PyMongo supports.
server_too_new = (
# Server too new.
s.min_wire_version is not None
and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION
)
server_too_old = (
# Server too old.
s.max_wire_version is not None
and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION
)
if server_too_new:
self._incompatible_err = (
"Server at %s:%d requires wire version %d, but this " # type: ignore
"version of PyMongo only supports up to %d."
% (
s.address[0],
s.address[1] or 0,
s.min_wire_version,
common.MAX_SUPPORTED_WIRE_VERSION,
)
)
elif server_too_old:
self._incompatible_err = (
"Server at %s:%d reports wire version %d, but this " # type: ignore
"version of PyMongo requires at least %d (MongoDB %s)."
% (
s.address[0],
s.address[1] or 0,
s.max_wire_version,
common.MIN_SUPPORTED_WIRE_VERSION,
common.MIN_SUPPORTED_SERVER_VERSION,
)
)
break
def check_compatible(self) -> None:
"""Raise ConfigurationError if any server is incompatible.
A server is incompatible if its wire protocol version range does not
overlap with PyMongo's.
"""
if self._incompatible_err:
raise ConfigurationError(self._incompatible_err)
def has_server(self, address: _Address) -> bool:
return address in self._server_descriptions
def reset_server(self, address: _Address) -> TopologyDescription:
"""A copy of this description, with one server marked Unknown."""
unknown_sd = self._server_descriptions[address].to_unknown()
return updated_topology_description(self, unknown_sd)
def reset(self) -> TopologyDescription:
"""A copy of this description, with all servers marked Unknown."""
if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary:
topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary
else:
topology_type = self._topology_type
# The default ServerDescription's type is Unknown.
sds = {address: ServerDescription(address) for address in self._server_descriptions}
return TopologyDescription(
topology_type,
sds,
self._replica_set_name,
self._max_set_version,
self._max_election_id,
self._topology_settings,
)
def server_descriptions(self) -> dict[_Address, ServerDescription]:
"""dict of (address,
:class:`~pymongo.server_description.ServerDescription`).
"""
return self._server_descriptions.copy()
@property
def topology_type(self) -> int:
"""The type of this topology."""
return self._topology_type
@property
def topology_type_name(self) -> str:
"""The topology type as a human readable string.
.. versionadded:: 3.4
"""
return TOPOLOGY_TYPE._fields[self._topology_type]
@property
def replica_set_name(self) -> Optional[str]:
"""The replica set name."""
return self._replica_set_name
@property
def max_set_version(self) -> Optional[int]:
"""Greatest setVersion seen from a primary, or None."""
return self._max_set_version
@property
def max_election_id(self) -> Optional[ObjectId]:
"""Greatest electionId seen from a primary, or None."""
return self._max_election_id
@property
def logical_session_timeout_minutes(self) -> Optional[int]:
"""Minimum logical session timeout, or None."""
return self._ls_timeout_minutes
@property
def known_servers(self) -> list[ServerDescription]:
"""List of Servers of types besides Unknown."""
return [s for s in self._server_descriptions.values() if s.is_server_type_known]
@property
def has_known_servers(self) -> bool:
"""Whether there are any Servers of types besides Unknown."""
return any(s for s in self._server_descriptions.values() if s.is_server_type_known)
@property
def readable_servers(self) -> list[ServerDescription]:
"""List of readable Servers."""
return [s for s in self._server_descriptions.values() if s.is_readable]
@property
def common_wire_version(self) -> Optional[int]:
"""Minimum of all servers' max wire versions, or None."""
servers = self.known_servers
if servers:
return min(s.max_wire_version for s in self.known_servers)
return None
@property
def heartbeat_frequency(self) -> int:
return self._topology_settings.heartbeat_frequency
@property
def srv_max_hosts(self) -> int:
return self._topology_settings._srv_max_hosts
def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]:
if not selection:
return []
round_trip_times: list[float] = []
for server in selection.server_descriptions:
if server.round_trip_time is None:
config_err_msg = f"round_trip_time for server {server.address} is unexpectedly None: {self}, servers: {selection.server_descriptions}"
raise ConfigurationError(config_err_msg)
round_trip_times.append(server.round_trip_time)
# Round trip time in seconds.
fastest = min(round_trip_times)
threshold = self._topology_settings.local_threshold_ms / 1000.0
return [
s
for s in selection.server_descriptions
if (cast(float, s.round_trip_time) - fastest) <= threshold
]
def apply_selector(
self,
selector: Any,
address: Optional[_Address] = None,
custom_selector: Optional[_ServerSelector] = None,
) -> list[ServerDescription]:
"""List of servers matching the provided selector(s).
:param selector: a callable that takes a Selection as input and returns
a Selection as output. For example, an instance of a read
preference from :mod:`~pymongo.read_preferences`.
:param address: A server address to select.
:param custom_selector: A callable that augments server
selection rules. Accepts a list of
:class:`~pymongo.server_description.ServerDescription` objects and
return a list of server descriptions that should be considered
suitable for the desired operation.
.. versionadded:: 3.4
"""
if getattr(selector, "min_wire_version", 0):
common_wv = self.common_wire_version
if common_wv and common_wv < selector.min_wire_version:
raise ConfigurationError(
"%s requires min wire version %d, but topology's min"
" wire version is %d" % (selector, selector.min_wire_version, common_wv)
)
if isinstance(selector, _AggWritePref):
selector.selection_hook(self)
if self.topology_type == TOPOLOGY_TYPE.Unknown:
return []
elif self.topology_type in (TOPOLOGY_TYPE.Single, TOPOLOGY_TYPE.LoadBalanced):
# Ignore selectors for standalone and load balancer mode.
return self.known_servers
if address:
# Ignore selectors when explicit address is requested.
description = self.server_descriptions().get(address)
return [description] if description else []
selection = Selection.from_topology_description(self)
# Ignore read preference for sharded clusters.
if self.topology_type != TOPOLOGY_TYPE.Sharded:
selection = selector(selection)
# Apply custom selector followed by localThresholdMS.
if custom_selector is not None and selection:
selection = selection.with_server_descriptions(
custom_selector(selection.server_descriptions)
)
return self._apply_local_threshold(selection)
def has_readable_server(self, read_preference: _ServerMode = ReadPreference.PRIMARY) -> bool:
"""Does this topology have any readable servers available matching the
given read preference?
:param read_preference: an instance of a read preference from
:mod:`~pymongo.read_preferences`. Defaults to
:attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`.
.. note:: When connected directly to a single server this method
always returns ``True``.
.. versionadded:: 3.4
"""
common.validate_read_preference("read_preference", read_preference)
return any(self.apply_selector(read_preference))
def has_writable_server(self) -> bool:
"""Does this topology have a writable server available?
.. note:: When connected directly to a single server this method
always returns ``True``.
.. versionadded:: 3.4
"""
return self.has_readable_server(ReadPreference.PRIMARY)
def __repr__(self) -> str:
# Sort the servers by address.
servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address)
return "<{} id: {}, topology_type: {}, servers: {!r}>".format(
self.__class__.__name__,
self._topology_settings._topology_id,
self.topology_type_name,
servers,
)
# If topology type is Unknown and we receive a hello response, what should
# the new topology type be?
_SERVER_TYPE_TO_TOPOLOGY_TYPE = {
SERVER_TYPE.Mongos: TOPOLOGY_TYPE.Sharded,
SERVER_TYPE.RSPrimary: TOPOLOGY_TYPE.ReplicaSetWithPrimary,
SERVER_TYPE.RSSecondary: TOPOLOGY_TYPE.ReplicaSetNoPrimary,
SERVER_TYPE.RSArbiter: TOPOLOGY_TYPE.ReplicaSetNoPrimary,
SERVER_TYPE.RSOther: TOPOLOGY_TYPE.ReplicaSetNoPrimary,
# Note: SERVER_TYPE.LoadBalancer and Unknown are intentionally left out.
}
def updated_topology_description(
topology_description: TopologyDescription, server_description: ServerDescription
) -> TopologyDescription:
"""Return an updated copy of a TopologyDescription.
:param topology_description: the current TopologyDescription
:param server_description: a new ServerDescription that resulted from
a hello call
Called after attempting (successfully or not) to call hello on the
server at server_description.address. Does not modify topology_description.
"""
address = server_description.address
# These values will be updated, if necessary, to form the new
# TopologyDescription.
topology_type = topology_description.topology_type
set_name = topology_description.replica_set_name
max_set_version = topology_description.max_set_version
max_election_id = topology_description.max_election_id
server_type = server_description.server_type
# Don't mutate the original dict of server descriptions; copy it.
sds = topology_description.server_descriptions()
# Replace this server's description with the new one.
sds[address] = server_description
if topology_type == TOPOLOGY_TYPE.Single:
# Set server type to Unknown if replica set name does not match.
if set_name is not None and set_name != server_description.replica_set_name:
error = ConfigurationError(
"client is configured to connect to a replica set named "
"'{}' but this node belongs to a set named '{}'".format(
set_name, server_description.replica_set_name
)
)
sds[address] = server_description.to_unknown(error=error)
# Single type never changes.
return TopologyDescription(
TOPOLOGY_TYPE.Single,
sds,
set_name,
max_set_version,
max_election_id,
topology_description._topology_settings,
)
if topology_type == TOPOLOGY_TYPE.Unknown:
if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.LoadBalancer):
if len(topology_description._topology_settings.seeds) == 1:
topology_type = TOPOLOGY_TYPE.Single
else:
# Remove standalone from Topology when given multiple seeds.
sds.pop(address)
elif server_type not in (SERVER_TYPE.Unknown, SERVER_TYPE.RSGhost):
topology_type = _SERVER_TYPE_TO_TOPOLOGY_TYPE[server_type]
if topology_type == TOPOLOGY_TYPE.Sharded:
if server_type not in (SERVER_TYPE.Mongos, SERVER_TYPE.Unknown):
sds.pop(address)
elif topology_type == TOPOLOGY_TYPE.ReplicaSetNoPrimary:
if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos):
sds.pop(address)
elif server_type == SERVER_TYPE.RSPrimary:
(topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary(
sds, set_name, server_description, max_set_version, max_election_id
)
elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther):
topology_type, set_name = _update_rs_no_primary_from_member(
sds, set_name, server_description
)
elif topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary:
if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos):
sds.pop(address)
topology_type = _check_has_primary(sds)
elif server_type == SERVER_TYPE.RSPrimary:
(topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary(
sds, set_name, server_description, max_set_version, max_election_id
)
elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther):
topology_type = _update_rs_with_primary_from_member(sds, set_name, server_description)
else:
# Server type is Unknown or RSGhost: did we just lose the primary?
topology_type = _check_has_primary(sds)
# Return updated copy.
return TopologyDescription(
topology_type,
sds,
set_name,
max_set_version,
max_election_id,
topology_description._topology_settings,
)
def _updated_topology_description_srv_polling(
topology_description: TopologyDescription, seedlist: list[tuple[str, Any]]
) -> TopologyDescription:
"""Return an updated copy of a TopologyDescription.
:param topology_description: the current TopologyDescription
:param seedlist: a list of new seeds new ServerDescription that resulted from
a hello call
"""
assert topology_description.topology_type in SRV_POLLING_TOPOLOGIES
# Create a copy of the server descriptions.
sds = topology_description.server_descriptions()
# If seeds haven't changed, don't do anything.
if set(sds.keys()) == set(seedlist):
return topology_description
# Remove SDs corresponding to servers no longer part of the SRV record.
for address in list(sds.keys()):
if address not in seedlist:
sds.pop(address)
if topology_description.srv_max_hosts != 0:
new_hosts = set(seedlist) - set(sds.keys())
n_to_add = topology_description.srv_max_hosts - len(sds)
if n_to_add > 0:
seedlist = sample(sorted(new_hosts), min(n_to_add, len(new_hosts)))
else:
seedlist = []
# Add SDs corresponding to servers recently added to the SRV record.
for address in seedlist:
if address not in sds:
sds[address] = ServerDescription(address)
return TopologyDescription(
topology_description.topology_type,
sds,
topology_description.replica_set_name,
topology_description.max_set_version,
topology_description.max_election_id,
topology_description._topology_settings,
)
def _update_rs_from_primary(
sds: MutableMapping[_Address, ServerDescription],
replica_set_name: Optional[str],
server_description: ServerDescription,
max_set_version: Optional[int],
max_election_id: Optional[ObjectId],
) -> tuple[int, Optional[str], Optional[int], Optional[ObjectId]]:
"""Update topology description from a primary's hello response.
Pass in a dict of ServerDescriptions, current replica set name, the
ServerDescription we are processing, and the TopologyDescription's
max_set_version and max_election_id if any.
Returns (new topology type, new replica_set_name, new max_set_version,
new max_election_id).
"""
if replica_set_name is None:
replica_set_name = server_description.replica_set_name
elif replica_set_name != server_description.replica_set_name:
# We found a primary but it doesn't have the replica_set_name
# provided by the user.
sds.pop(server_description.address)
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
if server_description.max_wire_version is None or server_description.max_wire_version < 17:
new_election_tuple: tuple = (server_description.set_version, server_description.election_id)
max_election_tuple: tuple = (max_set_version, max_election_id)
if None not in new_election_tuple:
if None not in max_election_tuple and new_election_tuple < max_election_tuple:
# Stale primary, set to type Unknown.
sds[server_description.address] = server_description.to_unknown()
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
max_election_id = server_description.election_id
if server_description.set_version is not None and (
max_set_version is None or server_description.set_version > max_set_version
):
max_set_version = server_description.set_version
else:
new_election_tuple = server_description.election_id, server_description.set_version
max_election_tuple = max_election_id, max_set_version
new_election_safe = tuple(MinKey() if i is None else i for i in new_election_tuple)
max_election_safe = tuple(MinKey() if i is None else i for i in max_election_tuple)
if new_election_safe < max_election_safe:
# Stale primary, set to type Unknown.
sds[server_description.address] = server_description.to_unknown()
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
else:
max_election_id = server_description.election_id
max_set_version = server_description.set_version
# We've heard from the primary. Is it the same primary as before?
for server in sds.values():
if (
server.server_type is SERVER_TYPE.RSPrimary
and server.address != server_description.address
):
# Reset old primary's type to Unknown.
sds[server.address] = server.to_unknown()
# There can be only one prior primary.
break
# Discover new hosts from this primary's response.
for new_address in server_description.all_hosts:
if new_address not in sds:
sds[new_address] = ServerDescription(new_address)
# Remove hosts not in the response.
for addr in set(sds) - server_description.all_hosts:
sds.pop(addr)
# If the host list differs from the seed list, we may not have a primary
# after all.
return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id)
def _update_rs_with_primary_from_member(
sds: MutableMapping[_Address, ServerDescription],
replica_set_name: Optional[str],
server_description: ServerDescription,
) -> int:
"""RS with known primary. Process a response from a non-primary.
Pass in a dict of ServerDescriptions, current replica set name, and the
ServerDescription we are processing.
Returns new topology type.
"""
assert replica_set_name is not None
if replica_set_name != server_description.replica_set_name:
sds.pop(server_description.address)
elif server_description.me and server_description.address != server_description.me:
sds.pop(server_description.address)
# Had this member been the primary?
return _check_has_primary(sds)
def _update_rs_no_primary_from_member(
sds: MutableMapping[_Address, ServerDescription],
replica_set_name: Optional[str],
server_description: ServerDescription,
) -> tuple[int, Optional[str]]:
"""RS without known primary. Update from a non-primary's response.
Pass in a dict of ServerDescriptions, current replica set name, and the
ServerDescription we are processing.
Returns (new topology type, new replica_set_name).
"""
topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary
if replica_set_name is None:
replica_set_name = server_description.replica_set_name
elif replica_set_name != server_description.replica_set_name:
sds.pop(server_description.address)
return topology_type, replica_set_name
# This isn't the primary's response, so don't remove any servers
# it doesn't report. Only add new servers.
for address in server_description.all_hosts:
if address not in sds:
sds[address] = ServerDescription(address)
if server_description.me and server_description.address != server_description.me:
sds.pop(server_description.address)
return topology_type, replica_set_name
def _check_has_primary(sds: Mapping[_Address, ServerDescription]) -> int:
"""Current topology type is ReplicaSetWithPrimary. Is primary still known?
Pass in a dict of ServerDescriptions.
Returns new topology type.
"""
for s in sds.values():
if s.server_type == SERVER_TYPE.RSPrimary:
return TOPOLOGY_TYPE.ReplicaSetWithPrimary
else: # noqa: PLW0120
return TOPOLOGY_TYPE.ReplicaSetNoPrimary

View File

@ -0,0 +1,61 @@
# Copyright 2022-Present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Type aliases used by PyMongo"""
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from bson.typings import _DocumentOut, _DocumentType, _DocumentTypeArg
if TYPE_CHECKING:
from pymongo.asynchronous.collation import Collation
_IS_SYNC = False
# Common Shared Types.
_Address = Tuple[str, Optional[int]]
_CollationIn = Union[Mapping[str, Any], "Collation"]
_Pipeline = Sequence[Mapping[str, Any]]
ClusterTime = Mapping[str, Any]
_T = TypeVar("_T")
def strip_optional(elem: Optional[_T]) -> _T:
"""This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T
while inside a list comprehension.
"""
assert elem is not None
return elem
__all__ = [
"_DocumentOut",
"_DocumentType",
"_DocumentTypeArg",
"_Address",
"_CollationIn",
"_Pipeline",
"strip_optional",
]

View File

@ -0,0 +1,624 @@
# Copyright 2011-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Tools to parse and validate a MongoDB URI."""
from __future__ import annotations
import re
import sys
import warnings
from typing import (
TYPE_CHECKING,
Any,
Mapping,
MutableMapping,
Optional,
Sized,
Union,
cast,
)
from urllib.parse import unquote_plus
from pymongo.asynchronous.client_options import _parse_ssl_options
from pymongo.asynchronous.common import (
INTERNAL_URI_OPTION_NAME_MAP,
SRV_SERVICE_NAME,
URI_OPTIONS_DEPRECATION_MAP,
_CaseInsensitiveDictionary,
get_validated_options,
)
from pymongo.asynchronous.srv_resolver import _have_dnspython, _SrvResolver
from pymongo.asynchronous.typings import _Address
from pymongo.errors import ConfigurationError, InvalidURI
if TYPE_CHECKING:
from pymongo.pyopenssl_context import SSLContext
_IS_SYNC = False
SCHEME = "mongodb://"
SCHEME_LEN = len(SCHEME)
SRV_SCHEME = "mongodb+srv://"
SRV_SCHEME_LEN = len(SRV_SCHEME)
DEFAULT_PORT = 27017
def _unquoted_percent(s: str) -> bool:
"""Check for unescaped percent signs.
:param s: A string. `s` can have things like '%25', '%2525',
and '%E2%85%A8' but cannot have unquoted percent like '%foo'.
"""
for i in range(len(s)):
if s[i] == "%":
sub = s[i : i + 3]
# If unquoting yields the same string this means there was an
# unquoted %.
if unquote_plus(sub) == sub:
return True
return False
def parse_userinfo(userinfo: str) -> tuple[str, str]:
"""Validates the format of user information in a MongoDB URI.
Reserved characters that are gen-delimiters (":", "/", "?", "#", "[",
"]", "@") as per RFC 3986 must be escaped.
Returns a 2-tuple containing the unescaped username followed
by the unescaped password.
:param userinfo: A string of the form <username>:<password>
"""
if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo):
raise InvalidURI(
"Username and password must be escaped according to "
"RFC 3986, use urllib.parse.quote_plus"
)
user, _, passwd = userinfo.partition(":")
# No password is expected with GSSAPI authentication.
if not user:
raise InvalidURI("The empty string is not valid username.")
return unquote_plus(user), unquote_plus(passwd)
def parse_ipv6_literal_host(
entity: str, default_port: Optional[int]
) -> tuple[str, Optional[Union[str, int]]]:
"""Validates an IPv6 literal host:port string.
Returns a 2-tuple of IPv6 literal followed by port where
port is default_port if it wasn't specified in entity.
:param entity: A string that represents an IPv6 literal enclosed
in braces (e.g. '[::1]' or '[::1]:27017').
:param default_port: The port number to use when one wasn't
specified in entity.
"""
if entity.find("]") == -1:
raise ValueError(
"an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732."
)
i = entity.find("]:")
if i == -1:
return entity[1:-1], default_port
return entity[1:i], entity[i + 2 :]
def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address:
"""Validates a host string
Returns a 2-tuple of host followed by port where port is default_port
if it wasn't specified in the string.
:param entity: A host or host:port string where host could be a
hostname or IP address.
:param default_port: The port number to use when one wasn't
specified in entity.
"""
host = entity
port: Optional[Union[str, int]] = default_port
if entity[0] == "[":
host, port = parse_ipv6_literal_host(entity, default_port)
elif entity.endswith(".sock"):
return entity, default_port
elif entity.find(":") != -1:
if entity.count(":") > 1:
raise ValueError(
"Reserved characters such as ':' must be "
"escaped according RFC 2396. An IPv6 "
"address literal must be enclosed in '[' "
"and ']' according to RFC 2732."
)
host, port = host.split(":", 1)
if isinstance(port, str):
if not port.isdigit() or int(port) > 65535 or int(port) <= 0:
raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}")
port = int(port)
# Normalize hostname to lowercase, since DNS is case-insensitive:
# http://tools.ietf.org/html/rfc4343
# This prevents useless rediscovery if "foo.com" is in the seed list but
# "FOO.com" is in the hello response.
return host.lower(), port
# Options whose values are implicitly determined by tlsInsecure.
_IMPLICIT_TLSINSECURE_OPTS = {
"tlsallowinvalidcertificates",
"tlsallowinvalidhostnames",
"tlsdisableocspendpointcheck",
}
def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary:
"""Helper method for split_options which creates the options dict.
Also handles the creation of a list for the URI tag_sets/
readpreferencetags portion, and the use of a unicode options string.
"""
options = _CaseInsensitiveDictionary()
for uriopt in opts.split(delim):
key, value = uriopt.split("=")
if key.lower() == "readpreferencetags":
options.setdefault(key, []).append(value)
else:
if key in options:
warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2)
if key.lower() == "authmechanismproperties":
val = value
else:
val = unquote_plus(value)
options[key] = val
return options
def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
"""Raise appropriate errors when conflicting TLS options are present in
the options dictionary.
:param options: Instance of _CaseInsensitiveDictionary containing
MongoDB URI options.
"""
# Implicitly defined options must not be explicitly specified.
tlsinsecure = options.get("tlsinsecure")
if tlsinsecure is not None:
for opt in _IMPLICIT_TLSINSECURE_OPTS:
if opt in options:
err_msg = "URI options %s and %s cannot be specified simultaneously."
raise InvalidURI(
err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt))
)
# Handle co-occurence of OCSP & tlsAllowInvalidCertificates options.
tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates")
if tlsallowinvalidcerts is not None:
if "tlsdisableocspendpointcheck" in options:
err_msg = "URI options %s and %s cannot be specified simultaneously."
raise InvalidURI(
err_msg
% ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck"))
)
if tlsallowinvalidcerts is True:
options["tlsdisableocspendpointcheck"] = True
# Handle co-occurence of CRL and OCSP-related options.
tlscrlfile = options.get("tlscrlfile")
if tlscrlfile is not None:
for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"):
if options.get(opt) is True:
err_msg = "URI option %s=True cannot be specified when CRL checking is enabled."
raise InvalidURI(err_msg % (opt,))
if "ssl" in options and "tls" in options:
def truth_value(val: Any) -> Any:
if val in ("true", "false"):
return val == "true"
if isinstance(val, bool):
return val
return val
if truth_value(options.get("ssl")) != truth_value(options.get("tls")):
err_msg = "Can not specify conflicting values for URI options %s and %s."
raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls")))
return options
def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
"""Issue appropriate warnings when deprecated options are present in the
options dictionary. Removes deprecated option key, value pairs if the
options dictionary is found to also have the renamed option.
:param options: Instance of _CaseInsensitiveDictionary containing
MongoDB URI options.
"""
for optname in list(options):
if optname in URI_OPTIONS_DEPRECATION_MAP:
mode, message = URI_OPTIONS_DEPRECATION_MAP[optname]
if mode == "renamed":
newoptname = message
if newoptname in options:
warn_msg = "Deprecated option '%s' ignored in favor of '%s'."
warnings.warn(
warn_msg % (options.cased_key(optname), options.cased_key(newoptname)),
DeprecationWarning,
stacklevel=2,
)
options.pop(optname)
continue
warn_msg = "Option '%s' is deprecated, use '%s' instead."
warnings.warn(
warn_msg % (options.cased_key(optname), newoptname),
DeprecationWarning,
stacklevel=2,
)
elif mode == "removed":
warn_msg = "Option '%s' is deprecated. %s."
warnings.warn(
warn_msg % (options.cased_key(optname), message),
DeprecationWarning,
stacklevel=2,
)
return options
def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
"""Normalizes option names in the options dictionary by converting them to
their internally-used names.
:param options: Instance of _CaseInsensitiveDictionary containing
MongoDB URI options.
"""
# Expand the tlsInsecure option.
tlsinsecure = options.get("tlsinsecure")
if tlsinsecure is not None:
for opt in _IMPLICIT_TLSINSECURE_OPTS:
# Implicit options are logically the same as tlsInsecure.
options[opt] = tlsinsecure
for optname in list(options):
intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None)
if intname is not None:
options[intname] = options.pop(optname)
return options
def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]:
"""Validates and normalizes options passed in a MongoDB URI.
Returns a new dictionary of validated and normalized options. If warn is
False then errors will be thrown for invalid options, otherwise they will
be ignored and a warning will be issued.
:param opts: A dict of MongoDB URI options.
:param warn: If ``True`` then warnings will be logged and
invalid options will be ignored. Otherwise invalid options will
cause errors.
"""
return get_validated_options(opts, warn)
def split_options(
opts: str, validate: bool = True, warn: bool = False, normalize: bool = True
) -> MutableMapping[str, Any]:
"""Takes the options portion of a MongoDB URI, validates each option
and returns the options in a dictionary.
:param opt: A string representing MongoDB URI options.
:param validate: If ``True`` (the default), validate and normalize all
options.
:param warn: If ``False`` (the default), suppress all warnings raised
during validation of options.
:param normalize: If ``True`` (the default), renames all options to their
internally-used names.
"""
and_idx = opts.find("&")
semi_idx = opts.find(";")
try:
if and_idx >= 0 and semi_idx >= 0:
raise InvalidURI("Can not mix '&' and ';' for option separators.")
elif and_idx >= 0:
options = _parse_options(opts, "&")
elif semi_idx >= 0:
options = _parse_options(opts, ";")
elif opts.find("=") != -1:
options = _parse_options(opts, None)
else:
raise ValueError
except ValueError:
raise InvalidURI("MongoDB URI options are key=value pairs.") from None
options = _handle_security_options(options)
options = _handle_option_deprecations(options)
if normalize:
options = _normalize_options(options)
if validate:
options = cast(_CaseInsensitiveDictionary, validate_options(options, warn))
if options.get("authsource") == "":
raise InvalidURI("the authSource database cannot be an empty string")
return options
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]:
"""Takes a string of the form host1[:port],host2[:port]... and
splits it into (host, port) tuples. If [:port] isn't present the
default_port is used.
Returns a set of 2-tuples containing the host name (or IP) followed by
port number.
:param hosts: A string of the form host1[:port],host2[:port],...
:param default_port: The port number to use when one wasn't specified
for a host.
"""
nodes = []
for entity in hosts.split(","):
if not entity:
raise ConfigurationError("Empty host (or extra comma in host list).")
port = default_port
# Unix socket entities don't have ports
if entity.endswith(".sock"):
port = None
nodes.append(parse_host(entity, port))
return nodes
# Prohibited characters in database name. DB names also can't have ".", but for
# backward-compat we allow "db.collection" in URI.
_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]")
_ALLOWED_TXT_OPTS = frozenset(
["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"]
)
def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None:
# Ensure directConnection was not True if there are multiple seeds.
if len(nodes) > 1 and options.get("directconnection"):
raise ConfigurationError("Cannot specify multiple hosts with directConnection=true")
if options.get("loadbalanced"):
if len(nodes) > 1:
raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true")
if options.get("directconnection"):
raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true")
if options.get("replicaset"):
raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true")
def parse_uri(
uri: str,
default_port: Optional[int] = DEFAULT_PORT,
validate: bool = True,
warn: bool = False,
normalize: bool = True,
connect_timeout: Optional[float] = None,
srv_service_name: Optional[str] = None,
srv_max_hosts: Optional[int] = None,
) -> dict[str, Any]:
"""Parse and validate a MongoDB URI.
Returns a dict of the form::
{
'nodelist': <list of (host, port) tuples>,
'username': <username> or None,
'password': <password> or None,
'database': <database name> or None,
'collection': <collection name> or None,
'options': <dict of MongoDB URI options>,
'fqdn': <fqdn of the MongoDB+SRV URI> or None
}
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
to build nodelist and options.
:param uri: The MongoDB URI to parse.
:param default_port: The port number to use when one wasn't specified
for a host in the URI.
:param validate: If ``True`` (the default), validate and
normalize all options. Default: ``True``.
:param warn: When validating, if ``True`` then will warn
the user then ignore any invalid options or values. If ``False``,
validation will error when options are unsupported or values are
invalid. Default: ``False``.
:param normalize: If ``True``, convert names of URI options
to their internally-used names. Default: ``True``.
:param connect_timeout: The maximum time in milliseconds to
wait for a response from the DNS server.
:param srv_service_name: A custom SRV service name
.. versionchanged:: 4.6
The delimiting slash (``/``) between hosts and connection options is now optional.
For example, "mongodb://example.com?tls=true" is now a valid URI.
.. versionchanged:: 4.0
To better follow RFC 3986, unquoted percent signs ("%") are no longer
supported.
.. versionchanged:: 3.9
Added the ``normalize`` parameter.
.. versionchanged:: 3.6
Added support for mongodb+srv:// URIs.
.. versionchanged:: 3.5
Return the original value of the ``readPreference`` MongoDB URI option
instead of the validated read preference mode.
.. versionchanged:: 3.1
``warn`` added so invalid options can be ignored.
"""
if uri.startswith(SCHEME):
is_srv = False
scheme_free = uri[SCHEME_LEN:]
elif uri.startswith(SRV_SCHEME):
if not _have_dnspython():
python_path = sys.executable or "python"
raise ConfigurationError(
'The "dnspython" module must be '
"installed to use mongodb+srv:// URIs. "
"To fix this error install pymongo again:\n "
"%s -m pip install pymongo>=4.3" % (python_path)
)
is_srv = True
scheme_free = uri[SRV_SCHEME_LEN:]
else:
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
if not scheme_free:
raise InvalidURI("Must provide at least one hostname or IP.")
user = None
passwd = None
dbase = None
collection = None
options = _CaseInsensitiveDictionary()
host_plus_db_part, _, opts = scheme_free.partition("?")
if "/" in host_plus_db_part:
host_part, _, dbase = host_plus_db_part.partition("/")
else:
host_part = host_plus_db_part
if dbase:
dbase = unquote_plus(dbase)
if "." in dbase:
dbase, collection = dbase.split(".", 1)
if _BAD_DB_CHARS.search(dbase):
raise InvalidURI('Bad database name "%s"' % dbase)
else:
dbase = None
if opts:
options.update(split_options(opts, validate, warn, normalize))
if srv_service_name is None:
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
if "@" in host_part:
userinfo, _, hosts = host_part.rpartition("@")
user, passwd = parse_userinfo(userinfo)
else:
hosts = host_part
if "/" in hosts:
raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part)
hosts = unquote_plus(hosts)
fqdn = None
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
if is_srv:
if options.get("directConnection"):
raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs")
nodes = split_hosts(hosts, default_port=None)
if len(nodes) != 1:
raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname")
fqdn, port = nodes[0]
if port is not None:
raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number")
# Use the connection timeout. connectTimeoutMS passed as a keyword
# argument overrides the same option passed in the connection string.
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
nodes = dns_resolver.get_hosts()
dns_options = dns_resolver.get_options()
if dns_options:
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
raise ConfigurationError(
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
)
for opt, val in parsed_dns_options.items():
if opt not in options:
options[opt] = val
if options.get("loadBalanced") and srv_max_hosts:
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
if options.get("replicaSet") and srv_max_hosts:
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
if "tls" not in options and "ssl" not in options:
options["tls"] = True if validate else "true"
elif not is_srv and options.get("srvServiceName") is not None:
raise ConfigurationError(
"The srvServiceName option is only allowed with 'mongodb+srv://' URIs"
)
elif not is_srv and srv_max_hosts:
raise ConfigurationError(
"The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs"
)
else:
nodes = split_hosts(hosts, default_port=default_port)
_check_options(nodes, options)
return {
"nodelist": nodes,
"username": user,
"password": passwd,
"database": dbase,
"collection": collection,
"options": options,
"fqdn": fqdn,
}
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]:
"""Parse KMS TLS connection options."""
if not kms_tls_options:
return {}
if not isinstance(kms_tls_options, dict):
raise TypeError("kms_tls_options must be a dict")
contexts = {}
for provider, options in kms_tls_options.items():
if not isinstance(options, dict):
raise TypeError(f'kms_tls_options["{provider}"] must be a dict')
options.setdefault("tls", True)
opts = _CaseInsensitiveDictionary(options)
opts = _handle_security_options(opts)
opts = _normalize_options(opts)
opts = cast(_CaseInsensitiveDictionary, validate_options(opts))
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts)
if ssl_context is None:
raise ConfigurationError("TLS is required for KMS providers")
if allow_invalid_hostnames:
raise ConfigurationError("Insecure TLS options prohibited")
for n in [
"tlsInsecure",
"tlsAllowInvalidCertificates",
"tlsAllowInvalidHostnames",
"tlsDisableCertificateRevocationCheck",
]:
if n in opts:
raise ConfigurationError(f"Insecure TLS options prohibited: {n}")
contexts[provider] = ssl_context
return contexts
if __name__ == "__main__":
import pprint
try:
pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203
except InvalidURI as exc:
print(exc) # noqa: T201
sys.exit(0)

View File

@ -1,4 +1,4 @@
# Copyright 2013-present MongoDB, Inc.
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,645 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Authentication helpers."""
"""Re-import of synchronous Auth API for compatibility."""
from __future__ import annotations
import functools
import hashlib
import hmac
import os
import socket
import typing
from base64 import standard_b64decode, standard_b64encode
from collections import namedtuple
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Mapping,
MutableMapping,
Optional,
cast,
)
from urllib.parse import quote
from pymongo.synchronous.auth import * # noqa: F403
from pymongo.synchronous.auth import __doc__ as original_doc
from bson.binary import Binary
from pymongo.auth_aws import _authenticate_aws
from pymongo.auth_oidc import (
_authenticate_oidc,
_get_authenticator,
_OIDCAzureCallback,
_OIDCGCPCallback,
_OIDCProperties,
_OIDCTestCallback,
)
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.saslprep import saslprep
if TYPE_CHECKING:
from pymongo.hello import Hello
from pymongo.pool import Connection
HAVE_KERBEROS = True
_USE_PRINCIPAL = False
try:
import winkerberos as kerberos # type:ignore[import]
if tuple(map(int, kerberos.__version__.split(".")[:2])) >= (0, 5):
_USE_PRINCIPAL = True
except ImportError:
try:
import kerberos # type:ignore[import]
except ImportError:
HAVE_KERBEROS = False
MECHANISMS = frozenset(
[
"GSSAPI",
"MONGODB-CR",
"MONGODB-OIDC",
"MONGODB-X509",
"MONGODB-AWS",
"PLAIN",
"SCRAM-SHA-1",
"SCRAM-SHA-256",
"DEFAULT",
]
)
"""The authentication mechanisms supported by PyMongo."""
class _Cache:
__slots__ = ("data",)
_hash_val = hash("_Cache")
def __init__(self) -> None:
self.data = None
def __eq__(self, other: object) -> bool:
# Two instances must always compare equal.
if isinstance(other, _Cache):
return True
return NotImplemented
def __ne__(self, other: object) -> bool:
if isinstance(other, _Cache):
return False
return NotImplemented
def __hash__(self) -> int:
return self._hash_val
MongoCredential = namedtuple(
"MongoCredential",
["mechanism", "source", "username", "password", "mechanism_properties", "cache"],
)
"""A hashable namedtuple of values used for authentication."""
GSSAPIProperties = namedtuple(
"GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"]
)
"""Mechanism properties for GSSAPI authentication."""
_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
"""Mechanism properties for MONGODB-AWS authentication."""
def _build_credentials_tuple(
mech: str,
source: Optional[str],
user: str,
passwd: str,
extra: Mapping[str, Any],
database: Optional[str],
) -> MongoCredential:
"""Build and return a mechanism specific credentials tuple."""
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
raise ConfigurationError(f"{mech} requires a username.")
if mech == "GSSAPI":
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for GSSAPI")
properties = extra.get("authmechanismproperties", {})
service_name = properties.get("SERVICE_NAME", "mongodb")
canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False))
service_realm = properties.get("SERVICE_REALM")
props = GSSAPIProperties(
service_name=service_name,
canonicalize_host_name=canonicalize,
service_realm=service_realm,
)
# Source is always $external.
return MongoCredential(mech, "$external", user, passwd, props, None)
elif mech == "MONGODB-X509":
if passwd is not None:
raise ConfigurationError("Passwords are not supported by MONGODB-X509")
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for MONGODB-X509")
# Source is always $external, user can be None.
return MongoCredential(mech, "$external", user, None, None, None)
elif mech == "MONGODB-AWS":
if user is not None and passwd is None:
raise ConfigurationError("username without a password is not supported by MONGODB-AWS")
if source is not None and source != "$external":
raise ConfigurationError(
"authentication source must be $external or None for MONGODB-AWS"
)
properties = extra.get("authmechanismproperties", {})
aws_session_token = properties.get("AWS_SESSION_TOKEN")
aws_props = _AWSProperties(aws_session_token=aws_session_token)
# user can be None for temporary link-local EC2 credentials.
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
elif mech == "MONGODB-OIDC":
properties = extra.get("authmechanismproperties", {})
callback = properties.get("OIDC_CALLBACK")
human_callback = properties.get("OIDC_HUMAN_CALLBACK")
environ = properties.get("ENVIRONMENT")
token_resource = properties.get("TOKEN_RESOURCE", "")
default_allowed = [
"*.mongodb.net",
"*.mongodb-dev.net",
"*.mongodb-qa.net",
"*.mongodbgov.net",
"localhost",
"127.0.0.1",
"::1",
]
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
msg = (
"authentication with MONGODB-OIDC requires providing either a callback or a environment"
)
if passwd is not None:
msg = "password is not supported by MONGODB-OIDC"
raise ConfigurationError(msg)
if callback or human_callback:
if environ is not None:
raise ConfigurationError(msg)
if callback and human_callback:
msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK"
raise ConfigurationError(msg)
elif environ is not None:
if environ == "test":
if user is not None:
msg = "test environment for MONGODB-OIDC does not support username"
raise ConfigurationError(msg)
callback = _OIDCTestCallback()
elif environ == "azure":
passwd = None
if not token_resource:
raise ConfigurationError(
"Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCAzureCallback(token_resource)
elif environ == "gcp":
passwd = None
if not token_resource:
raise ConfigurationError(
"GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCGCPCallback(token_resource)
else:
raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}")
else:
raise ConfigurationError(msg)
oidc_props = _OIDCProperties(
callback=callback,
human_callback=human_callback,
environment=environ,
allowed_hosts=allowed_hosts,
token_resource=token_resource,
username=user,
)
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
elif mech == "PLAIN":
source_database = source or database or "$external"
return MongoCredential(mech, source_database, user, passwd, None, None)
else:
source_database = source or database or "admin"
if passwd is None:
raise ConfigurationError("A password is required.")
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
def _xor(fir: bytes, sec: bytes) -> bytes:
"""XOR two byte strings together."""
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]:
"""Split a scram response into key, value pairs."""
return dict(
typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1))
for item in response.split(b",")
)
def _authenticate_scram_start(
credentials: MongoCredential, mechanism: str
) -> tuple[bytes, bytes, MutableMapping[str, Any]]:
username = credentials.username
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
nonce = standard_b64encode(os.urandom(32))
first_bare = b"n=" + user + b",r=" + nonce
cmd = {
"saslStart": 1,
"mechanism": mechanism,
"payload": Binary(b"n,," + first_bare),
"autoAuthorize": 1,
"options": {"skipEmptyExchange": True},
}
return nonce, first_bare, cmd
def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanism: str) -> None:
"""Authenticate using SCRAM."""
username = credentials.username
if mechanism == "SCRAM-SHA-256":
digest = "sha256"
digestmod = hashlib.sha256
data = saslprep(credentials.password).encode("utf-8")
else:
digest = "sha1"
digestmod = hashlib.sha1
data = _password_digest(username, credentials.password).encode("utf-8")
source = credentials.source
cache = credentials.cache
# Make local
_hmac = hmac.HMAC
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
assert isinstance(ctx, _ScramContext)
assert ctx.scram_data is not None
nonce, first_bare = ctx.scram_data
res = ctx.speculative_authenticate
else:
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
res = conn.command(source, cmd)
assert res is not None
server_first = res["payload"]
parsed = _parse_scram_response(server_first)
iterations = int(parsed[b"i"])
if iterations < 4096:
raise OperationFailure("Server returned an invalid iteration count.")
salt = parsed[b"s"]
rnonce = parsed[b"r"]
if not rnonce.startswith(nonce):
raise OperationFailure("Server returned an invalid nonce.")
without_proof = b"c=biws,r=" + rnonce
if cache.data:
client_key, server_key, csalt, citerations = cache.data
else:
client_key, server_key, csalt, citerations = None, None, None, None
# Salt and / or iterations could change for a number of different
# reasons. Either changing invalidates the cache.
if not client_key or salt != csalt or iterations != citerations:
salted_pass = hashlib.pbkdf2_hmac(digest, data, standard_b64decode(salt), iterations)
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
cache.data = (client_key, server_key, salt, iterations)
stored_key = digestmod(client_key).digest()
auth_msg = b",".join((first_bare, server_first, without_proof))
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()
client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig))
client_final = b",".join((without_proof, client_proof))
server_sig = standard_b64encode(_hmac(server_key, auth_msg, digestmod).digest())
cmd = {
"saslContinue": 1,
"conversationId": res["conversationId"],
"payload": Binary(client_final),
}
res = conn.command(source, cmd)
parsed = _parse_scram_response(res["payload"])
if not hmac.compare_digest(parsed[b"v"], server_sig):
raise OperationFailure("Server returned an invalid signature.")
# A third empty challenge may be required if the server does not support
# skipEmptyExchange: SERVER-44857.
if not res["done"]:
cmd = {
"saslContinue": 1,
"conversationId": res["conversationId"],
"payload": Binary(b""),
}
res = conn.command(source, cmd)
if not res["done"]:
raise OperationFailure("SASL conversation failed to complete.")
def _password_digest(username: str, password: str) -> str:
"""Get a password digest to use for authentication."""
if not isinstance(password, str):
raise TypeError("password must be an instance of str")
if len(password) == 0:
raise ValueError("password can't be empty")
if not isinstance(username, str):
raise TypeError("username must be an instance of str")
md5hash = hashlib.md5() # noqa: S324
data = f"{username}:mongo:{password}"
md5hash.update(data.encode("utf-8"))
return md5hash.hexdigest()
def _auth_key(nonce: str, username: str, password: str) -> str:
"""Get an auth key to use for authentication."""
digest = _password_digest(username, password)
md5hash = hashlib.md5() # noqa: S324
data = f"{nonce}{username}{digest}"
md5hash.update(data.encode("utf-8"))
return md5hash.hexdigest()
def _canonicalize_hostname(hostname: str) -> str:
"""Canonicalize hostname following MIT-krb5 behavior."""
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
)[0]
try:
name = socket.getnameinfo(sockaddr, socket.NI_NAMEREQD)
except socket.gaierror:
return canonname.lower()
return name[0].lower()
def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using GSSAPI."""
if not HAVE_KERBEROS:
raise ConfigurationError(
'The "kerberos" module must be installed to use GSSAPI authentication.'
)
try:
username = credentials.username
password = credentials.password
props = credentials.mechanism_properties
# Starting here and continuing through the while loop below - establish
# the security context. See RFC 4752, Section 3.1, first paragraph.
host = conn.address[0]
if props.canonicalize_host_name:
host = _canonicalize_hostname(host)
service = props.service_name + "@" + host
if props.service_realm is not None:
service = service + "@" + props.service_realm
if password is not None:
if _USE_PRINCIPAL:
# Note that, though we use unquote_plus for unquoting URI
# options, we use quote here. Microsoft's UrlUnescape (used
# by WinKerberos) doesn't support +.
principal = ":".join((quote(username), quote(password)))
result, ctx = kerberos.authGSSClientInit(
service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG
)
else:
if "@" in username:
user, domain = username.split("@", 1)
else:
user, domain = username, None
result, ctx = kerberos.authGSSClientInit(
service,
gssflags=kerberos.GSS_C_MUTUAL_FLAG,
user=user,
domain=domain,
password=password,
)
else:
result, ctx = kerberos.authGSSClientInit(service, gssflags=kerberos.GSS_C_MUTUAL_FLAG)
if result != kerberos.AUTH_GSS_COMPLETE:
raise OperationFailure("Kerberos context failed to initialize.")
try:
# pykerberos uses a weird mix of exceptions and return values
# to indicate errors.
# 0 == continue, 1 == complete, -1 == error
# Only authGSSClientStep can return 0.
if kerberos.authGSSClientStep(ctx, "") != 0:
raise OperationFailure("Unknown kerberos failure in step function.")
# Start a SASL conversation with mongod/s
# Note: pykerberos deals with base64 encoded byte strings.
# Since mongo accepts base64 strings as the payload we don't
# have to use bson.binary.Binary.
payload = kerberos.authGSSClientResponse(ctx)
cmd = {
"saslStart": 1,
"mechanism": "GSSAPI",
"payload": payload,
"autoAuthorize": 1,
}
response = conn.command("$external", cmd)
# Limit how many times we loop to catch protocol / library issues
for _ in range(10):
result = kerberos.authGSSClientStep(ctx, str(response["payload"]))
if result == -1:
raise OperationFailure("Unknown kerberos failure in step function.")
payload = kerberos.authGSSClientResponse(ctx) or ""
cmd = {
"saslContinue": 1,
"conversationId": response["conversationId"],
"payload": payload,
}
response = conn.command("$external", cmd)
if result == kerberos.AUTH_GSS_COMPLETE:
break
else:
raise OperationFailure("Kerberos authentication failed to complete.")
# Once the security context is established actually authenticate.
# See RFC 4752, Section 3.1, last two paragraphs.
if kerberos.authGSSClientUnwrap(ctx, str(response["payload"])) != 1:
raise OperationFailure("Unknown kerberos failure during GSS_Unwrap step.")
if kerberos.authGSSClientWrap(ctx, kerberos.authGSSClientResponse(ctx), username) != 1:
raise OperationFailure("Unknown kerberos failure during GSS_Wrap step.")
payload = kerberos.authGSSClientResponse(ctx)
cmd = {
"saslContinue": 1,
"conversationId": response["conversationId"],
"payload": payload,
}
conn.command("$external", cmd)
finally:
kerberos.authGSSClientClean(ctx)
except kerberos.KrbError as exc:
raise OperationFailure(str(exc)) from None
def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using SASL PLAIN (RFC 4616)"""
source = credentials.source
username = credentials.username
password = credentials.password
payload = (f"\x00{username}\x00{password}").encode()
cmd = {
"saslStart": 1,
"mechanism": "PLAIN",
"payload": Binary(payload),
"autoAuthorize": 1,
}
conn.command(source, cmd)
def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using MONGODB-X509."""
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
# MONGODB-X509 is done after the speculative auth step.
return
cmd = _X509Context(credentials, conn.address).speculate_command()
conn.command("$external", cmd)
def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using MONGODB-CR."""
source = credentials.source
username = credentials.username
password = credentials.password
# Get a nonce
response = conn.command(source, {"getnonce": 1})
nonce = response["nonce"]
key = _auth_key(nonce, username, password)
# Actually authenticate
query = {"authenticate": 1, "user": username, "nonce": nonce, "key": key}
conn.command(source, query)
def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None:
if conn.max_wire_version >= 7:
if conn.negotiated_mechs:
mechs = conn.negotiated_mechs
else:
source = credentials.source
cmd = conn.hello_cmd()
cmd["saslSupportedMechs"] = source + "." + credentials.username
mechs = conn.command(source, cmd, publish_events=False).get("saslSupportedMechs", [])
if "SCRAM-SHA-256" in mechs:
return _authenticate_scram(credentials, conn, "SCRAM-SHA-256")
else:
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
else:
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
_AUTH_MAP: Mapping[str, Callable[..., None]] = {
"GSSAPI": _authenticate_gssapi,
"MONGODB-CR": _authenticate_mongo_cr,
"MONGODB-X509": _authenticate_x509,
"MONGODB-AWS": _authenticate_aws,
"MONGODB-OIDC": _authenticate_oidc, # type:ignore[dict-item]
"PLAIN": _authenticate_plain,
"SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"),
"SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"),
"DEFAULT": _authenticate_default,
}
class _AuthContext:
def __init__(self, credentials: MongoCredential, address: tuple[str, int]) -> None:
self.credentials = credentials
self.speculative_authenticate: Optional[Mapping[str, Any]] = None
self.address = address
@staticmethod
def from_credentials(
creds: MongoCredential, address: tuple[str, int]
) -> Optional[_AuthContext]:
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
if spec_cls:
return cast(_AuthContext, spec_cls(creds, address))
return None
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
raise NotImplementedError
def parse_response(self, hello: Hello[Mapping[str, Any]]) -> None:
self.speculative_authenticate = hello.speculative_authenticate
def speculate_succeeded(self) -> bool:
return bool(self.speculative_authenticate)
class _ScramContext(_AuthContext):
def __init__(
self, credentials: MongoCredential, address: tuple[str, int], mechanism: str
) -> None:
super().__init__(credentials, address)
self.scram_data: Optional[tuple[bytes, bytes]] = None
self.mechanism = mechanism
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism)
# The 'db' field is included only on the speculative command.
cmd["db"] = self.credentials.source
# Save for later use.
self.scram_data = (nonce, first_bare)
return cmd
class _X509Context(_AuthContext):
def speculate_command(self) -> MutableMapping[str, Any]:
cmd = {"authenticate": 1, "mechanism": "MONGODB-X509"}
if self.credentials.username is not None:
cmd["user"] = self.credentials.username
return cmd
class _OIDCContext(_AuthContext):
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
authenticator = _get_authenticator(self.credentials, self.address)
cmd = authenticator.get_spec_auth_cmd()
if cmd is None:
return None
cmd["db"] = self.credentials.source
return cmd
_SPECULATIVE_AUTH_MAP: Mapping[str, Any] = {
"MONGODB-X509": _X509Context,
"SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"),
"SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
"MONGODB-OIDC": _OIDCContext,
"DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
}
def authenticate(
credentials: MongoCredential, conn: Connection, reauthenticate: bool = False
) -> None:
"""Authenticate connection."""
mechanism = credentials.mechanism
auth_func = _AUTH_MAP[mechanism]
if mechanism == "MONGODB-OIDC":
_authenticate_oidc(credentials, conn, reauthenticate)
else:
auth_func(credentials, conn)
__doc__ = original_doc

View File

@ -1,4 +1,4 @@
# Copyright 2023-present MongoDB, Inc.
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,354 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""MONGODB-OIDC Authentication helpers."""
"""Re-import of synchronous AuthOIDC API for compatibility."""
from __future__ import annotations
import abc
import os
import threading
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
from urllib.parse import quote
from pymongo.synchronous.auth_oidc import * # noqa: F403
from pymongo.synchronous.auth_oidc import __doc__ as original_doc
import bson
from bson.binary import Binary
from pymongo._azure_helpers import _get_azure_response
from pymongo._csot import remaining
from pymongo._gcp_helpers import _get_gcp_response
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.helpers import _AUTHENTICATION_FAILURE_CODE
if TYPE_CHECKING:
from pymongo.auth import MongoCredential
from pymongo.pool import Connection
@dataclass
class OIDCIdPInfo:
issuer: str
clientId: Optional[str] = field(default=None)
requestScopes: Optional[list[str]] = field(default=None)
@dataclass
class OIDCCallbackContext:
timeout_seconds: float
username: str
version: int
refresh_token: Optional[str] = field(default=None)
idp_info: Optional[OIDCIdPInfo] = field(default=None)
@dataclass
class OIDCCallbackResult:
access_token: str
expires_in_seconds: Optional[float] = field(default=None)
refresh_token: Optional[str] = field(default=None)
class OIDCCallback(abc.ABC):
"""A base class for defining OIDC callbacks."""
@abc.abstractmethod
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
"""Convert the given BSON value into our own type."""
@dataclass
class _OIDCProperties:
callback: Optional[OIDCCallback] = field(default=None)
human_callback: Optional[OIDCCallback] = field(default=None)
environment: Optional[str] = field(default=None)
allowed_hosts: list[str] = field(default_factory=list)
token_resource: Optional[str] = field(default=None)
username: str = ""
"""Mechanism properties for MONGODB-OIDC authentication."""
TOKEN_BUFFER_MINUTES = 5
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
CALLBACK_VERSION = 1
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
TIME_BETWEEN_CALLS_SECONDS = 0.1
def _get_authenticator(
credentials: MongoCredential, address: tuple[str, int]
) -> _OIDCAuthenticator:
if credentials.cache.data:
return credentials.cache.data
# Extract values.
principal_name = credentials.username
properties = credentials.mechanism_properties
# Validate that the address is allowed.
if not properties.environment:
found = False
allowed_hosts = properties.allowed_hosts
for patt in allowed_hosts:
if patt == address[0]:
found = True
elif patt.startswith("*.") and address[0].endswith(patt[1:]):
found = True
if not found:
raise ConfigurationError(
f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
)
# Get or create the cache data.
credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties)
return credentials.cache.data
class _OIDCTestCallback(OIDCCallback):
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
token_file = os.environ.get("OIDC_TOKEN_FILE")
if not token_file:
raise RuntimeError(
'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set'
)
with open(token_file) as fid:
return OIDCCallbackResult(access_token=fid.read().strip())
class _OIDCAzureCallback(OIDCCallback):
def __init__(self, token_resource: str) -> None:
self.token_resource = quote(token_resource)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds)
return OIDCCallbackResult(
access_token=resp["access_token"], expires_in_seconds=resp["expires_in"]
)
class _OIDCGCPCallback(OIDCCallback):
def __init__(self, token_resource: str) -> None:
self.token_resource = quote(token_resource)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
return OIDCCallbackResult(access_token=resp["access_token"])
@dataclass
class _OIDCAuthenticator:
username: str
properties: _OIDCProperties
refresh_token: Optional[str] = field(default=None)
access_token: Optional[str] = field(default=None)
idp_info: Optional[OIDCIdPInfo] = field(default=None)
token_gen_id: int = field(default=0)
lock: threading.Lock = field(default_factory=threading.Lock)
last_call_time: float = field(default=0)
def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
"""Handle a reauthenticate from the server."""
# Invalidate the token for the connection.
self._invalidate(conn)
# Call the appropriate auth logic for the callback type.
if self.properties.callback:
return self._authenticate_machine(conn)
return self._authenticate_human(conn)
def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
"""Handle an initial authenticate request."""
# First handle speculative auth.
# If it succeeded, we are done.
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
resp = ctx.speculative_authenticate
if resp and resp["done"]:
conn.oidc_token_gen_id = self.token_gen_id
return resp
# If spec auth failed, call the appropriate auth logic for the callback type.
# We cannot assume that the token is invalid, because a proxy may have been
# involved that stripped the speculative auth information.
if self.properties.callback:
return self._authenticate_machine(conn)
return self._authenticate_human(conn)
def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]:
"""Get the appropriate speculative auth command."""
if not self.access_token:
return None
return self._get_start_command({"jwt": self.access_token})
def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]:
# If there is a cached access token, try to authenticate with it. If
# authentication fails with error code 18, invalidate the access token,
# fetch a new access token, and try to authenticate again. If authentication
# fails for any other reason, raise the error to the user.
if self.access_token:
try:
return self._sasl_start_jwt(conn)
except OperationFailure as e:
if self._is_auth_error(e):
return self._authenticate_machine(conn)
raise
return self._sasl_start_jwt(conn)
def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]:
# If we have a cached access token, try a JwtStepRequest.
# authentication fails with error code 18, invalidate the access token,
# and try to authenticate again. If authentication fails for any other
# reason, raise the error to the user.
if self.access_token:
try:
return self._sasl_start_jwt(conn)
except OperationFailure as e:
if self._is_auth_error(e):
return self._authenticate_human(conn)
raise
# If we have a cached refresh token, try a JwtStepRequest with that.
# If authentication fails with error code 18, invalidate the access and
# refresh tokens, and try to authenticate again. If authentication fails for
# any other reason, raise the error to the user.
if self.refresh_token:
try:
return self._sasl_start_jwt(conn)
except OperationFailure as e:
if self._is_auth_error(e):
self.refresh_token = None
return self._authenticate_human(conn)
raise
# Start a new Two-Step SASL conversation.
# Run a PrincipalStepRequest to get the IdpInfo.
cmd = self._get_start_command(None)
start_resp = self._run_command(conn, cmd)
# Attempt to authenticate with a JwtStepRequest.
return self._sasl_continue_jwt(conn, start_resp)
def _get_access_token(self) -> Optional[str]:
properties = self.properties
cb: Union[None, OIDCCallback]
resp: OIDCCallbackResult
is_human = properties.human_callback is not None
if is_human and self.idp_info is None:
return None
if properties.callback:
cb = properties.callback
if properties.human_callback:
cb = properties.human_callback
prev_token = self.access_token
if prev_token:
return prev_token
if cb is None and not prev_token:
return None
if not prev_token and cb is not None:
with self.lock:
# See if the token was changed while we were waiting for the
# lock.
new_token = self.access_token
if new_token != prev_token:
return new_token
# Ensure that we are waiting a min time between callback invocations.
delta = time.time() - self.last_call_time
if delta < TIME_BETWEEN_CALLS_SECONDS:
time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
self.last_call_time = time.time()
if is_human:
timeout = HUMAN_CALLBACK_TIMEOUT_SECONDS
assert self.idp_info is not None
else:
timeout = int(remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS)
context = OIDCCallbackContext(
timeout_seconds=timeout,
version=CALLBACK_VERSION,
refresh_token=self.refresh_token,
idp_info=self.idp_info,
username=self.properties.username,
)
resp = cb.fetch(context)
if not isinstance(resp, OIDCCallbackResult):
raise ValueError("Callback result must be of type OIDCCallbackResult")
self.refresh_token = resp.refresh_token
self.access_token = resp.access_token
self.token_gen_id += 1
return self.access_token
def _run_command(self, conn: Connection, cmd: MutableMapping[str, Any]) -> Mapping[str, Any]:
try:
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
except OperationFailure as e:
if self._is_auth_error(e):
self._invalidate(conn)
raise
def _is_auth_error(self, err: Exception) -> bool:
if not isinstance(err, OperationFailure):
return False
return err.code == _AUTHENTICATION_FAILURE_CODE
def _invalidate(self, conn: Connection) -> None:
# Ignore the invalidation if a token gen id is given and is less than our
# current token gen id.
token_gen_id = conn.oidc_token_gen_id or 0
if token_gen_id is not None and token_gen_id < self.token_gen_id:
return
self.access_token = None
def _sasl_continue_jwt(
self, conn: Connection, start_resp: Mapping[str, Any]
) -> Mapping[str, Any]:
self.access_token = None
self.refresh_token = None
start_payload: dict = bson.decode(start_resp["payload"])
if "issuer" in start_payload:
self.idp_info = OIDCIdPInfo(**start_payload)
access_token = self._get_access_token()
conn.oidc_token_gen_id = self.token_gen_id
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
return self._run_command(conn, cmd)
def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]:
access_token = self._get_access_token()
conn.oidc_token_gen_id = self.token_gen_id
cmd = self._get_start_command({"jwt": access_token})
return self._run_command(conn, cmd)
def _get_start_command(self, payload: Optional[Mapping[str, Any]]) -> MutableMapping[str, Any]:
if payload is None:
principal_name = self.username
if principal_name:
payload = {"n": principal_name}
else:
payload = {}
bin_payload = Binary(bson.encode(payload))
return {"saslStart": 1, "mechanism": "MONGODB-OIDC", "payload": bin_payload}
def _get_continue_command(
self, payload: Mapping[str, Any], start_resp: Mapping[str, Any]
) -> MutableMapping[str, Any]:
bin_payload = Binary(bson.encode(payload))
return {
"saslContinue": 1,
"payload": bin_payload,
"conversationId": start_resp["conversationId"],
}
def _authenticate_oidc(
credentials: MongoCredential, conn: Connection, reauthenticate: bool
) -> Optional[Mapping[str, Any]]:
"""Authenticate using MONGODB-OIDC."""
authenticator = _get_authenticator(credentials, conn.address)
if reauthenticate:
return authenticator.reauthenticate(conn)
else:
return authenticator.authenticate(conn)
__doc__ = original_doc

View File

@ -1,489 +1,21 @@
# Copyright 2017 MongoDB, Inc.
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Watch changes on a collection, a database, or the entire cluster."""
"""Re-import of synchronous ChangeStream API for compatibility."""
from __future__ import annotations
import copy
from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
from pymongo.synchronous.change_stream import * # noqa: F403
from pymongo.synchronous.change_stream import __doc__ as original_doc
from bson import CodecOptions, _bson_to_dict
from bson.raw_bson import RawBSONDocument
from bson.timestamp import Timestamp
from pymongo import _csot, common
from pymongo.aggregation import (
_AggregationCommand,
_CollectionAggregationCommand,
_DatabaseAggregationCommand,
)
from pymongo.collation import validate_collation_or_none
from pymongo.command_cursor import CommandCursor
from pymongo.errors import (
ConnectionFailure,
CursorNotFound,
InvalidOperation,
OperationFailure,
PyMongoError,
)
from pymongo.operations import _Op
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
# The change streams spec considers the following server errors from the
# getMore command non-resumable. All other getMore errors are resumable.
_RESUMABLE_GETMORE_ERRORS = frozenset(
[
6, # HostUnreachable
7, # HostNotFound
89, # NetworkTimeout
91, # ShutdownInProgress
189, # PrimarySteppedDown
262, # ExceededTimeLimit
9001, # SocketException
10107, # NotWritablePrimary
11600, # InterruptedAtShutdown
11602, # InterruptedDueToReplStateChange
13435, # NotPrimaryNoSecondaryOk
13436, # NotPrimaryOrSecondary
63, # StaleShardVersion
150, # StaleEpoch
13388, # StaleConfig
234, # RetryChangeStream
133, # FailedToSatisfyReadPreference
]
)
if TYPE_CHECKING:
from pymongo.client_session import ClientSession
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.mongo_client import MongoClient
from pymongo.pool import Connection
def _resumable(exc: PyMongoError) -> bool:
"""Return True if given a resumable change stream error."""
if isinstance(exc, (ConnectionFailure, CursorNotFound)):
return True
if isinstance(exc, OperationFailure):
if exc._max_wire_version is None:
return False
return (
exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError")
) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS)
return False
class ChangeStream(Generic[_DocumentType]):
"""The internal abstract base class for change stream cursors.
Should not be called directly by application developers. Use
:meth:`pymongo.collection.Collection.watch`,
:meth:`pymongo.database.Database.watch`, or
:meth:`pymongo.mongo_client.MongoClient.watch` instead.
.. versionadded:: 3.6
.. seealso:: The MongoDB documentation on `changeStreams <https://mongodb.com/docs/manual/changeStreams/>`_.
"""
def __init__(
self,
target: Union[
MongoClient[_DocumentType], Database[_DocumentType], Collection[_DocumentType]
],
pipeline: Optional[_Pipeline],
full_document: Optional[str],
resume_after: Optional[Mapping[str, Any]],
max_await_time_ms: Optional[int],
batch_size: Optional[int],
collation: Optional[_CollationIn],
start_at_operation_time: Optional[Timestamp],
session: Optional[ClientSession],
start_after: Optional[Mapping[str, Any]],
comment: Optional[Any] = None,
full_document_before_change: Optional[str] = None,
show_expanded_events: Optional[bool] = None,
) -> None:
if pipeline is None:
pipeline = []
pipeline = common.validate_list("pipeline", pipeline)
common.validate_string_or_none("full_document", full_document)
validate_collation_or_none(collation)
common.validate_non_negative_integer_or_none("batchSize", batch_size)
self._decode_custom = False
self._orig_codec_options: CodecOptions[_DocumentType] = target.codec_options
if target.codec_options.type_registry._decoder_map:
self._decode_custom = True
# Keep the type registry so that we support encoding custom types
# in the pipeline.
self._target = target.with_options( # type: ignore
codec_options=target.codec_options.with_options(document_class=RawBSONDocument)
)
else:
self._target = target
self._pipeline = copy.deepcopy(pipeline)
self._full_document = full_document
self._full_document_before_change = full_document_before_change
self._uses_start_after = start_after is not None
self._uses_resume_after = resume_after is not None
self._resume_token = copy.deepcopy(start_after or resume_after)
self._max_await_time_ms = max_await_time_ms
self._batch_size = batch_size
self._collation = collation
self._start_at_operation_time = start_at_operation_time
self._session = session
self._comment = comment
self._closed = False
self._timeout = self._target._timeout
self._show_expanded_events = show_expanded_events
# Initialize cursor.
self._cursor = self._create_cursor()
@property
def _aggregation_command_class(self) -> Type[_AggregationCommand]:
"""The aggregation command class to be used."""
raise NotImplementedError
@property
def _client(self) -> MongoClient:
"""The client against which the aggregation commands for
this ChangeStream will be run.
"""
raise NotImplementedError
def _change_stream_options(self) -> dict[str, Any]:
"""Return the options dict for the $changeStream pipeline stage."""
options: dict[str, Any] = {}
if self._full_document is not None:
options["fullDocument"] = self._full_document
if self._full_document_before_change is not None:
options["fullDocumentBeforeChange"] = self._full_document_before_change
resume_token = self.resume_token
if resume_token is not None:
if self._uses_start_after:
options["startAfter"] = resume_token
else:
options["resumeAfter"] = resume_token
elif self._start_at_operation_time is not None:
options["startAtOperationTime"] = self._start_at_operation_time
if self._show_expanded_events:
options["showExpandedEvents"] = self._show_expanded_events
return options
def _command_options(self) -> dict[str, Any]:
"""Return the options dict for the aggregation command."""
options = {}
if self._max_await_time_ms is not None:
options["maxAwaitTimeMS"] = self._max_await_time_ms
if self._batch_size is not None:
options["batchSize"] = self._batch_size
return options
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
"""Return the full aggregation pipeline for this ChangeStream."""
options = self._change_stream_options()
full_pipeline: list = [{"$changeStream": options}]
full_pipeline.extend(self._pipeline)
return full_pipeline
def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None:
"""Callback that caches the postBatchResumeToken or
startAtOperationTime from a changeStream aggregate command response
containing an empty batch of change documents.
This is implemented as a callback because we need access to the wire
version in order to determine whether to cache this value.
"""
if not result["cursor"]["firstBatch"]:
if "postBatchResumeToken" in result["cursor"]:
self._resume_token = result["cursor"]["postBatchResumeToken"]
elif (
self._start_at_operation_time is None
and self._uses_resume_after is False
and self._uses_start_after is False
and conn.max_wire_version >= 7
):
self._start_at_operation_time = result.get("operationTime")
# PYTHON-2181: informative error on missing operationTime.
if self._start_at_operation_time is None:
raise OperationFailure(
"Expected field 'operationTime' missing from command "
f"response : {result!r}"
)
def _run_aggregation_cmd(
self, session: Optional[ClientSession], explicit_session: bool
) -> CommandCursor:
"""Run the full aggregation pipeline for this ChangeStream and return
the corresponding CommandCursor.
"""
cmd = self._aggregation_command_class(
self._target,
CommandCursor,
self._aggregation_pipeline(),
self._command_options(),
explicit_session,
result_processor=self._process_result,
comment=self._comment,
)
return self._client._retryable_read(
cmd.get_cursor,
self._target._read_preference_for(session),
session,
operation=_Op.AGGREGATE,
)
def _create_cursor(self) -> CommandCursor:
with self._client._tmp_session(self._session, close=False) as s:
return self._run_aggregation_cmd(session=s, explicit_session=self._session is not None)
def _resume(self) -> None:
"""Reestablish this change stream after a resumable error."""
try:
self._cursor.close()
except PyMongoError:
pass
self._cursor = self._create_cursor()
def close(self) -> None:
"""Close this ChangeStream."""
self._closed = True
self._cursor.close()
def __iter__(self) -> ChangeStream[_DocumentType]:
return self
@property
def resume_token(self) -> Optional[Mapping[str, Any]]:
"""The cached resume token that will be used to resume after the most
recently returned change.
.. versionadded:: 3.9
"""
return copy.deepcopy(self._resume_token)
@_csot.apply
def next(self) -> _DocumentType:
"""Advance the cursor.
This method blocks until the next change document is returned or an
unrecoverable error is raised. This method is used when iterating over
all changes in the cursor. For example::
try:
resume_token = None
pipeline = [{'$match': {'operationType': 'insert'}}]
with db.collection.watch(pipeline) as stream:
for insert_change in stream:
print(insert_change)
resume_token = stream.resume_token
except pymongo.errors.PyMongoError:
# The ChangeStream encountered an unrecoverable error or the
# resume attempt failed to recreate the cursor.
if resume_token is None:
# There is no usable resume token because there was a
# failure during ChangeStream initialization.
logging.error('...')
else:
# Use the interrupted ChangeStream's resume token to create
# a new ChangeStream. The new stream will continue from the
# last seen insert change without missing any events.
with db.collection.watch(
pipeline, resume_after=resume_token) as stream:
for insert_change in stream:
print(insert_change)
Raises :exc:`StopIteration` if this ChangeStream is closed.
"""
while self.alive:
doc = self.try_next()
if doc is not None:
return doc
raise StopIteration
__next__ = next
@property
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
.. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise
:exc:`StopIteration` and :meth:`try_next` can return ``None``.
.. versionadded:: 3.8
"""
return not self._closed
@_csot.apply
def try_next(self) -> Optional[_DocumentType]:
"""Advance the cursor without blocking indefinitely.
This method returns the next change document without waiting
indefinitely for the next change. For example::
with db.collection.watch() as stream:
while stream.alive:
change = stream.try_next()
# Note that the ChangeStream's resume token may be updated
# even when no changes are returned.
print("Current resume token: %r" % (stream.resume_token,))
if change is not None:
print("Change document: %r" % (change,))
continue
# We end up here when there are no recent changes.
# Sleep for a while before trying again to avoid flooding
# the server with getMore requests when no changes are
# available.
time.sleep(10)
If no change document is cached locally then this method runs a single
getMore command. If the getMore yields any documents, the next
document is returned, otherwise, if the getMore returns no documents
(because there have been no changes) then ``None`` is returned.
:return: The next change document or ``None`` when no document is available
after running a single getMore or when the cursor is closed.
.. versionadded:: 3.8
"""
if not self._closed and not self._cursor.alive:
self._resume()
# Attempt to get the next change with at most one getMore and at most
# one resume attempt.
try:
try:
change = self._cursor._try_next(True)
except PyMongoError as exc:
if not _resumable(exc):
raise
self._resume()
change = self._cursor._try_next(False)
except PyMongoError as exc:
# Close the stream after a fatal error.
if not _resumable(exc) and not exc.timeout:
self.close()
raise
except Exception:
self.close()
raise
# Check if the cursor was invalidated.
if not self._cursor.alive:
self._closed = True
# If no changes are available.
if change is None:
# We have either iterated over all documents in the cursor,
# OR the most-recently returned batch is empty. In either case,
# update the cached resume token with the postBatchResumeToken if
# one was returned. We also clear the startAtOperationTime.
if self._cursor._post_batch_resume_token is not None:
self._resume_token = self._cursor._post_batch_resume_token
self._start_at_operation_time = None
return change
# Else, changes are available.
try:
resume_token = change["_id"]
except KeyError:
self.close()
raise InvalidOperation(
"Cannot provide resume functionality when the resume token is missing."
) from None
# If this is the last change document from the current batch, cache the
# postBatchResumeToken.
if not self._cursor._has_next() and self._cursor._post_batch_resume_token:
resume_token = self._cursor._post_batch_resume_token
# Hereafter, don't use startAfter; instead use resumeAfter.
self._uses_start_after = False
self._uses_resume_after = True
# Cache the resume token and clear startAtOperationTime.
self._resume_token = resume_token
self._start_at_operation_time = None
if self._decode_custom:
return _bson_to_dict(change.raw, self._orig_codec_options)
return change
def __enter__(self) -> ChangeStream[_DocumentType]:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
class CollectionChangeStream(ChangeStream[_DocumentType]):
"""A change stream that watches changes on a single collection.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.collection.Collection.watch` instead.
.. versionadded:: 3.7
"""
_target: Collection[_DocumentType]
@property
def _aggregation_command_class(self) -> Type[_CollectionAggregationCommand]:
return _CollectionAggregationCommand
@property
def _client(self) -> MongoClient[_DocumentType]:
return self._target.database.client
class DatabaseChangeStream(ChangeStream[_DocumentType]):
"""A change stream that watches changes on all collections in a database.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.database.Database.watch` instead.
.. versionadded:: 3.7
"""
_target: Database[_DocumentType]
@property
def _aggregation_command_class(self) -> Type[_DatabaseAggregationCommand]:
return _DatabaseAggregationCommand
@property
def _client(self) -> MongoClient[_DocumentType]:
return self._target.client
class ClusterChangeStream(DatabaseChangeStream[_DocumentType]):
"""A change stream that watches changes on all collections in the cluster.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.mongo_client.MongoClient.watch` instead.
.. versionadded:: 3.7
"""
def _change_stream_options(self) -> dict[str, Any]:
options = super()._change_stream_options()
options["allChangesForCluster"] = True
return options
__doc__ = original_doc

View File

@ -1,332 +1,21 @@
# Copyright 2014-present MongoDB, Inc.
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools to parse mongo client options."""
"""Re-import of synchronous ClientOptions API for compatibility."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast
from pymongo.synchronous.client_options import * # noqa: F403
from pymongo.synchronous.client_options import __doc__ as original_doc
from bson.codec_options import _parse_codec_options
from pymongo import common
from pymongo.compression_support import CompressionSettings
from pymongo.errors import ConfigurationError
from pymongo.monitoring import _EventListener, _EventListeners
from pymongo.pool import PoolOptions
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import (
_ServerMode,
make_read_preference,
read_pref_mode_from_name,
)
from pymongo.server_selectors import any_server_selector
from pymongo.ssl_support import get_ssl_context
from pymongo.write_concern import WriteConcern, validate_boolean
if TYPE_CHECKING:
from bson.codec_options import CodecOptions
from pymongo.auth import MongoCredential
from pymongo.encryption_options import AutoEncryptionOpts
from pymongo.pyopenssl_context import SSLContext
from pymongo.topology_description import _ServerSelector
def _parse_credentials(
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
) -> Optional[MongoCredential]:
"""Parse authentication credentials."""
mechanism = options.get("authmechanism", "DEFAULT" if username else None)
source = options.get("authsource")
if username or mechanism:
from pymongo.auth import _build_credentials_tuple
return _build_credentials_tuple(mechanism, source, username, password, options, database)
return None
def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode:
"""Parse read preference options."""
if "read_preference" in options:
return options["read_preference"]
name = options.get("readpreference", "primary")
mode = read_pref_mode_from_name(name)
tags = options.get("readpreferencetags")
max_staleness = options.get("maxstalenessseconds", -1)
return make_read_preference(mode, tags, max_staleness)
def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern:
"""Parse write concern options."""
concern = options.get("w")
wtimeout = options.get("wtimeoutms")
j = options.get("journal")
fsync = options.get("fsync")
return WriteConcern(concern, wtimeout, j, fsync)
def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
"""Parse read concern options."""
concern = options.get("readconcernlevel")
return ReadConcern(concern)
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
"""Parse ssl options."""
use_tls = options.get("tls")
if use_tls is not None:
validate_boolean("tls", use_tls)
certfile = options.get("tlscertificatekeyfile")
passphrase = options.get("tlscertificatekeyfilepassword")
ca_certs = options.get("tlscafile")
crlfile = options.get("tlscrlfile")
allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False)
allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False)
disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False)
enabled_tls_opts = []
for opt in (
"tlscertificatekeyfile",
"tlscertificatekeyfilepassword",
"tlscafile",
"tlscrlfile",
):
# Any non-null value of these options implies tls=True.
if opt in options and options[opt]:
enabled_tls_opts.append(opt)
for opt in (
"tlsallowinvalidcertificates",
"tlsallowinvalidhostnames",
"tlsdisableocspendpointcheck",
):
# A value of False for these options implies tls=True.
if opt in options and not options[opt]:
enabled_tls_opts.append(opt)
if enabled_tls_opts:
if use_tls is None:
# Implicitly enable TLS when one of the tls* options is set.
use_tls = True
elif not use_tls:
# Error since tls is explicitly disabled but a tls option is set.
raise ConfigurationError(
"TLS has not been enabled but the "
"following tls parameters have been set: "
"%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts)
)
if use_tls:
ctx = get_ssl_context(
certfile,
passphrase,
ca_certs,
crlfile,
allow_invalid_certificates,
allow_invalid_hostnames,
disable_ocsp_endpoint_check,
)
return ctx, allow_invalid_hostnames
return None, allow_invalid_hostnames
def _parse_pool_options(
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
) -> PoolOptions:
"""Parse connection pool options."""
credentials = _parse_credentials(username, password, database, options)
max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE)
min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE)
max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC)
if max_pool_size is not None and min_pool_size > max_pool_size:
raise ValueError("minPoolSize must be smaller or equal to maxPoolSize")
connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT)
socket_timeout = options.get("sockettimeoutms")
wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT)
event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners"))
appname = options.get("appname")
driver = options.get("driver")
server_api = options.get("server_api")
compression_settings = CompressionSettings(
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
)
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
load_balanced = options.get("loadbalanced")
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
return PoolOptions(
max_pool_size,
min_pool_size,
max_idle_time_seconds,
connect_timeout,
socket_timeout,
wait_queue_timeout,
ssl_context,
tls_allow_invalid_hostnames,
_EventListeners(event_listeners),
appname,
driver,
compression_settings,
max_connecting=max_connecting,
server_api=server_api,
load_balanced=load_balanced,
credentials=credentials,
)
class ClientOptions:
"""Read only configuration options for a MongoClient.
Should not be instantiated directly by application developers. Access
a client's options via :attr:`pymongo.mongo_client.MongoClient.options`
instead.
"""
def __init__(
self, username: str, password: str, database: Optional[str], options: Mapping[str, Any]
):
self.__options = options
self.__codec_options = _parse_codec_options(options)
self.__direct_connection = options.get("directconnection")
self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS)
# self.__server_selection_timeout is in seconds. Must use full name for
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
self.__server_selection_timeout = options.get(
"serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT
)
self.__pool_options = _parse_pool_options(username, password, database, options)
self.__read_preference = _parse_read_preference(options)
self.__replica_set_name = options.get("replicaset")
self.__write_concern = _parse_write_concern(options)
self.__read_concern = _parse_read_concern(options)
self.__connect = options.get("connect")
self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY)
self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES)
self.__retry_reads = options.get("retryreads", common.RETRY_READS)
self.__server_selector = options.get("server_selector", any_server_selector)
self.__auto_encryption_opts = options.get("auto_encryption_opts")
self.__load_balanced = options.get("loadbalanced")
self.__timeout = options.get("timeoutms")
self.__server_monitoring_mode = options.get(
"servermonitoringmode", common.SERVER_MONITORING_MODE
)
@property
def _options(self) -> Mapping[str, Any]:
"""The original options used to create this ClientOptions."""
return self.__options
@property
def connect(self) -> Optional[bool]:
"""Whether to begin discovering a MongoDB topology automatically."""
return self.__connect
@property
def codec_options(self) -> CodecOptions:
"""A :class:`~bson.codec_options.CodecOptions` instance."""
return self.__codec_options
@property
def direct_connection(self) -> Optional[bool]:
"""Whether to connect to the deployment in 'Single' topology."""
return self.__direct_connection
@property
def local_threshold_ms(self) -> int:
"""The local threshold for this instance."""
return self.__local_threshold_ms
@property
def server_selection_timeout(self) -> int:
"""The server selection timeout for this instance in seconds."""
return self.__server_selection_timeout
@property
def server_selector(self) -> _ServerSelector:
return self.__server_selector
@property
def heartbeat_frequency(self) -> int:
"""The monitoring frequency in seconds."""
return self.__heartbeat_frequency
@property
def pool_options(self) -> PoolOptions:
"""A :class:`~pymongo.pool.PoolOptions` instance."""
return self.__pool_options
@property
def read_preference(self) -> _ServerMode:
"""A read preference instance."""
return self.__read_preference
@property
def replica_set_name(self) -> Optional[str]:
"""Replica set name or None."""
return self.__replica_set_name
@property
def write_concern(self) -> WriteConcern:
"""A :class:`~pymongo.write_concern.WriteConcern` instance."""
return self.__write_concern
@property
def read_concern(self) -> ReadConcern:
"""A :class:`~pymongo.read_concern.ReadConcern` instance."""
return self.__read_concern
@property
def timeout(self) -> Optional[float]:
"""The configured timeoutMS converted to seconds, or None.
.. versionadded:: 4.2
"""
return self.__timeout
@property
def retry_writes(self) -> bool:
"""If this instance should retry supported write operations."""
return self.__retry_writes
@property
def retry_reads(self) -> bool:
"""If this instance should retry supported read operations."""
return self.__retry_reads
@property
def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]:
"""A :class:`~pymongo.encryption.AutoEncryptionOpts` or None."""
return self.__auto_encryption_opts
@property
def load_balanced(self) -> Optional[bool]:
"""True if the client was configured to connect to a load balancer."""
return self.__load_balanced
@property
def event_listeners(self) -> list[_EventListeners]:
"""The event listeners registered for this client.
See :mod:`~pymongo.monitoring` for details.
.. versionadded:: 4.0
"""
assert self.__pool_options._event_listeners is not None
return self.__pool_options._event_listeners.event_listeners()
@property
def server_monitoring_mode(self) -> str:
"""The configured serverMonitoringMode option.
.. versionadded:: 4.5
"""
return self.__server_monitoring_mode
__doc__ = original_doc

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
# Copyright 2016 MongoDB, Inc.
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,213 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for working with `collations`_.
.. _collations: https://www.mongodb.com/docs/manual/reference/collation/
"""
"""Re-import of synchronous Collation API for compatibility."""
from __future__ import annotations
from typing import Any, Mapping, Optional, Union
from pymongo.synchronous.collation import * # noqa: F403
from pymongo.synchronous.collation import __doc__ as original_doc
from pymongo import common
from pymongo.write_concern import validate_boolean
class CollationStrength:
"""
An enum that defines values for `strength` on a
:class:`~pymongo.collation.Collation`.
"""
PRIMARY = 1
"""Differentiate base (unadorned) characters."""
SECONDARY = 2
"""Differentiate character accents."""
TERTIARY = 3
"""Differentiate character case."""
QUATERNARY = 4
"""Differentiate words with and without punctuation."""
IDENTICAL = 5
"""Differentiate unicode code point (characters are exactly identical)."""
class CollationAlternate:
"""
An enum that defines values for `alternate` on a
:class:`~pymongo.collation.Collation`.
"""
NON_IGNORABLE = "non-ignorable"
"""Spaces and punctuation are treated as base characters."""
SHIFTED = "shifted"
"""Spaces and punctuation are *not* considered base characters.
Spaces and punctuation are distinguished regardless when the
:class:`~pymongo.collation.Collation` strength is at least
:data:`~pymongo.collation.CollationStrength.QUATERNARY`.
"""
class CollationMaxVariable:
"""
An enum that defines values for `max_variable` on a
:class:`~pymongo.collation.Collation`.
"""
PUNCT = "punct"
"""Both punctuation and spaces are ignored."""
SPACE = "space"
"""Spaces alone are ignored."""
class CollationCaseFirst:
"""
An enum that defines values for `case_first` on a
:class:`~pymongo.collation.Collation`.
"""
UPPER = "upper"
"""Sort uppercase characters first."""
LOWER = "lower"
"""Sort lowercase characters first."""
OFF = "off"
"""Default for locale or collation strength."""
class Collation:
"""Collation
:param locale: (string) The locale of the collation. This should be a string
that identifies an `ICU locale ID` exactly. For example, ``en_US`` is
valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB
documentation for a list of supported locales.
:param caseLevel: (optional) If ``True``, turn on case sensitivity if
`strength` is 1 or 2 (case sensitivity is implied if `strength` is
greater than 2). Defaults to ``False``.
:param caseFirst: (optional) Specify that either uppercase or lowercase
characters take precedence. Must be one of the following values:
* :data:`~CollationCaseFirst.UPPER`
* :data:`~CollationCaseFirst.LOWER`
* :data:`~CollationCaseFirst.OFF` (the default)
:param strength: Specify the comparison strength. This is also
known as the ICU comparison level. This must be one of the following
values:
* :data:`~CollationStrength.PRIMARY`
* :data:`~CollationStrength.SECONDARY`
* :data:`~CollationStrength.TERTIARY` (the default)
* :data:`~CollationStrength.QUATERNARY`
* :data:`~CollationStrength.IDENTICAL`
Each successive level builds upon the previous. For example, a
`strength` of :data:`~CollationStrength.SECONDARY` differentiates
characters based both on the unadorned base character and its accents.
:param numericOrdering: If ``True``, order numbers numerically
instead of in collation order (defaults to ``False``).
:param alternate: Specify whether spaces and punctuation are
considered base characters. This must be one of the following values:
* :data:`~CollationAlternate.NON_IGNORABLE` (the default)
* :data:`~CollationAlternate.SHIFTED`
:param maxVariable: When `alternate` is
:data:`~CollationAlternate.SHIFTED`, this option specifies what
characters may be ignored. This must be one of the following values:
* :data:`~CollationMaxVariable.PUNCT` (the default)
* :data:`~CollationMaxVariable.SPACE`
:param normalization: If ``True``, normalizes text into Unicode
NFD. Defaults to ``False``.
:param backwards: If ``True``, accents on characters are
considered from the back of the word to the front, as it is done in some
French dictionary ordering traditions. Defaults to ``False``.
:param kwargs: Keyword arguments supplying any additional options
to be sent with this Collation object.
.. versionadded: 3.4
"""
__slots__ = ("__document",)
def __init__(
self,
locale: str,
caseLevel: Optional[bool] = None,
caseFirst: Optional[str] = None,
strength: Optional[int] = None,
numericOrdering: Optional[bool] = None,
alternate: Optional[str] = None,
maxVariable: Optional[str] = None,
normalization: Optional[bool] = None,
backwards: Optional[bool] = None,
**kwargs: Any,
) -> None:
locale = common.validate_string("locale", locale)
self.__document: dict[str, Any] = {"locale": locale}
if caseLevel is not None:
self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel)
if caseFirst is not None:
self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst)
if strength is not None:
self.__document["strength"] = common.validate_integer("strength", strength)
if numericOrdering is not None:
self.__document["numericOrdering"] = validate_boolean(
"numericOrdering", numericOrdering
)
if alternate is not None:
self.__document["alternate"] = common.validate_string("alternate", alternate)
if maxVariable is not None:
self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable)
if normalization is not None:
self.__document["normalization"] = validate_boolean("normalization", normalization)
if backwards is not None:
self.__document["backwards"] = validate_boolean("backwards", backwards)
self.__document.update(kwargs)
@property
def document(self) -> dict[str, Any]:
"""The document representation of this collation.
.. note::
:class:`Collation` is immutable. Mutating the value of
:attr:`document` does not mutate this :class:`Collation`.
"""
return self.__document.copy()
def __repr__(self) -> str:
document = self.document
return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document))
def __eq__(self, other: Any) -> bool:
if isinstance(other, Collation):
return self.document == other.document
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def validate_collation_or_none(
value: Optional[Union[Mapping[str, Any], Collation]]
) -> Optional[dict[str, Any]]:
if value is None:
return None
if isinstance(value, Collation):
return value.document
if isinstance(value, dict):
return value
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")
__doc__ = original_doc

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
# Copyright 2014-present MongoDB, Inc.
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,390 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""CommandCursor class to iterate over command results."""
"""Re-import of synchronous CommandCursor API for compatibility."""
from __future__ import annotations
from collections import deque
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterator,
Mapping,
NoReturn,
Optional,
Sequence,
Union,
)
from pymongo.synchronous.command_cursor import * # noqa: F403
from pymongo.synchronous.command_cursor import __doc__ as original_doc
from bson import CodecOptions, _convert_raw_document_lists_to_streams
from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _ConnectionManager
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.message import _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore
from pymongo.response import PinnedResponse
from pymongo.typings import _Address, _DocumentOut, _DocumentType
if TYPE_CHECKING:
from pymongo.client_session import ClientSession
from pymongo.collection import Collection
from pymongo.pool import Connection
class CommandCursor(Generic[_DocumentType]):
"""A cursor / iterator over command cursors."""
_getmore_class = _GetMore
def __init__(
self,
collection: Collection[_DocumentType],
cursor_info: Mapping[str, Any],
address: Optional[_Address],
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[ClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new command cursor."""
self.__sock_mgr: Any = None
self.__collection: Collection[_DocumentType] = collection
self.__id = cursor_info["id"]
self.__data = deque(cursor_info["firstBatch"])
self.__postbatchresumetoken: Optional[Mapping[str, Any]] = cursor_info.get(
"postBatchResumeToken"
)
self.__address = address
self.__batch_size = batch_size
self.__max_await_time_ms = max_await_time_ms
self.__session = session
self.__explicit_session = explicit_session
self.__killed = self.__id == 0
self.__comment = comment
if self.__killed:
self.__end_session(True)
if "ns" in cursor_info: # noqa: SIM401
self.__ns = cursor_info["ns"]
else:
self.__ns = collection.full_name
self.batch_size(batch_size)
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
raise TypeError("max_await_time_ms must be an integer or None")
def __del__(self) -> None:
self.__die()
def __die(self, synchronous: bool = False) -> None:
"""Closes this cursor."""
already_killed = self.__killed
self.__killed = True
if self.__id and not already_killed:
cursor_id = self.__id
assert self.__address is not None
address = _CursorAddress(self.__address, self.__ns)
else:
# Skip killCursors.
cursor_id = 0
address = None
self.__collection.database.client._cleanup_cursor(
synchronous,
cursor_id,
address,
self.__sock_mgr,
self.__session,
self.__explicit_session,
)
if not self.__explicit_session:
self.__session = None
self.__sock_mgr = None
def __end_session(self, synchronous: bool) -> None:
if self.__session and not self.__explicit_session:
self.__session._end_session(lock=synchronous)
self.__session = None
def close(self) -> None:
"""Explicitly close / kill this cursor."""
self.__die(True)
def batch_size(self, batch_size: int) -> CommandCursor[_DocumentType]:
"""Limits the number of documents returned in one batch. Each batch
requires a round trip to the server. It can be adjusted to optimize
performance and limit data transfer.
.. note:: batch_size can not override MongoDB's internal limits on the
amount of data it will return to the client in a single batch (i.e
if you set batch size to 1,000,000,000, MongoDB will currently only
return 4-16MB of results per batch).
Raises :exc:`TypeError` if `batch_size` is not an integer.
Raises :exc:`ValueError` if `batch_size` is less than ``0``.
:param batch_size: The size of each batch of results requested.
"""
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")
if batch_size < 0:
raise ValueError("batch_size must be >= 0")
self.__batch_size = batch_size == 1 and 2 or batch_size
return self
def _has_next(self) -> bool:
"""Returns `True` if the cursor has documents remaining from the
previous batch.
"""
return len(self.__data) > 0
@property
def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]:
"""Retrieve the postBatchResumeToken from the response to a
changeStream aggregate or getMore.
"""
return self.__postbatchresumetoken
def _maybe_pin_connection(self, conn: Connection) -> None:
client = self.__collection.database.client
if not client._should_pin_cursor(self.__session):
return
if not self.__sock_mgr:
conn.pin_cursor()
conn_mgr = _ConnectionManager(conn, False)
# Ensure the connection gets returned when the entire result is
# returned in the first batch.
if self.__id == 0:
conn_mgr.close()
else:
self.__sock_mgr = conn_mgr
def __send_message(self, operation: _GetMore) -> None:
"""Send a getmore message and handle the response."""
client = self.__collection.database.client
try:
response = client._run_operation(
operation, self._unpack_response, address=self.__address
)
except OperationFailure as exc:
if exc.code in _CURSOR_CLOSED_ERRORS:
# Don't send killCursors because the cursor is already closed.
self.__killed = True
if exc.timeout:
self.__die(False)
else:
# Return the session and pinned connection, if necessary.
self.close()
raise
except ConnectionFailure:
# Don't send killCursors because the cursor is already closed.
self.__killed = True
# Return the session and pinned connection, if necessary.
self.close()
raise
except Exception:
self.close()
raise
if isinstance(response, PinnedResponse):
if not self.__sock_mgr:
self.__sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
if response.from_command:
cursor = response.docs[0]["cursor"]
documents = cursor["nextBatch"]
self.__postbatchresumetoken = cursor.get("postBatchResumeToken")
self.__id = cursor["id"]
else:
documents = response.docs
assert isinstance(response.data, _OpReply)
self.__id = response.data.cursor_id
if self.__id == 0:
self.close()
self.__data = deque(documents)
def _unpack_response(
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions[Mapping[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> Sequence[_DocumentOut]:
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
def _refresh(self) -> int:
"""Refreshes the cursor with more data from the server.
Returns the length of self.__data after refresh. Will exit early if
self.__data is already non-empty. Raises OperationFailure when the
cursor cannot be refreshed due to an error on the query.
"""
if len(self.__data) or self.__killed:
return len(self.__data)
if self.__id: # Get More
dbname, collname = self.__ns.split(".", 1)
read_pref = self.__collection._read_preference_for(self.session)
self.__send_message(
self._getmore_class(
dbname,
collname,
self.__batch_size,
self.__id,
self.__collection.codec_options,
read_pref,
self.__session,
self.__collection.database.client,
self.__max_await_time_ms,
self.__sock_mgr,
False,
self.__comment,
)
)
else: # Cursor id is zero nothing else to return
self.__die(True)
return len(self.__data)
@property
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
Even if :attr:`alive` is ``True``, :meth:`next` can raise
:exc:`StopIteration`. Best to use a for loop::
for doc in collection.aggregate(pipeline):
print(doc)
.. note:: :attr:`alive` can be True while iterating a cursor from
a failed server. In this case :attr:`alive` will return False after
:meth:`next` fails to retrieve the next batch of results from the
server.
"""
return bool(len(self.__data) or (not self.__killed))
@property
def cursor_id(self) -> int:
"""Returns the id of the cursor."""
return self.__id
@property
def address(self) -> Optional[_Address]:
"""The (host, port) of the server used, or None.
.. versionadded:: 3.0
"""
return self.__address
@property
def session(self) -> Optional[ClientSession]:
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
.. versionadded:: 3.6
"""
if self.__explicit_session:
return self.__session
return None
def __iter__(self) -> Iterator[_DocumentType]:
return self
def next(self) -> _DocumentType:
"""Advance the cursor."""
# Block until a document is returnable.
while self.alive:
doc = self._try_next(True)
if doc is not None:
return doc
raise StopIteration
__next__ = next
def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
"""Advance the cursor blocking for at most one getMore command."""
if not len(self.__data) and not self.__killed and get_more_allowed:
self._refresh()
if len(self.__data):
return self.__data.popleft()
else:
return None
def try_next(self) -> Optional[_DocumentType]:
"""Advance the cursor without blocking indefinitely.
This method returns the next document without waiting
indefinitely for data.
If no document is cached locally then this method runs a single
getMore command. If the getMore yields any documents, the next
document is returned, otherwise, if the getMore returns no documents
(because there is no additional data) then ``None`` is returned.
:return: The next document or ``None`` when no document is available
after running a single getMore or when the cursor is closed.
.. versionadded:: 4.5
"""
return self._try_next(get_more_allowed=True)
def __enter__(self) -> CommandCursor[_DocumentType]:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]):
_getmore_class = _RawBatchGetMore
def __init__(
self,
collection: Collection[_DocumentType],
cursor_info: Mapping[str, Any],
address: Optional[_Address],
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[ClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new cursor / iterator over raw batches of BSON data.
Should not be called directly by application developers -
see :meth:`~pymongo.collection.Collection.aggregate_raw_batches`
instead.
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
"""
assert not cursor_info.get("firstBatch")
super().__init__(
collection,
cursor_info,
address,
batch_size,
max_await_time_ms,
session,
explicit_session,
comment,
)
def _unpack_response( # type: ignore[override]
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> list[Mapping[str, Any]]:
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
if not legacy_response:
# OP_MSG returns firstBatch/nextBatch documents as a BSON array
# Re-assemble the array of documents into a document stream
_convert_raw_document_lists_to_streams(raw_response[0])
return raw_response # type: ignore[return-value]
def __getitem__(self, index: int) -> NoReturn:
raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor")
__doc__ = original_doc

File diff suppressed because it is too large Load Diff

94
pymongo/cursor_shared.py Normal file
View File

@ -0,0 +1,94 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Constants and types shared across all cursor classes."""
from __future__ import annotations
from typing import Any, Mapping, Sequence, Tuple, Union
# These errors mean that the server has already killed the cursor so there is
# no need to send killCursors.
_CURSOR_CLOSED_ERRORS = frozenset(
[
43, # CursorNotFound
175, # QueryPlanKilled
237, # CursorKilled
# On a tailable cursor, the following errors mean the capped collection
# rolled over.
# MongoDB 2.6:
# {'$err': 'Runner killed during getMore', 'code': 28617, 'ok': 0}
28617,
# MongoDB 3.0:
# {'$err': 'getMore executor error: UnknownError no details available',
# 'code': 17406, 'ok': 0}
17406,
# MongoDB 3.2 + 3.4:
# {'ok': 0.0, 'errmsg': 'GetMore command executor error:
# CappedPositionLost: CollectionScan died due to failure to restore
# tailable cursor position. Last seen record id: RecordId(3)',
# 'code': 96}
96,
# MongoDB 3.6+:
# {'ok': 0.0, 'errmsg': 'errmsg: "CollectionScan died due to failure to
# restore tailable cursor position. Last seen record id: RecordId(3)"',
# 'code': 136, 'codeName': 'CappedPositionLost'}
136,
]
)
_QUERY_OPTIONS = {
"tailable_cursor": 2,
"secondary_okay": 4,
"oplog_replay": 8,
"no_timeout": 16,
"await_data": 32,
"exhaust": 64,
"partial": 128,
}
class CursorType:
NON_TAILABLE = 0
"""The standard cursor type."""
TAILABLE = _QUERY_OPTIONS["tailable_cursor"]
"""The tailable cursor type.
Tailable cursors are only for use with capped collections. They are not
closed when the last data is retrieved but are kept open and the cursor
location marks the final document position. If more data is received
iteration of the cursor will continue from the last document received.
"""
TAILABLE_AWAIT = TAILABLE | _QUERY_OPTIONS["await_data"]
"""A tailable cursor with the await option set.
Creates a tailable cursor that will wait for a few seconds after returning
the full result set so that it can capture and return additional data added
during the query.
"""
EXHAUST = _QUERY_OPTIONS["exhaust"]
"""An exhaust cursor.
MongoDB will stream batched results to the client without waiting for the
client to request each batch, reducing latency.
"""
_Sort = Union[
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
]
_Hint = Union[str, _Sort]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,34 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Constants, helpers, and types shared across all database classes."""
from __future__ import annotations
from typing import Any, Mapping, TypeVar
from pymongo.errors import InvalidName
def _check_name(name: str) -> None:
"""Check if a database name is valid."""
if not name:
raise InvalidName("database name cannot be the empty string")
for invalid_char in [" ", ".", "$", "/", "\\", "\x00", '"']:
if invalid_char in name:
raise InvalidName("database names cannot contain the character %r" % invalid_char)
_CodecDocumentType = TypeVar("_CodecDocumentType", bound=Mapping[str, Any])

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
# Copyright 2019-present MongoDB, Inc.
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,257 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support for automatic client-side field level encryption."""
"""Re-import of synchronous EncryptionOptions API for compatibility."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional
from pymongo.synchronous.encryption_options import * # noqa: F403
from pymongo.synchronous.encryption_options import __doc__ as original_doc
try:
import pymongocrypt # type:ignore[import] # noqa: F401
_HAVE_PYMONGOCRYPT = True
except ImportError:
_HAVE_PYMONGOCRYPT = False
from bson import int64
from pymongo.common import validate_is_mapping
from pymongo.errors import ConfigurationError
from pymongo.uri_parser import _parse_kms_tls_options
if TYPE_CHECKING:
from pymongo.mongo_client import MongoClient
from pymongo.typings import _DocumentTypeArg
class AutoEncryptionOpts:
"""Options to configure automatic client-side field level encryption."""
def __init__(
self,
kms_providers: Mapping[str, Any],
key_vault_namespace: str,
key_vault_client: Optional[MongoClient[_DocumentTypeArg]] = None,
schema_map: Optional[Mapping[str, Any]] = None,
bypass_auto_encryption: bool = False,
mongocryptd_uri: str = "mongodb://localhost:27020",
mongocryptd_bypass_spawn: bool = False,
mongocryptd_spawn_path: str = "mongocryptd",
mongocryptd_spawn_args: Optional[list[str]] = None,
kms_tls_options: Optional[Mapping[str, Any]] = None,
crypt_shared_lib_path: Optional[str] = None,
crypt_shared_lib_required: bool = False,
bypass_query_analysis: bool = False,
encrypted_fields_map: Optional[Mapping[str, Any]] = None,
) -> None:
"""Options to configure automatic client-side field level encryption.
Automatic client-side field level encryption requires MongoDB >=4.2
enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not
supported for operations on a database or view and will result in
error.
Although automatic encryption requires MongoDB >=4.2 enterprise or a
MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all
users. To configure automatic *decryption* without automatic
*encryption* set ``bypass_auto_encryption=True``. Explicit
encryption and explicit decryption is also supported for all users
with the :class:`~pymongo.encryption.ClientEncryption` class.
See :ref:`automatic-client-side-encryption` for an example.
:param kms_providers: Map of KMS provider options. The `kms_providers`
map values differ by provider:
- `aws`: Map with "accessKeyId" and "secretAccessKey" as strings.
These are the AWS access key ID and AWS secret access key used
to generate KMS messages. An optional "sessionToken" may be
included to support temporary AWS credentials.
- `azure`: Map with "tenantId", "clientId", and "clientSecret" as
strings. Additionally, "identityPlatformEndpoint" may also be
specified as a string (defaults to 'login.microsoftonline.com').
These are the Azure Active Directory credentials used to
generate Azure Key Vault messages.
- `gcp`: Map with "email" as a string and "privateKey"
as `bytes` or a base64 encoded string.
Additionally, "endpoint" may also be specified as a string
(defaults to 'oauth2.googleapis.com'). These are the
credentials used to generate Google Cloud KMS messages.
- `kmip`: Map with "endpoint" as a host with required port.
For example: ``{"endpoint": "example.com:443"}``.
- `local`: Map with "key" as `bytes` (96 bytes in length) or
a base64 encoded string which decodes
to 96 bytes. "key" is the master key used to encrypt/decrypt
data keys. This key should be generated and stored as securely
as possible.
KMS providers may be specified with an optional name suffix
separated by a colon, for example "kmip:name" or "aws:name".
Named KMS providers do not support :ref:`CSFLE on-demand credentials`.
Named KMS providers enables more than one of each KMS provider type to be configured.
For example, to configure multiple local KMS providers::
kms_providers = {
"local": {"key": local_kek1}, # Unnamed KMS provider.
"local:myname": {"key": local_kek2}, # Named KMS provider with name "myname".
}
:param key_vault_namespace: The namespace for the key vault collection.
The key vault collection contains all data keys used for encryption
and decryption. Data keys are stored as documents in this MongoDB
collection. Data keys are protected with encryption by a KMS
provider.
:param key_vault_client: By default, the key vault collection
is assumed to reside in the same MongoDB cluster as the encrypted
MongoClient. Use this option to route data key queries to a
separate MongoDB cluster.
:param schema_map: Map of collection namespace ("db.coll") to
JSON Schema. By default, a collection's JSONSchema is periodically
polled with the listCollections command. But a JSONSchema may be
specified locally with the schemaMap option.
**Supplying a `schema_map` provides more security than relying on
JSON Schemas obtained from the server. It protects against a
malicious server advertising a false JSON Schema, which could trick
the client into sending unencrypted data that should be
encrypted.**
Schemas supplied in the schemaMap only apply to configuring
automatic encryption for client side encryption. Other validation
rules in the JSON schema will not be enforced by the driver and
will result in an error.
:param bypass_auto_encryption: If ``True``, automatic
encryption will be disabled but automatic decryption will still be
enabled. Defaults to ``False``.
:param mongocryptd_uri: The MongoDB URI used to connect
to the *local* mongocryptd process. Defaults to
``'mongodb://localhost:27020'``.
:param mongocryptd_bypass_spawn: If ``True``, the encrypted
MongoClient will not attempt to spawn the mongocryptd process.
Defaults to ``False``.
:param mongocryptd_spawn_path: Used for spawning the
mongocryptd process. Defaults to ``'mongocryptd'`` and spawns
mongocryptd from the system path.
:param mongocryptd_spawn_args: A list of string arguments to
use when spawning the mongocryptd process. Defaults to
``['--idleShutdownTimeoutSecs=60']``. If the list does not include
the ``idleShutdownTimeoutSecs`` option then
``'--idleShutdownTimeoutSecs=60'`` will be added.
:param kms_tls_options: A map of KMS provider names to TLS
options to use when creating secure connections to KMS providers.
Accepts the same TLS options as
:class:`pymongo.mongo_client.MongoClient`. For example, to
override the system default CA file::
kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}}
Or to supply a client certificate::
kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}}
:param crypt_shared_lib_path: Override the path to load the crypt_shared library.
:param crypt_shared_lib_required: If True, raise an error if libmongocrypt is
unable to load the crypt_shared library.
:param bypass_query_analysis: If ``True``, disable automatic analysis
of outgoing commands. Set `bypass_query_analysis` to use explicit
encryption on indexed fields without the MongoDB Enterprise Advanced
licensed crypt_shared library.
:param encrypted_fields_map: Map of collection namespace ("db.coll") to documents
that described the encrypted fields for Queryable Encryption. For example::
{
"db.encryptedCollection": {
"escCollection": "enxcol_.encryptedCollection.esc",
"ecocCollection": "enxcol_.encryptedCollection.ecoc",
"fields": [
{
"path": "firstName",
"keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')),
"bsonType": "string",
"queries": {"queryType": "equality"}
},
{
"path": "ssn",
"keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')),
"bsonType": "string"
}
]
}
}
.. versionchanged:: 4.2
Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`,
and `bypass_query_analysis` parameters.
.. versionchanged:: 4.0
Added the `kms_tls_options` parameter and the "kmip" KMS provider.
.. versionadded:: 3.9
"""
if not _HAVE_PYMONGOCRYPT:
raise ConfigurationError(
"client side encryption requires the pymongocrypt library: "
"install a compatible version with: "
"python -m pip install 'pymongo[encryption]'"
)
if encrypted_fields_map:
validate_is_mapping("encrypted_fields_map", encrypted_fields_map)
self._encrypted_fields_map = encrypted_fields_map
self._bypass_query_analysis = bypass_query_analysis
self._crypt_shared_lib_path = crypt_shared_lib_path
self._crypt_shared_lib_required = crypt_shared_lib_required
self._kms_providers = kms_providers
self._key_vault_namespace = key_vault_namespace
self._key_vault_client = key_vault_client
self._schema_map = schema_map
self._bypass_auto_encryption = bypass_auto_encryption
self._mongocryptd_uri = mongocryptd_uri
self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
self._mongocryptd_spawn_path = mongocryptd_spawn_path
if mongocryptd_spawn_args is None:
mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"]
self._mongocryptd_spawn_args = mongocryptd_spawn_args
if not isinstance(self._mongocryptd_spawn_args, list):
raise TypeError("mongocryptd_spawn_args must be a list")
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
# Maps KMS provider name to a SSLContext.
self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options)
self._bypass_query_analysis = bypass_query_analysis
class RangeOpts:
"""Options to configure encrypted queries using the rangePreview algorithm."""
def __init__(
self,
sparsity: int,
min: Optional[Any] = None,
max: Optional[Any] = None,
precision: Optional[int] = None,
) -> None:
"""Options to configure encrypted queries using the rangePreview algorithm.
.. note:: This feature is experimental only, and not intended for public use.
:param sparsity: An integer.
:param min: A BSON scalar value corresponding to the type being queried.
:param max: A BSON scalar value corresponding to the type being queried.
:param precision: An integer, may only be set for double or decimal128 types.
.. versionadded:: 4.4
"""
self.min = min
self.max = max
self.sparsity = sparsity
self.precision = precision
@property
def document(self) -> dict[str, Any]:
doc = {}
for k, v in [
("sparsity", int64.Int64(self.sparsity)),
("precision", self.precision),
("min", self.min),
("max", self.max),
]:
if v is not None:
doc[k] = v
return doc
__doc__ = original_doc

View File

@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional, Sequence, Un
from bson.errors import InvalidDocument
if TYPE_CHECKING:
from pymongo.typings import _DocumentOut
from pymongo.asynchronous.typings import _DocumentOut
class PyMongoError(Exception):

View File

@ -1,4 +1,4 @@
# Copyright 2020-present MongoDB, Inc.
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,212 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example event logger classes.
.. versionadded:: 3.11
These loggers can be registered using :func:`register` or
:class:`~pymongo.mongo_client.MongoClient`.
``monitoring.register(CommandLogger())``
or
``MongoClient(event_listeners=[CommandLogger()])``
"""
"""Re-import of synchronous EventLoggers API for compatibility."""
from __future__ import annotations
import logging
from pymongo.synchronous.event_loggers import * # noqa: F403
from pymongo.synchronous.event_loggers import __doc__ as original_doc
from pymongo import monitoring
class CommandLogger(monitoring.CommandListener):
"""A simple listener that logs command events.
Listens for :class:`~pymongo.monitoring.CommandStartedEvent`,
:class:`~pymongo.monitoring.CommandSucceededEvent` and
:class:`~pymongo.monitoring.CommandFailedEvent` events and
logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def started(self, event: monitoring.CommandStartedEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} started on server "
f"{event.connection_id}"
)
def succeeded(self, event: monitoring.CommandSucceededEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} on server {event.connection_id} "
f"succeeded in {event.duration_micros} "
"microseconds"
)
def failed(self, event: monitoring.CommandFailedEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} on server {event.connection_id} "
f"failed in {event.duration_micros} "
"microseconds"
)
class ServerLogger(monitoring.ServerListener):
"""A simple listener that logs server discovery events.
Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`,
:class:`~pymongo.monitoring.ServerDescriptionChangedEvent`,
and :class:`~pymongo.monitoring.ServerClosedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def opened(self, event: monitoring.ServerOpeningEvent) -> None:
logging.info(f"Server {event.server_address} added to topology {event.topology_id}")
def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None:
previous_server_type = event.previous_description.server_type
new_server_type = event.new_description.server_type
if new_server_type != previous_server_type:
# server_type_name was added in PyMongo 3.4
logging.info(
f"Server {event.server_address} changed type from "
f"{event.previous_description.server_type_name} to "
f"{event.new_description.server_type_name}"
)
def closed(self, event: monitoring.ServerClosedEvent) -> None:
logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}")
class HeartbeatLogger(monitoring.ServerHeartbeatListener):
"""A simple listener that logs server heartbeat events.
Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`,
:class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`,
and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
logging.info(f"Heartbeat sent to server {event.connection_id}")
def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
# The reply.document attribute was added in PyMongo 3.4.
logging.info(
f"Heartbeat to server {event.connection_id} "
"succeeded with reply "
f"{event.reply.document}"
)
def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
logging.warning(
f"Heartbeat to server {event.connection_id} failed with error {event.reply}"
)
class TopologyLogger(monitoring.TopologyListener):
"""A simple listener that logs server topology events.
Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`,
:class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`,
and :class:`~pymongo.monitoring.TopologyClosedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def opened(self, event: monitoring.TopologyOpenedEvent) -> None:
logging.info(f"Topology with id {event.topology_id} opened")
def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None:
logging.info(f"Topology description updated for topology id {event.topology_id}")
previous_topology_type = event.previous_description.topology_type
new_topology_type = event.new_description.topology_type
if new_topology_type != previous_topology_type:
# topology_type_name was added in PyMongo 3.4
logging.info(
f"Topology {event.topology_id} changed type from "
f"{event.previous_description.topology_type_name} to "
f"{event.new_description.topology_type_name}"
)
# The has_writable_server and has_readable_server methods
# were added in PyMongo 3.4.
if not event.new_description.has_writable_server():
logging.warning("No writable servers available.")
if not event.new_description.has_readable_server():
logging.warning("No readable servers available.")
def closed(self, event: monitoring.TopologyClosedEvent) -> None:
logging.info(f"Topology with id {event.topology_id} closed")
class ConnectionPoolLogger(monitoring.ConnectionPoolListener):
"""A simple listener that logs server connection pool events.
Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`,
:class:`~pymongo.monitoring.PoolClearedEvent`,
:class:`~pymongo.monitoring.PoolClosedEvent`,
:~pymongo.monitoring.class:`ConnectionCreatedEvent`,
:class:`~pymongo.monitoring.ConnectionReadyEvent`,
:class:`~pymongo.monitoring.ConnectionClosedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckedOutEvent`,
and :class:`~pymongo.monitoring.ConnectionCheckedInEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def pool_created(self, event: monitoring.PoolCreatedEvent) -> None:
logging.info(f"[pool {event.address}] pool created")
def pool_ready(self, event: monitoring.PoolReadyEvent) -> None:
logging.info(f"[pool {event.address}] pool ready")
def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None:
logging.info(f"[pool {event.address}] pool cleared")
def pool_closed(self, event: monitoring.PoolClosedEvent) -> None:
logging.info(f"[pool {event.address}] pool closed")
def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None:
logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created")
def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded"
)
def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] "
f'connection closed, reason: "{event.reason}"'
)
def connection_check_out_started(
self, event: monitoring.ConnectionCheckOutStartedEvent
) -> None:
logging.info(f"[pool {event.address}] connection check out started")
def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None:
logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}")
def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool"
)
def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool"
)
__doc__ = original_doc

View File

@ -0,0 +1,72 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants used by the driver that don't really fit elsewhere."""
# From the SDAM spec, the "node is shutting down" codes.
from __future__ import annotations
_SHUTDOWN_CODES: frozenset = frozenset(
[
11600, # InterruptedAtShutdown
91, # ShutdownInProgress
]
)
# From the SDAM spec, the "not primary" error codes are combined with the
# "node is recovering" error codes (of which the "node is shutting down"
# errors are a subset).
_NOT_PRIMARY_CODES: frozenset = (
frozenset(
[
10058, # LegacyNotPrimary <=3.2 "not primary" error code
10107, # NotWritablePrimary
13435, # NotPrimaryNoSecondaryOk
11602, # InterruptedDueToReplStateChange
13436, # NotPrimaryOrSecondary
189, # PrimarySteppedDown
]
)
| _SHUTDOWN_CODES
)
# From the retryable writes spec.
_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
[
7, # HostNotFound
6, # HostUnreachable
89, # NetworkTimeout
9001, # SocketException
262, # ExceededTimeLimit
134, # ReadConcernMajorityNotAvailableYet
]
)
# Server code raised when re-authentication is required
_REAUTHENTICATION_REQUIRED_CODE: int = 391
# Server code raised when authentication fails.
_AUTHENTICATION_FAILURE_CODE: int = 18
# Note - to avoid bugs from forgetting which if these is all lowercase and
# which are camelCase, and at the same time avoid having to add a test for
# every command, use all lowercase here and test against command_name.lower().
_SENSITIVE_COMMANDS: set = {
"authenticate",
"saslstart",
"saslcontinue",
"getnonce",
"createuser",
"updateuser",
"copydbgetnonce",
"copydbsaslstart",
"copydb",
}

View File

@ -13,9 +13,12 @@
# limitations under the License.
from __future__ import annotations
import asyncio
import os
import threading
import time
import weakref
from typing import Any, Callable, Optional
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
@ -38,3 +41,102 @@ def _release_locks() -> None:
for lock in _forkable_locks:
if lock.locked():
lock.release()
class _ALock:
def __init__(self, lock: threading.Lock) -> None:
self._lock = lock
def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
return self._lock.acquire(blocking=blocking, timeout=timeout)
async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if timeout > 0:
tstart = time.monotonic()
while True:
acquired = self._lock.acquire(blocking=False)
if acquired:
return True
if timeout > 0 and (time.monotonic() - tstart) > timeout:
return False
if not blocking:
return False
await asyncio.sleep(0)
def release(self) -> None:
self._lock.release()
async def __aenter__(self) -> _ALock:
await self.a_acquire()
return self
def __enter__(self) -> _ALock:
self._lock.acquire()
return self
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
class _ACondition:
def __init__(self, condition: threading.Condition) -> None:
self._condition = condition
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if timeout > 0:
tstart = time.monotonic()
while True:
acquired = self._condition.acquire(blocking=False)
if acquired:
return True
if timeout > 0 and (time.monotonic() - tstart) > timeout:
return False
if not blocking:
return False
await asyncio.sleep(0)
async def wait(self, timeout: Optional[float] = None) -> bool:
if timeout is not None:
tstart = time.monotonic()
while True:
notified = self._condition.wait(0.001)
if notified:
return True
if timeout is not None and (time.monotonic() - tstart) > timeout:
return False
async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool:
if timeout is not None:
tstart = time.monotonic()
while True:
notified = self._condition.wait_for(predicate, 0.001)
if notified:
return True
if timeout is not None and (time.monotonic() - tstart) > timeout:
return False
def notify(self, n: int = 1) -> None:
self._condition.notify(n)
def notify_all(self) -> None:
self._condition.notify_all()
def release(self) -> None:
self._condition.release()
async def __aenter__(self) -> _ACondition:
await self.acquire()
return self
def __enter__(self) -> _ACondition:
self._condition.acquire()
return self
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

49
pymongo/network_layer.py Normal file
View File

@ -0,0 +1,49 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal network layer helper methods."""
from __future__ import annotations
import asyncio
import socket
import struct
from typing import (
TYPE_CHECKING,
Union,
)
from pymongo import ssl_support
if TYPE_CHECKING:
from pymongo.pyopenssl_context import _sslConn
_UNPACK_HEADER = struct.Struct("<iiii").unpack
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
_POLL_TIMEOUT = 0.5
# Errors raised by sockets (and TLS sockets) when in non-blocking mode.
BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS)
async def async_sendall(socket: Union[socket.socket, _sslConn], buf: bytes) -> None:
timeout = socket.gettimeout()
socket.settimeout(0.0)
loop = asyncio.get_event_loop()
try:
await asyncio.wait_for(loop.sock_sendall(socket, buf), timeout=timeout) # type: ignore[arg-type]
finally:
socket.settimeout(timeout)
def sendall(socket: Union[socket.socket, _sslConn], buf: bytes) -> None:
socket.sendall(buf)

View File

@ -1,4 +1,4 @@
# Copyright 2015-present MongoDB, Inc.
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,612 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Operation class definitions."""
"""Re-import of synchronous Operations API for compatibility."""
from __future__ import annotations
import enum
from typing import (
TYPE_CHECKING,
Any,
Generic,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
from pymongo.synchronous.operations import * # noqa: F403
from pymongo.synchronous.operations import __doc__ as original_doc
from bson.raw_bson import RawBSONDocument
from pymongo import helpers
from pymongo.collation import validate_collation_or_none
from pymongo.common import validate_is_mapping, validate_list
from pymongo.helpers import _gen_index_name, _index_document, _index_list
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
from pymongo.write_concern import validate_boolean
if TYPE_CHECKING:
from pymongo.bulk import _Bulk
# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary
_IndexList = Union[
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
]
_IndexKeyHint = Union[str, _IndexList]
class _Op(str, enum.Enum):
ABORT = "abortTransaction"
AGGREGATE = "aggregate"
COMMIT = "commitTransaction"
COUNT = "count"
CREATE = "create"
CREATE_INDEXES = "createIndexes"
CREATE_SEARCH_INDEXES = "createSearchIndexes"
DELETE = "delete"
DISTINCT = "distinct"
DROP = "drop"
DROP_DATABASE = "dropDatabase"
DROP_INDEXES = "dropIndexes"
DROP_SEARCH_INDEXES = "dropSearchIndexes"
END_SESSIONS = "endSessions"
FIND_AND_MODIFY = "findAndModify"
FIND = "find"
INSERT = "insert"
LIST_COLLECTIONS = "listCollections"
LIST_INDEXES = "listIndexes"
LIST_SEARCH_INDEX = "listSearchIndexes"
LIST_DATABASES = "listDatabases"
UPDATE = "update"
UPDATE_INDEX = "updateIndex"
UPDATE_SEARCH_INDEX = "updateSearchIndex"
RENAME = "rename"
GETMORE = "getMore"
KILL_CURSORS = "killCursors"
TEST = "testOperation"
class InsertOne(Generic[_DocumentType]):
"""Represents an insert_one operation."""
__slots__ = ("_doc",)
def __init__(self, document: _DocumentType) -> None:
"""Create an InsertOne instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
:param document: The document to insert. If the document is missing an
_id field one will be added.
"""
self._doc = document
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_insert(self._doc) # type: ignore[arg-type]
def __repr__(self) -> str:
return f"InsertOne({self._doc!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return other._doc == self._doc
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
class DeleteOne:
"""Represents a delete_one operation."""
__slots__ = ("_filter", "_collation", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create a DeleteOne instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
:param filter: A query that matches the document to delete.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.Collection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.4 and above.
.. versionchanged:: 3.11
Added the ``hint`` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._collation = collation
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_delete(
self._filter,
1,
collation=validate_collation_or_none(self._collation),
hint=self._hint,
)
def __repr__(self) -> str:
return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return (other._filter, other._collation, other._hint) == (
self._filter,
self._collation,
self._hint,
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
class DeleteMany:
"""Represents a delete_many operation."""
__slots__ = ("_filter", "_collation", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create a DeleteMany instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
:param filter: A query that matches the documents to delete.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.Collection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.4 and above.
.. versionchanged:: 3.11
Added the ``hint`` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._collation = collation
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_delete(
self._filter,
0,
collation=validate_collation_or_none(self._collation),
hint=self._hint,
)
def __repr__(self) -> str:
return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})"
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return (other._filter, other._collation, other._hint) == (
self._filter,
self._collation,
self._hint,
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
class ReplaceOne(Generic[_DocumentType]):
"""Represents a replace_one operation."""
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
replacement: Union[_DocumentType, RawBSONDocument],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create a ReplaceOne instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
:param filter: A query that matches the document to replace.
:param replacement: The new document.
:param upsert: If ``True``, perform an insert if no documents
match the filter.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.Collection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
.. versionchanged:: 3.11
Added the ``hint`` option.
.. versionchanged:: 3.5
Added the ``collation`` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
if upsert is not None:
validate_boolean("upsert", upsert)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._doc = replacement
self._upsert = upsert
self._collation = collation
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_replace(
self._filter,
self._doc,
self._upsert,
collation=validate_collation_or_none(self._collation),
hint=self._hint,
)
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return (
other._filter,
other._doc,
other._upsert,
other._collation,
other._hint,
) == (
self._filter,
self._doc,
self._upsert,
self._collation,
other._hint,
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __repr__(self) -> str:
return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format(
self.__class__.__name__,
self._filter,
self._doc,
self._upsert,
self._collation,
self._hint,
)
class _UpdateOp:
"""Private base class for update operations."""
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint")
def __init__(
self,
filter: Mapping[str, Any],
doc: Union[Mapping[str, Any], _Pipeline],
upsert: bool,
collation: Optional[_CollationIn],
array_filters: Optional[list[Mapping[str, Any]]],
hint: Optional[_IndexKeyHint],
):
if filter is not None:
validate_is_mapping("filter", filter)
if upsert is not None:
validate_boolean("upsert", upsert)
if array_filters is not None:
validate_list("array_filters", array_filters)
if hint is not None and not isinstance(hint, str):
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
else:
self._hint = hint
self._filter = filter
self._doc = doc
self._upsert = upsert
self._collation = collation
self._array_filters = array_filters
def __eq__(self, other: object) -> bool:
if isinstance(other, type(self)):
return (
other._filter,
other._doc,
other._upsert,
other._collation,
other._array_filters,
other._hint,
) == (
self._filter,
self._doc,
self._upsert,
self._collation,
self._array_filters,
self._hint,
)
return NotImplemented
def __repr__(self) -> str:
return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format(
self.__class__.__name__,
self._filter,
self._doc,
self._upsert,
self._collation,
self._array_filters,
self._hint,
)
class UpdateOne(_UpdateOp):
"""Represents an update_one operation."""
__slots__ = ()
def __init__(
self,
filter: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Represents an update_one operation.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
:param filter: A query that matches the document to update.
:param update: The modifications to apply.
:param upsert: If ``True``, perform an insert if no documents
match the filter.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param array_filters: A list of filters specifying which
array elements an update should apply.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.Collection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
.. versionchanged:: 3.11
Added the `hint` option.
.. versionchanged:: 3.9
Added the ability to accept a pipeline as the `update`.
.. versionchanged:: 3.6
Added the `array_filters` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
super().__init__(filter, update, upsert, collation, array_filters, hint)
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_update(
self._filter,
self._doc,
False,
self._upsert,
collation=validate_collation_or_none(self._collation),
array_filters=self._array_filters,
hint=self._hint,
)
class UpdateMany(_UpdateOp):
"""Represents an update_many operation."""
__slots__ = ()
def __init__(
self,
filter: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create an UpdateMany instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
:param filter: A query that matches the documents to update.
:param update: The modifications to apply.
:param upsert: If ``True``, perform an insert if no documents
match the filter.
:param collation: An instance of
:class:`~pymongo.collation.Collation`.
:param array_filters: A list of filters specifying which
array elements an update should apply.
:param hint: An index to use to support the query
predicate specified either by its string name, or in the same
format as passed to
:meth:`~pymongo.collection.Collection.create_index` (e.g.
``[('field', ASCENDING)]``). This option is only supported on
MongoDB 4.2 and above.
.. versionchanged:: 3.11
Added the `hint` option.
.. versionchanged:: 3.9
Added the ability to accept a pipeline as the `update`.
.. versionchanged:: 3.6
Added the `array_filters` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
super().__init__(filter, update, upsert, collation, array_filters, hint)
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_update(
self._filter,
self._doc,
True,
self._upsert,
collation=validate_collation_or_none(self._collation),
array_filters=self._array_filters,
hint=self._hint,
)
class IndexModel:
"""Represents an index to create."""
__slots__ = ("__document",)
def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None:
"""Create an Index instance.
For use with :meth:`~pymongo.collection.Collection.create_indexes`.
Takes either a single key or a list containing (key, direction) pairs
or keys. If no direction is given, :data:`~pymongo.ASCENDING` will
be assumed.
The key(s) must be an instance of :class:`str`, and the direction(s) must
be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`,
:data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`,
:data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`).
Valid options include, but are not limited to:
- `name`: custom name to use for this index - if none is
given, a name will be generated.
- `unique`: if ``True``, creates a uniqueness constraint on the index.
- `background`: if ``True``, this index should be created in the
background.
- `sparse`: if ``True``, omit from the index any documents that lack
the indexed field.
- `bucketSize`: for use with geoHaystack indexes.
Number of documents to group together within a certain proximity
to a given longitude and latitude.
- `min`: minimum value for keys in a :data:`~pymongo.GEO2D`
index.
- `max`: maximum value for keys in a :data:`~pymongo.GEO2D`
index.
- `expireAfterSeconds`: <int> Used to create an expiring (TTL)
collection. MongoDB will automatically delete documents from
this collection after <int> seconds. The indexed field must
be a UTC datetime or the data will not expire.
- `partialFilterExpression`: A document that specifies a filter for
a partial index.
- `collation`: An instance of :class:`~pymongo.collation.Collation`
that specifies the collation to use.
- `wildcardProjection`: Allows users to include or exclude specific
field paths from a `wildcard index`_ using the { "$**" : 1} key
pattern. Requires MongoDB >= 4.2.
- `hidden`: if ``True``, this index will be hidden from the query
planner and will not be evaluated as part of query plan
selection. Requires MongoDB >= 4.4.
See the MongoDB documentation for a full list of supported options by
server version.
:param keys: a single key or a list containing (key, direction) pairs
or keys specifying the index to create.
:param kwargs: any additional index creation
options (see the above list) should be passed as keyword
arguments.
.. versionchanged:: 3.11
Added the ``hidden`` option.
.. versionchanged:: 3.2
Added the ``partialFilterExpression`` option to support partial
indexes.
.. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/
"""
keys = _index_list(keys)
if kwargs.get("name") is None:
kwargs["name"] = _gen_index_name(keys)
kwargs["key"] = _index_document(keys)
collation = validate_collation_or_none(kwargs.pop("collation", None))
self.__document = kwargs
if collation is not None:
self.__document["collation"] = collation
@property
def document(self) -> dict[str, Any]:
"""An index document suitable for passing to the createIndexes
command.
"""
return self.__document
class SearchIndexModel:
"""Represents a search index to create."""
__slots__ = ("__document",)
def __init__(
self,
definition: Mapping[str, Any],
name: Optional[str] = None,
type: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Create a Search Index instance.
For use with :meth:`~pymongo.collection.Collection.create_search_index` and :meth:`~pymongo.collection.Collection.create_search_indexes`.
:param definition: The definition for this index.
:param name: The name for this index, if present.
:param type: The type for this index which defaults to "search". Alternative values include "vectorSearch".
:param kwargs: Keyword arguments supplying any additional options.
.. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster.
.. versionadded:: 4.5
.. versionchanged:: 4.7
Added the type and kwargs arguments.
"""
self.__document: dict[str, Any] = {}
if name is not None:
self.__document["name"] = name
self.__document["definition"] = definition
if type is not None:
self.__document["type"] = type
self.__document.update(kwargs)
@property
def document(self) -> Mapping[str, Any]:
"""The document for this index."""
return self.__document
__doc__ = original_doc

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@ context.
"""
from __future__ import annotations
import asyncio
import socket as _socket
import ssl as _stdlibssl
import sys as _sys
@ -364,6 +365,58 @@ class SSLContext:
# but not that same as CPython's.
self._ctx.set_default_verify_paths()
async def a_wrap_socket(
self,
sock: _socket.socket,
server_side: bool = False,
do_handshake_on_connect: bool = True,
suppress_ragged_eofs: bool = True,
server_hostname: Optional[str] = None,
session: Optional[_SSL.Session] = None,
) -> _sslConn:
"""Wrap an existing Python socket connection and return a TLS socket
object.
"""
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
loop = asyncio.get_running_loop()
if session:
ssl_conn.set_session(session)
if server_side is True:
ssl_conn.set_accept_state()
else:
# SNI
if server_hostname and not _is_ip_address(server_hostname):
# XXX: Do this in a callback registered with
# SSLContext.set_info_callback? See Twisted for an example.
ssl_conn.set_tlsext_host_name(server_hostname.encode("idna"))
if self.verify_mode != _stdlibssl.CERT_NONE:
# Request a stapled OCSP response.
await loop.run_in_executor(None, ssl_conn.request_ocsp)
ssl_conn.set_connect_state()
# If this wasn't true the caller of wrap_socket would call
# do_handshake()
if do_handshake_on_connect:
# XXX: If we do hostname checking in a callback we can get rid
# of this call to do_handshake() since the handshake
# will happen automatically later.
await loop.run_in_executor(None, ssl_conn.do_handshake)
# XXX: Do this in a callback registered with
# SSLContext.set_info_callback? See Twisted for an example.
if self.check_hostname and server_hostname is not None:
from service_identity import pyopenssl
try:
if _is_ip_address(server_hostname):
pyopenssl.verify_ip_address(ssl_conn, server_hostname)
else:
pyopenssl.verify_hostname(ssl_conn, server_hostname)
except ( # type:ignore[misc]
service_identity.SICertificateError,
service_identity.SIVerificationError,
) as exc:
raise _CertificateError(str(exc)) from None
return ssl_conn
def wrap_socket(
self,
sock: _socket.socket,

View File

@ -1,6 +1,6 @@
# Copyright 2012-present MongoDB, Inc.
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License",
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
@ -12,611 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for choosing which member of a replica set to read from."""
"""Re-import of synchronous ReadPreferences API for compatibility."""
from __future__ import annotations
from collections import abc
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
from pymongo.synchronous.read_preferences import * # noqa: F403
from pymongo.synchronous.read_preferences import __doc__ as original_doc
from pymongo import max_staleness_selectors
from pymongo.errors import ConfigurationError
from pymongo.server_selectors import (
member_with_tags_server_selector,
secondary_with_tags_server_selector,
)
if TYPE_CHECKING:
from pymongo.server_selectors import Selection
from pymongo.topology_description import TopologyDescription
_PRIMARY = 0
_PRIMARY_PREFERRED = 1
_SECONDARY = 2
_SECONDARY_PREFERRED = 3
_NEAREST = 4
_MONGOS_MODES = (
"primary",
"primaryPreferred",
"secondary",
"secondaryPreferred",
"nearest",
)
_Hedge = Mapping[str, Any]
_TagSets = Sequence[Mapping[str, Any]]
def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]:
"""Validate tag sets for a MongoClient."""
if tag_sets is None:
return tag_sets
if not isinstance(tag_sets, (list, tuple)):
raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence")
if len(tag_sets) == 0:
raise ValueError(
f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags"
)
for tags in tag_sets:
if not isinstance(tags, abc.Mapping):
raise TypeError(
f"Tag set {tags!r} invalid, must be an instance of dict, "
"bson.son.SON or other type that inherits from "
"collection.Mapping"
)
return list(tag_sets)
def _invalid_max_staleness_msg(max_staleness: Any) -> str:
return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness
# Some duplication with common.py to avoid import cycle.
def _validate_max_staleness(max_staleness: Any) -> int:
"""Validate max_staleness."""
if max_staleness == -1:
return -1
if not isinstance(max_staleness, int):
raise TypeError(_invalid_max_staleness_msg(max_staleness))
if max_staleness <= 0:
raise ValueError(_invalid_max_staleness_msg(max_staleness))
return max_staleness
def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]:
"""Validate hedge."""
if hedge is None:
return None
if not isinstance(hedge, dict):
raise TypeError(f"hedge must be a dictionary, not {hedge!r}")
return hedge
class _ServerMode:
"""Base class for all read preferences."""
__slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge")
def __init__(
self,
mode: int,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
self.__mongos_mode = _MONGOS_MODES[mode]
self.__mode = mode
self.__tag_sets = _validate_tag_sets(tag_sets)
self.__max_staleness = _validate_max_staleness(max_staleness)
self.__hedge = _validate_hedge(hedge)
@property
def name(self) -> str:
"""The name of this read preference."""
return self.__class__.__name__
@property
def mongos_mode(self) -> str:
"""The mongos mode of this read preference."""
return self.__mongos_mode
@property
def document(self) -> dict[str, Any]:
"""Read preference as a document."""
doc: dict[str, Any] = {"mode": self.__mongos_mode}
if self.__tag_sets not in (None, [{}]):
doc["tags"] = self.__tag_sets
if self.__max_staleness != -1:
doc["maxStalenessSeconds"] = self.__max_staleness
if self.__hedge not in (None, {}):
doc["hedge"] = self.__hedge
return doc
@property
def mode(self) -> int:
"""The mode of this read preference instance."""
return self.__mode
@property
def tag_sets(self) -> _TagSets:
"""Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to
read only from members whose ``dc`` tag has the value ``"ny"``.
To specify a priority-order for tag sets, provide a list of
tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag
set, ``{}``, means "read from any member that matches the mode,
ignoring tags." MongoClient tries each set of tags in turn
until it finds a set of tags with at least one matching member.
For example, to only send a query to an analytic node::
Nearest(tag_sets=[{"node":"analytics"}])
Or using :class:`SecondaryPreferred`::
SecondaryPreferred(tag_sets=[{"node":"analytics"}])
.. seealso:: `Data-Center Awareness
<https://www.mongodb.com/docs/manual/data-center-awareness/>`_
"""
return list(self.__tag_sets) if self.__tag_sets else [{}]
@property
def max_staleness(self) -> int:
"""The maximum estimated length of time (in seconds) a replica set
secondary can fall behind the primary in replication before it will
no longer be selected for operations, or -1 for no maximum.
"""
return self.__max_staleness
@property
def hedge(self) -> Optional[_Hedge]:
"""The read preference ``hedge`` parameter.
A dictionary that configures how the server will perform hedged reads.
It consists of the following keys:
- ``enabled``: Enables or disables hedged reads in sharded clusters.
Hedged reads are automatically enabled in MongoDB 4.4+ when using a
``nearest`` read preference. To explicitly enable hedged reads, set
the ``enabled`` key to ``true``::
>>> Nearest(hedge={'enabled': True})
To explicitly disable hedged reads, set the ``enabled`` key to
``False``::
>>> Nearest(hedge={'enabled': False})
.. versionadded:: 3.11
"""
return self.__hedge
@property
def min_wire_version(self) -> int:
"""The wire protocol version the server must support.
Some read preferences impose version requirements on all servers (e.g.
maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5).
All servers' maxWireVersion must be at least this read preference's
`min_wire_version`, or the driver raises
:exc:`~pymongo.errors.ConfigurationError`.
"""
return 0 if self.__max_staleness == -1 else 5
def __repr__(self) -> str:
return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format(
self.name,
self.__tag_sets,
self.__max_staleness,
self.__hedge,
)
def __eq__(self, other: Any) -> bool:
if isinstance(other, _ServerMode):
return (
self.mode == other.mode
and self.tag_sets == other.tag_sets
and self.max_staleness == other.max_staleness
and self.hedge == other.hedge
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __getstate__(self) -> dict[str, Any]:
"""Return value of object for pickling.
Needed explicitly because __slots__() defined.
"""
return {
"mode": self.__mode,
"tag_sets": self.__tag_sets,
"max_staleness": self.__max_staleness,
"hedge": self.__hedge,
}
def __setstate__(self, value: Mapping[str, Any]) -> None:
"""Restore from pickling."""
self.__mode = value["mode"]
self.__mongos_mode = _MONGOS_MODES[self.__mode]
self.__tag_sets = _validate_tag_sets(value["tag_sets"])
self.__max_staleness = _validate_max_staleness(value["max_staleness"])
self.__hedge = _validate_hedge(value["hedge"])
def __call__(self, selection: Selection) -> Selection:
return selection
class Primary(_ServerMode):
"""Primary read preference.
* When directly connected to one mongod queries are allowed if the server
is standalone or a replica set primary.
* When connected to a mongos queries are sent to the primary of a shard.
* When connected to a replica set queries are sent to the primary of
the replica set.
"""
__slots__ = ()
def __init__(self) -> None:
super().__init__(_PRIMARY)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to a Selection."""
return selection.primary_selection
def __repr__(self) -> str:
return "Primary()"
def __eq__(self, other: Any) -> bool:
if isinstance(other, _ServerMode):
return other.mode == _PRIMARY
return NotImplemented
class PrimaryPreferred(_ServerMode):
"""PrimaryPreferred read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are sent to the primary of a shard if
available, otherwise a shard secondary.
* When connected to a replica set queries are sent to the primary if
available, otherwise a secondary.
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
created reads will be routed to an available secondary until the
primary of the replica set is discovered.
:param tag_sets: The :attr:`~tag_sets` to use if the primary is not
available.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` to use if the primary is not available.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
if selection.primary:
return selection.primary_selection
else:
return secondary_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
class Secondary(_ServerMode):
"""Secondary read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among shard
secondaries. An error is raised if no secondaries are available.
* When connected to a replica set queries are distributed among
secondaries. An error is raised if no secondaries are available.
:param tag_sets: The :attr:`~tag_sets` for this read preference.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_SECONDARY, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
return secondary_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
class SecondaryPreferred(_ServerMode):
"""SecondaryPreferred read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among shard
secondaries, or the shard primary if no secondary is available.
* When connected to a replica set queries are distributed among
secondaries, or the primary if no secondary is available.
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
created reads will be routed to the primary of the replica set until
an available secondary is discovered.
:param tag_sets: The :attr:`~tag_sets` for this read preference.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
secondaries = secondary_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
if secondaries:
return secondaries
else:
return selection.primary_selection
class Nearest(_ServerMode):
"""Nearest read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among all members of
a shard.
* When connected to a replica set queries are distributed among all
members.
:param tag_sets: The :attr:`~tag_sets` for this read preference.
:param max_staleness: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
"""
__slots__ = ()
def __init__(
self,
tag_sets: Optional[_TagSets] = None,
max_staleness: int = -1,
hedge: Optional[_Hedge] = None,
) -> None:
super().__init__(_NEAREST, tag_sets, max_staleness, hedge)
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to Selection."""
return member_with_tags_server_selector(
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
)
class _AggWritePref:
"""Agg $out/$merge write preference.
* If there are readable servers and there is any pre-5.0 server, use
primary read preference.
* Otherwise use `pref` read preference.
:param pref: The read preference to use on MongoDB 5.0+.
"""
__slots__ = ("pref", "effective_pref")
def __init__(self, pref: _ServerMode):
self.pref = pref
self.effective_pref: _ServerMode = ReadPreference.PRIMARY
def selection_hook(self, topology_description: TopologyDescription) -> None:
common_wv = topology_description.common_wire_version
if (
topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED)
and common_wv
and common_wv < 13
):
self.effective_pref = ReadPreference.PRIMARY
else:
self.effective_pref = self.pref
def __call__(self, selection: Selection) -> Selection:
"""Apply this read preference to a Selection."""
return self.effective_pref(selection)
def __repr__(self) -> str:
return f"_AggWritePref(pref={self.pref!r})"
# Proxy other calls to the effective_pref so that _AggWritePref can be
# used in place of an actual read preference.
def __getattr__(self, name: str) -> Any:
return getattr(self.effective_pref, name)
_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest)
def make_read_preference(
mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1
) -> _ServerMode:
if mode == _PRIMARY:
if tag_sets not in (None, [{}]):
raise ConfigurationError("Read preference primary cannot be combined with tags")
if max_staleness != -1:
raise ConfigurationError(
"Read preference primary cannot be combined with maxStalenessSeconds"
)
return Primary()
return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore
_MODES = (
"PRIMARY",
"PRIMARY_PREFERRED",
"SECONDARY",
"SECONDARY_PREFERRED",
"NEAREST",
)
class ReadPreference:
"""An enum that defines some commonly used read preference modes.
Apps can also create a custom read preference, for example::
Nearest(tag_sets=[{"node":"analytics"}])
See :doc:`/examples/high_availability` for code examples.
A read preference is used in three cases:
:class:`~pymongo.mongo_client.MongoClient` connected to a single mongod:
- ``PRIMARY``: Queries are allowed if the server is standalone or a replica
set primary.
- All other modes allow queries to standalone servers, to a replica set
primary, or to replica set secondaries.
:class:`~pymongo.mongo_client.MongoClient` initialized with the
``replicaSet`` option:
- ``PRIMARY``: Read from the primary. This is the default, and provides the
strongest consistency. If no primary is available, raise
:class:`~pymongo.errors.AutoReconnect`.
- ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is
none, read from a secondary.
- ``SECONDARY``: Read from a secondary. If no secondary is available,
raise :class:`~pymongo.errors.AutoReconnect`.
- ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise
from the primary.
- ``NEAREST``: Read from any member.
:class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a
sharded cluster of replica sets:
- ``PRIMARY``: Read from the primary of the shard, or raise
:class:`~pymongo.errors.OperationFailure` if there is none.
This is the default.
- ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is
none, read from a secondary of the shard.
- ``SECONDARY``: Read from a secondary of the shard, or raise
:class:`~pymongo.errors.OperationFailure` if there is none.
- ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available,
otherwise from the shard primary.
- ``NEAREST``: Read from any shard member.
"""
PRIMARY = Primary()
PRIMARY_PREFERRED = PrimaryPreferred()
SECONDARY = Secondary()
SECONDARY_PREFERRED = SecondaryPreferred()
NEAREST = Nearest()
def read_pref_mode_from_name(name: str) -> int:
"""Get the read preference mode from mongos/uri name."""
return _MONGOS_MODES.index(name)
class MovingAverage:
"""Tracks an exponentially-weighted moving average."""
average: Optional[float]
def __init__(self) -> None:
self.average = None
def add_sample(self, sample: float) -> None:
if sample < 0:
# Likely system time change while waiting for hello response
# and not using time.monotonic. Ignore it, the next one will
# probably be valid.
return
if self.average is None:
self.average = sample
else:
# The Server Selection Spec requires an exponentially weighted
# average with alpha = 0.2.
self.average = 0.8 * self.average + 0.2 * sample
def get(self) -> Optional[float]:
"""Get the calculated average, or None if no samples yet."""
return self.average
def reset(self) -> None:
self.average = None
__doc__ = original_doc

View File

@ -1,4 +1,4 @@
# Copyright 2014-present MongoDB, Inc.
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,288 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Represent one server the driver is connected to."""
"""Re-import of synchronous ServerDescription API for compatibility."""
from __future__ import annotations
import time
import warnings
from typing import Any, Mapping, Optional
from pymongo.synchronous.server_description import * # noqa: F403
from pymongo.synchronous.server_description import __doc__ as original_doc
from bson import EPOCH_NAIVE
from bson.objectid import ObjectId
from pymongo.hello import Hello
from pymongo.server_type import SERVER_TYPE
from pymongo.typings import ClusterTime, _Address
class ServerDescription:
"""Immutable representation of one server.
:param address: A (host, port) pair
:param hello: Optional Hello instance
:param round_trip_time: Optional float
:param error: Optional, the last error attempting to connect to the server
:param round_trip_time: Optional float, the min latency from the most recent samples
"""
__slots__ = (
"_address",
"_server_type",
"_all_hosts",
"_tags",
"_replica_set_name",
"_primary",
"_max_bson_size",
"_max_message_size",
"_max_write_batch_size",
"_min_wire_version",
"_max_wire_version",
"_round_trip_time",
"_min_round_trip_time",
"_me",
"_is_writable",
"_is_readable",
"_ls_timeout_minutes",
"_error",
"_set_version",
"_election_id",
"_cluster_time",
"_last_write_date",
"_last_update_time",
"_topology_version",
)
def __init__(
self,
address: _Address,
hello: Optional[Hello] = None,
round_trip_time: Optional[float] = None,
error: Optional[Exception] = None,
min_round_trip_time: float = 0.0,
) -> None:
self._address = address
if not hello:
hello = Hello({})
self._server_type = hello.server_type
self._all_hosts = hello.all_hosts
self._tags = hello.tags
self._replica_set_name = hello.replica_set_name
self._primary = hello.primary
self._max_bson_size = hello.max_bson_size
self._max_message_size = hello.max_message_size
self._max_write_batch_size = hello.max_write_batch_size
self._min_wire_version = hello.min_wire_version
self._max_wire_version = hello.max_wire_version
self._set_version = hello.set_version
self._election_id = hello.election_id
self._cluster_time = hello.cluster_time
self._is_writable = hello.is_writable
self._is_readable = hello.is_readable
self._ls_timeout_minutes = hello.logical_session_timeout_minutes
self._round_trip_time = round_trip_time
self._min_round_trip_time = min_round_trip_time
self._me = hello.me
self._last_update_time = time.monotonic()
self._error = error
self._topology_version = hello.topology_version
if error:
details = getattr(error, "details", None)
if isinstance(details, dict):
self._topology_version = details.get("topologyVersion")
self._last_write_date: Optional[float]
if hello.last_write_date:
# Convert from datetime to seconds.
delta = hello.last_write_date - EPOCH_NAIVE
self._last_write_date = delta.total_seconds()
else:
self._last_write_date = None
@property
def address(self) -> _Address:
"""The address (host, port) of this server."""
return self._address
@property
def server_type(self) -> int:
"""The type of this server."""
return self._server_type
@property
def server_type_name(self) -> str:
"""The server type as a human readable string.
.. versionadded:: 3.4
"""
return SERVER_TYPE._fields[self._server_type]
@property
def all_hosts(self) -> set[tuple[str, int]]:
"""List of hosts, passives, and arbiters known to this server."""
return self._all_hosts
@property
def tags(self) -> Mapping[str, Any]:
return self._tags
@property
def replica_set_name(self) -> Optional[str]:
"""Replica set name or None."""
return self._replica_set_name
@property
def primary(self) -> Optional[tuple[str, int]]:
"""This server's opinion about who the primary is, or None."""
return self._primary
@property
def max_bson_size(self) -> int:
return self._max_bson_size
@property
def max_message_size(self) -> int:
return self._max_message_size
@property
def max_write_batch_size(self) -> int:
return self._max_write_batch_size
@property
def min_wire_version(self) -> int:
return self._min_wire_version
@property
def max_wire_version(self) -> int:
return self._max_wire_version
@property
def set_version(self) -> Optional[int]:
return self._set_version
@property
def election_id(self) -> Optional[ObjectId]:
return self._election_id
@property
def cluster_time(self) -> Optional[ClusterTime]:
return self._cluster_time
@property
def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]:
warnings.warn(
"'election_tuple' is deprecated, use 'set_version' and 'election_id' instead",
DeprecationWarning,
stacklevel=2,
)
return self._set_version, self._election_id
@property
def me(self) -> Optional[tuple[str, int]]:
return self._me
@property
def logical_session_timeout_minutes(self) -> Optional[int]:
return self._ls_timeout_minutes
@property
def last_write_date(self) -> Optional[float]:
return self._last_write_date
@property
def last_update_time(self) -> float:
return self._last_update_time
@property
def round_trip_time(self) -> Optional[float]:
"""The current average latency or None."""
# This override is for unittesting only!
if self._address in self._host_to_round_trip_time:
return self._host_to_round_trip_time[self._address]
return self._round_trip_time
@property
def min_round_trip_time(self) -> float:
"""The min latency from the most recent samples."""
return self._min_round_trip_time
@property
def error(self) -> Optional[Exception]:
"""The last error attempting to connect to the server, or None."""
return self._error
@property
def is_writable(self) -> bool:
return self._is_writable
@property
def is_readable(self) -> bool:
return self._is_readable
@property
def mongos(self) -> bool:
return self._server_type == SERVER_TYPE.Mongos
@property
def is_server_type_known(self) -> bool:
return self.server_type != SERVER_TYPE.Unknown
@property
def retryable_writes_supported(self) -> bool:
"""Checks if this server supports retryable writes."""
return (
self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)
) or self._server_type == SERVER_TYPE.LoadBalancer
@property
def retryable_reads_supported(self) -> bool:
"""Checks if this server supports retryable writes."""
return self._max_wire_version >= 6
@property
def topology_version(self) -> Optional[Mapping[str, Any]]:
return self._topology_version
def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription:
unknown = ServerDescription(self.address, error=error)
unknown._topology_version = self.topology_version
return unknown
def __eq__(self, other: Any) -> bool:
if isinstance(other, ServerDescription):
return (
(self._address == other.address)
and (self._server_type == other.server_type)
and (self._min_wire_version == other.min_wire_version)
and (self._max_wire_version == other.max_wire_version)
and (self._me == other.me)
and (self._all_hosts == other.all_hosts)
and (self._tags == other.tags)
and (self._replica_set_name == other.replica_set_name)
and (self._set_version == other.set_version)
and (self._election_id == other.election_id)
and (self._primary == other.primary)
and (self._ls_timeout_minutes == other.logical_session_timeout_minutes)
and (self._error == other.error)
)
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __repr__(self) -> str:
errmsg = ""
if self.error:
errmsg = f", error={self.error!r}"
return "<{} {} server_type: {}, rtt: {}{}>".format(
self.__class__.__name__,
self.address,
self.server_type_name,
self.round_trip_time,
errmsg,
)
# For unittesting only. Use under no circumstances!
_host_to_round_trip_time: dict = {}
__doc__ = original_doc

View File

View File

@ -18,20 +18,22 @@ from __future__ import annotations
from collections.abc import Callable, Mapping, MutableMapping
from typing import TYPE_CHECKING, Any, Optional, Union
from pymongo import common
from pymongo.collation import validate_collation_or_none
from pymongo.errors import ConfigurationError
from pymongo.read_preferences import ReadPreference, _AggWritePref
from pymongo.synchronous import common
from pymongo.synchronous.collation import validate_collation_or_none
from pymongo.synchronous.read_preferences import ReadPreference, _AggWritePref
if TYPE_CHECKING:
from pymongo.client_session import ClientSession
from pymongo.collection import Collection
from pymongo.command_cursor import CommandCursor
from pymongo.database import Database
from pymongo.pool import Connection
from pymongo.read_preferences import _ServerMode
from pymongo.server import Server
from pymongo.typings import _DocumentType, _Pipeline
from pymongo.synchronous.client_session import ClientSession
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.pool import Connection
from pymongo.synchronous.read_preferences import _ServerMode
from pymongo.synchronous.server import Server
from pymongo.synchronous.typings import _DocumentType, _Pipeline
_IS_SYNC = True
class _AggregationCommand:

658
pymongo/synchronous/auth.py Normal file
View File

@ -0,0 +1,658 @@
# Copyright 2013-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Authentication helpers."""
from __future__ import annotations
import functools
import hashlib
import hmac
import os
import socket
import typing
from base64 import standard_b64decode, standard_b64encode
from collections import namedtuple
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Mapping,
MutableMapping,
Optional,
cast,
)
from urllib.parse import quote
from bson.binary import Binary
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.saslprep import saslprep
from pymongo.synchronous.auth_aws import _authenticate_aws
from pymongo.synchronous.auth_oidc import (
_authenticate_oidc,
_get_authenticator,
_OIDCAzureCallback,
_OIDCGCPCallback,
_OIDCProperties,
_OIDCTestCallback,
)
if TYPE_CHECKING:
from pymongo.synchronous.hello import Hello
from pymongo.synchronous.pool import Connection
HAVE_KERBEROS = True
_USE_PRINCIPAL = False
try:
import winkerberos as kerberos # type:ignore[import]
if tuple(map(int, kerberos.__version__.split(".")[:2])) >= (0, 5):
_USE_PRINCIPAL = True
except ImportError:
try:
import kerberos # type:ignore[import]
except ImportError:
HAVE_KERBEROS = False
_IS_SYNC = True
MECHANISMS = frozenset(
[
"GSSAPI",
"MONGODB-CR",
"MONGODB-OIDC",
"MONGODB-X509",
"MONGODB-AWS",
"PLAIN",
"SCRAM-SHA-1",
"SCRAM-SHA-256",
"DEFAULT",
]
)
"""The authentication mechanisms supported by PyMongo."""
class _Cache:
__slots__ = ("data",)
_hash_val = hash("_Cache")
def __init__(self) -> None:
self.data = None
def __eq__(self, other: object) -> bool:
# Two instances must always compare equal.
if isinstance(other, _Cache):
return True
return NotImplemented
def __ne__(self, other: object) -> bool:
if isinstance(other, _Cache):
return False
return NotImplemented
def __hash__(self) -> int:
return self._hash_val
MongoCredential = namedtuple(
"MongoCredential",
["mechanism", "source", "username", "password", "mechanism_properties", "cache"],
)
"""A hashable namedtuple of values used for authentication."""
GSSAPIProperties = namedtuple(
"GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"]
)
"""Mechanism properties for GSSAPI authentication."""
_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
"""Mechanism properties for MONGODB-AWS authentication."""
def _build_credentials_tuple(
mech: str,
source: Optional[str],
user: str,
passwd: str,
extra: Mapping[str, Any],
database: Optional[str],
) -> MongoCredential:
"""Build and return a mechanism specific credentials tuple."""
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
raise ConfigurationError(f"{mech} requires a username.")
if mech == "GSSAPI":
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for GSSAPI")
properties = extra.get("authmechanismproperties", {})
service_name = properties.get("SERVICE_NAME", "mongodb")
canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False))
service_realm = properties.get("SERVICE_REALM")
props = GSSAPIProperties(
service_name=service_name,
canonicalize_host_name=canonicalize,
service_realm=service_realm,
)
# Source is always $external.
return MongoCredential(mech, "$external", user, passwd, props, None)
elif mech == "MONGODB-X509":
if passwd is not None:
raise ConfigurationError("Passwords are not supported by MONGODB-X509")
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for MONGODB-X509")
# Source is always $external, user can be None.
return MongoCredential(mech, "$external", user, None, None, None)
elif mech == "MONGODB-AWS":
if user is not None and passwd is None:
raise ConfigurationError("username without a password is not supported by MONGODB-AWS")
if source is not None and source != "$external":
raise ConfigurationError(
"authentication source must be $external or None for MONGODB-AWS"
)
properties = extra.get("authmechanismproperties", {})
aws_session_token = properties.get("AWS_SESSION_TOKEN")
aws_props = _AWSProperties(aws_session_token=aws_session_token)
# user can be None for temporary link-local EC2 credentials.
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
elif mech == "MONGODB-OIDC":
properties = extra.get("authmechanismproperties", {})
callback = properties.get("OIDC_CALLBACK")
human_callback = properties.get("OIDC_HUMAN_CALLBACK")
environ = properties.get("ENVIRONMENT")
token_resource = properties.get("TOKEN_RESOURCE", "")
default_allowed = [
"*.mongodb.net",
"*.mongodb-dev.net",
"*.mongodb-qa.net",
"*.mongodbgov.net",
"localhost",
"127.0.0.1",
"::1",
]
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
msg = (
"authentication with MONGODB-OIDC requires providing either a callback or a environment"
)
if passwd is not None:
msg = "password is not supported by MONGODB-OIDC"
raise ConfigurationError(msg)
if callback or human_callback:
if environ is not None:
raise ConfigurationError(msg)
if callback and human_callback:
msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK"
raise ConfigurationError(msg)
elif environ is not None:
if environ == "test":
if user is not None:
msg = "test environment for MONGODB-OIDC does not support username"
raise ConfigurationError(msg)
callback = _OIDCTestCallback()
elif environ == "azure":
passwd = None
if not token_resource:
raise ConfigurationError(
"Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCAzureCallback(token_resource)
elif environ == "gcp":
passwd = None
if not token_resource:
raise ConfigurationError(
"GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
)
callback = _OIDCGCPCallback(token_resource)
else:
raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}")
else:
raise ConfigurationError(msg)
oidc_props = _OIDCProperties(
callback=callback,
human_callback=human_callback,
environment=environ,
allowed_hosts=allowed_hosts,
token_resource=token_resource,
username=user,
)
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
elif mech == "PLAIN":
source_database = source or database or "$external"
return MongoCredential(mech, source_database, user, passwd, None, None)
else:
source_database = source or database or "admin"
if passwd is None:
raise ConfigurationError("A password is required.")
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
def _xor(fir: bytes, sec: bytes) -> bytes:
"""XOR two byte strings together."""
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]:
"""Split a scram response into key, value pairs."""
return dict(
typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1))
for item in response.split(b",")
)
def _authenticate_scram_start(
credentials: MongoCredential, mechanism: str
) -> tuple[bytes, bytes, MutableMapping[str, Any]]:
username = credentials.username
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
nonce = standard_b64encode(os.urandom(32))
first_bare = b"n=" + user + b",r=" + nonce
cmd = {
"saslStart": 1,
"mechanism": mechanism,
"payload": Binary(b"n,," + first_bare),
"autoAuthorize": 1,
"options": {"skipEmptyExchange": True},
}
return nonce, first_bare, cmd
def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanism: str) -> None:
"""Authenticate using SCRAM."""
username = credentials.username
if mechanism == "SCRAM-SHA-256":
digest = "sha256"
digestmod = hashlib.sha256
data = saslprep(credentials.password).encode("utf-8")
else:
digest = "sha1"
digestmod = hashlib.sha1
data = _password_digest(username, credentials.password).encode("utf-8")
source = credentials.source
cache = credentials.cache
# Make local
_hmac = hmac.HMAC
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
assert isinstance(ctx, _ScramContext)
assert ctx.scram_data is not None
nonce, first_bare = ctx.scram_data
res = ctx.speculative_authenticate
else:
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
res = conn.command(source, cmd)
assert res is not None
server_first = res["payload"]
parsed = _parse_scram_response(server_first)
iterations = int(parsed[b"i"])
if iterations < 4096:
raise OperationFailure("Server returned an invalid iteration count.")
salt = parsed[b"s"]
rnonce = parsed[b"r"]
if not rnonce.startswith(nonce):
raise OperationFailure("Server returned an invalid nonce.")
without_proof = b"c=biws,r=" + rnonce
if cache.data:
client_key, server_key, csalt, citerations = cache.data
else:
client_key, server_key, csalt, citerations = None, None, None, None
# Salt and / or iterations could change for a number of different
# reasons. Either changing invalidates the cache.
if not client_key or salt != csalt or iterations != citerations:
salted_pass = hashlib.pbkdf2_hmac(digest, data, standard_b64decode(salt), iterations)
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
cache.data = (client_key, server_key, salt, iterations)
stored_key = digestmod(client_key).digest()
auth_msg = b",".join((first_bare, server_first, without_proof))
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()
client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig))
client_final = b",".join((without_proof, client_proof))
server_sig = standard_b64encode(_hmac(server_key, auth_msg, digestmod).digest())
cmd = {
"saslContinue": 1,
"conversationId": res["conversationId"],
"payload": Binary(client_final),
}
res = conn.command(source, cmd)
parsed = _parse_scram_response(res["payload"])
if not hmac.compare_digest(parsed[b"v"], server_sig):
raise OperationFailure("Server returned an invalid signature.")
# A third empty challenge may be required if the server does not support
# skipEmptyExchange: SERVER-44857.
if not res["done"]:
cmd = {
"saslContinue": 1,
"conversationId": res["conversationId"],
"payload": Binary(b""),
}
res = conn.command(source, cmd)
if not res["done"]:
raise OperationFailure("SASL conversation failed to complete.")
def _password_digest(username: str, password: str) -> str:
"""Get a password digest to use for authentication."""
if not isinstance(password, str):
raise TypeError("password must be an instance of str")
if len(password) == 0:
raise ValueError("password can't be empty")
if not isinstance(username, str):
raise TypeError("username must be an instance of str")
md5hash = hashlib.md5() # noqa: S324
data = f"{username}:mongo:{password}"
md5hash.update(data.encode("utf-8"))
return md5hash.hexdigest()
def _auth_key(nonce: str, username: str, password: str) -> str:
"""Get an auth key to use for authentication."""
digest = _password_digest(username, password)
md5hash = hashlib.md5() # noqa: S324
data = f"{nonce}{username}{digest}"
md5hash.update(data.encode("utf-8"))
return md5hash.hexdigest()
def _canonicalize_hostname(hostname: str) -> str:
"""Canonicalize hostname following MIT-krb5 behavior."""
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
)[0]
try:
name = socket.getnameinfo(sockaddr, socket.NI_NAMEREQD)
except socket.gaierror:
return canonname.lower()
return name[0].lower()
def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using GSSAPI."""
if not HAVE_KERBEROS:
raise ConfigurationError(
'The "kerberos" module must be installed to use GSSAPI authentication.'
)
try:
username = credentials.username
password = credentials.password
props = credentials.mechanism_properties
# Starting here and continuing through the while loop below - establish
# the security context. See RFC 4752, Section 3.1, first paragraph.
host = conn.address[0]
if props.canonicalize_host_name:
host = _canonicalize_hostname(host)
service = props.service_name + "@" + host
if props.service_realm is not None:
service = service + "@" + props.service_realm
if password is not None:
if _USE_PRINCIPAL:
# Note that, though we use unquote_plus for unquoting URI
# options, we use quote here. Microsoft's UrlUnescape (used
# by WinKerberos) doesn't support +.
principal = ":".join((quote(username), quote(password)))
result, ctx = kerberos.authGSSClientInit(
service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG
)
else:
if "@" in username:
user, domain = username.split("@", 1)
else:
user, domain = username, None
result, ctx = kerberos.authGSSClientInit(
service,
gssflags=kerberos.GSS_C_MUTUAL_FLAG,
user=user,
domain=domain,
password=password,
)
else:
result, ctx = kerberos.authGSSClientInit(service, gssflags=kerberos.GSS_C_MUTUAL_FLAG)
if result != kerberos.AUTH_GSS_COMPLETE:
raise OperationFailure("Kerberos context failed to initialize.")
try:
# pykerberos uses a weird mix of exceptions and return values
# to indicate errors.
# 0 == continue, 1 == complete, -1 == error
# Only authGSSClientStep can return 0.
if kerberos.authGSSClientStep(ctx, "") != 0:
raise OperationFailure("Unknown kerberos failure in step function.")
# Start a SASL conversation with mongod/s
# Note: pykerberos deals with base64 encoded byte strings.
# Since mongo accepts base64 strings as the payload we don't
# have to use bson.binary.Binary.
payload = kerberos.authGSSClientResponse(ctx)
cmd = {
"saslStart": 1,
"mechanism": "GSSAPI",
"payload": payload,
"autoAuthorize": 1,
}
response = conn.command("$external", cmd)
# Limit how many times we loop to catch protocol / library issues
for _ in range(10):
result = kerberos.authGSSClientStep(ctx, str(response["payload"]))
if result == -1:
raise OperationFailure("Unknown kerberos failure in step function.")
payload = kerberos.authGSSClientResponse(ctx) or ""
cmd = {
"saslContinue": 1,
"conversationId": response["conversationId"],
"payload": payload,
}
response = conn.command("$external", cmd)
if result == kerberos.AUTH_GSS_COMPLETE:
break
else:
raise OperationFailure("Kerberos authentication failed to complete.")
# Once the security context is established actually authenticate.
# See RFC 4752, Section 3.1, last two paragraphs.
if kerberos.authGSSClientUnwrap(ctx, str(response["payload"])) != 1:
raise OperationFailure("Unknown kerberos failure during GSS_Unwrap step.")
if kerberos.authGSSClientWrap(ctx, kerberos.authGSSClientResponse(ctx), username) != 1:
raise OperationFailure("Unknown kerberos failure during GSS_Wrap step.")
payload = kerberos.authGSSClientResponse(ctx)
cmd = {
"saslContinue": 1,
"conversationId": response["conversationId"],
"payload": payload,
}
conn.command("$external", cmd)
finally:
kerberos.authGSSClientClean(ctx)
except kerberos.KrbError as exc:
raise OperationFailure(str(exc)) from None
def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using SASL PLAIN (RFC 4616)"""
source = credentials.source
username = credentials.username
password = credentials.password
payload = (f"\x00{username}\x00{password}").encode()
cmd = {
"saslStart": 1,
"mechanism": "PLAIN",
"payload": Binary(payload),
"autoAuthorize": 1,
}
conn.command(source, cmd)
def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using MONGODB-X509."""
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
# MONGODB-X509 is done after the speculative auth step.
return
cmd = _X509Context(credentials, conn.address).speculate_command()
conn.command("$external", cmd)
def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using MONGODB-CR."""
source = credentials.source
username = credentials.username
password = credentials.password
# Get a nonce
response = conn.command(source, {"getnonce": 1})
nonce = response["nonce"]
key = _auth_key(nonce, username, password)
# Actually authenticate
query = {"authenticate": 1, "user": username, "nonce": nonce, "key": key}
conn.command(source, query)
def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None:
if conn.max_wire_version >= 7:
if conn.negotiated_mechs:
mechs = conn.negotiated_mechs
else:
source = credentials.source
cmd = conn.hello_cmd()
cmd["saslSupportedMechs"] = source + "." + credentials.username
mechs = (conn.command(source, cmd, publish_events=False)).get("saslSupportedMechs", [])
if "SCRAM-SHA-256" in mechs:
return _authenticate_scram(credentials, conn, "SCRAM-SHA-256")
else:
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
else:
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
_AUTH_MAP: Mapping[str, Callable[..., None]] = {
"GSSAPI": _authenticate_gssapi,
"MONGODB-CR": _authenticate_mongo_cr,
"MONGODB-X509": _authenticate_x509,
"MONGODB-AWS": _authenticate_aws,
"MONGODB-OIDC": _authenticate_oidc, # type:ignore[dict-item]
"PLAIN": _authenticate_plain,
"SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"),
"SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"),
"DEFAULT": _authenticate_default,
}
class _AuthContext:
def __init__(self, credentials: MongoCredential, address: tuple[str, int]) -> None:
self.credentials = credentials
self.speculative_authenticate: Optional[Mapping[str, Any]] = None
self.address = address
@staticmethod
def from_credentials(
creds: MongoCredential, address: tuple[str, int]
) -> Optional[_AuthContext]:
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
if spec_cls:
return cast(_AuthContext, spec_cls(creds, address))
return None
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
raise NotImplementedError
def parse_response(self, hello: Hello[Mapping[str, Any]]) -> None:
self.speculative_authenticate = hello.speculative_authenticate
def speculate_succeeded(self) -> bool:
return bool(self.speculative_authenticate)
class _ScramContext(_AuthContext):
def __init__(
self, credentials: MongoCredential, address: tuple[str, int], mechanism: str
) -> None:
super().__init__(credentials, address)
self.scram_data: Optional[tuple[bytes, bytes]] = None
self.mechanism = mechanism
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism)
# The 'db' field is included only on the speculative command.
cmd["db"] = self.credentials.source
# Save for later use.
self.scram_data = (nonce, first_bare)
return cmd
class _X509Context(_AuthContext):
def speculate_command(self) -> MutableMapping[str, Any]:
cmd = {"authenticate": 1, "mechanism": "MONGODB-X509"}
if self.credentials.username is not None:
cmd["user"] = self.credentials.username
return cmd
class _OIDCContext(_AuthContext):
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
authenticator = _get_authenticator(self.credentials, self.address)
cmd = authenticator.get_spec_auth_cmd()
if cmd is None:
return None
cmd["db"] = self.credentials.source
return cmd
_SPECULATIVE_AUTH_MAP: Mapping[str, Any] = {
"MONGODB-X509": _X509Context,
"SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"),
"SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
"MONGODB-OIDC": _OIDCContext,
"DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
}
def authenticate(
credentials: MongoCredential, conn: Connection, reauthenticate: bool = False
) -> None:
"""Authenticate connection."""
mechanism = credentials.mechanism
auth_func = _AUTH_MAP[mechanism]
if mechanism == "MONGODB-OIDC":
_authenticate_oidc(credentials, conn, reauthenticate)
else:
auth_func(credentials, conn)

View File

@ -23,8 +23,10 @@ from pymongo.errors import ConfigurationError, OperationFailure
if TYPE_CHECKING:
from bson.typings import _ReadableBuffer
from pymongo.auth import MongoCredential
from pymongo.pool import Connection
from pymongo.synchronous.auth import MongoCredential
from pymongo.synchronous.pool import Connection
_IS_SYNC = True
def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
@ -36,7 +38,6 @@ def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
"MONGODB-AWS authentication requires pymongo-auth-aws: "
"install with: python -m pip install 'pymongo[aws]'"
) from e
# Delayed import.
from pymongo_auth_aws.auth import ( # type:ignore[import]
set_cached_credentials,

View File

@ -0,0 +1,378 @@
# Copyright 2023-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MONGODB-OIDC Authentication helpers."""
from __future__ import annotations
import abc
import os
import threading
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
from urllib.parse import quote
import bson
from bson.binary import Binary
from pymongo._azure_helpers import _get_azure_response
from pymongo._csot import remaining
from pymongo._gcp_helpers import _get_gcp_response
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.helpers_constants import _AUTHENTICATION_FAILURE_CODE
if TYPE_CHECKING:
from pymongo.synchronous.auth import MongoCredential
from pymongo.synchronous.pool import Connection
_IS_SYNC = True
@dataclass
class OIDCIdPInfo:
issuer: str
clientId: Optional[str] = field(default=None)
requestScopes: Optional[list[str]] = field(default=None)
@dataclass
class OIDCCallbackContext:
timeout_seconds: float
username: str
version: int
refresh_token: Optional[str] = field(default=None)
idp_info: Optional[OIDCIdPInfo] = field(default=None)
@dataclass
class OIDCCallbackResult:
access_token: str
expires_in_seconds: Optional[float] = field(default=None)
refresh_token: Optional[str] = field(default=None)
class OIDCCallback(abc.ABC):
"""A base class for defining OIDC callbacks."""
@abc.abstractmethod
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
"""Convert the given BSON value into our own type."""
@dataclass
class _OIDCProperties:
callback: Optional[OIDCCallback] = field(default=None)
human_callback: Optional[OIDCCallback] = field(default=None)
environment: Optional[str] = field(default=None)
allowed_hosts: list[str] = field(default_factory=list)
token_resource: Optional[str] = field(default=None)
username: str = ""
"""Mechanism properties for MONGODB-OIDC authentication."""
TOKEN_BUFFER_MINUTES = 5
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
CALLBACK_VERSION = 1
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
TIME_BETWEEN_CALLS_SECONDS = 0.1
def _get_authenticator(
credentials: MongoCredential, address: tuple[str, int]
) -> _OIDCAuthenticator:
if credentials.cache.data:
return credentials.cache.data
# Extract values.
principal_name = credentials.username
properties = credentials.mechanism_properties
# Validate that the address is allowed.
if not properties.environment:
found = False
allowed_hosts = properties.allowed_hosts
for patt in allowed_hosts:
if patt == address[0]:
found = True
elif patt.startswith("*.") and address[0].endswith(patt[1:]):
found = True
if not found:
raise ConfigurationError(
f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
)
# Get or create the cache data.
credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties)
return credentials.cache.data
class _OIDCTestCallback(OIDCCallback):
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
token_file = os.environ.get("OIDC_TOKEN_FILE")
if not token_file:
raise RuntimeError(
'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set'
)
with open(token_file) as fid:
return OIDCCallbackResult(access_token=fid.read().strip())
class _OIDCAWSCallback(OIDCCallback):
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
if not token_file:
raise RuntimeError(
'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set'
)
with open(token_file) as fid:
return OIDCCallbackResult(access_token=fid.read().strip())
class _OIDCAzureCallback(OIDCCallback):
def __init__(self, token_resource: str) -> None:
self.token_resource = quote(token_resource)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds)
return OIDCCallbackResult(
access_token=resp["access_token"], expires_in_seconds=resp["expires_in"]
)
class _OIDCGCPCallback(OIDCCallback):
def __init__(self, token_resource: str) -> None:
self.token_resource = quote(token_resource)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
return OIDCCallbackResult(access_token=resp["access_token"])
@dataclass
class _OIDCAuthenticator:
username: str
properties: _OIDCProperties
refresh_token: Optional[str] = field(default=None)
access_token: Optional[str] = field(default=None)
idp_info: Optional[OIDCIdPInfo] = field(default=None)
token_gen_id: int = field(default=0)
lock: threading.Lock = field(default_factory=threading.Lock)
last_call_time: float = field(default=0)
def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
"""Handle a reauthenticate from the server."""
# Invalidate the token for the connection.
self._invalidate(conn)
# Call the appropriate auth logic for the callback type.
if self.properties.callback:
return self._authenticate_machine(conn)
return self._authenticate_human(conn)
def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
"""Handle an initial authenticate request."""
# First handle speculative auth.
# If it succeeded, we are done.
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
resp = ctx.speculative_authenticate
if resp and resp["done"]:
conn.oidc_token_gen_id = self.token_gen_id
return resp
# If spec auth failed, call the appropriate auth logic for the callback type.
# We cannot assume that the token is invalid, because a proxy may have been
# involved that stripped the speculative auth information.
if self.properties.callback:
return self._authenticate_machine(conn)
return self._authenticate_human(conn)
def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]:
"""Get the appropriate speculative auth command."""
if not self.access_token:
return None
return self._get_start_command({"jwt": self.access_token})
def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]:
# If there is a cached access token, try to authenticate with it. If
# authentication fails with error code 18, invalidate the access token,
# fetch a new access token, and try to authenticate again. If authentication
# fails for any other reason, raise the error to the user.
if self.access_token:
try:
return self._sasl_start_jwt(conn)
except OperationFailure as e:
if self._is_auth_error(e):
return self._authenticate_machine(conn)
raise
return self._sasl_start_jwt(conn)
def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]:
# If we have a cached access token, try a JwtStepRequest.
# authentication fails with error code 18, invalidate the access token,
# and try to authenticate again. If authentication fails for any other
# reason, raise the error to the user.
if self.access_token:
try:
return self._sasl_start_jwt(conn)
except OperationFailure as e:
if self._is_auth_error(e):
return self._authenticate_human(conn)
raise
# If we have a cached refresh token, try a JwtStepRequest with that.
# If authentication fails with error code 18, invalidate the access and
# refresh tokens, and try to authenticate again. If authentication fails for
# any other reason, raise the error to the user.
if self.refresh_token:
try:
return self._sasl_start_jwt(conn)
except OperationFailure as e:
if self._is_auth_error(e):
self.refresh_token = None
return self._authenticate_human(conn)
raise
# Start a new Two-Step SASL conversation.
# Run a PrincipalStepRequest to get the IdpInfo.
cmd = self._get_start_command(None)
start_resp = self._run_command(conn, cmd)
# Attempt to authenticate with a JwtStepRequest.
return self._sasl_continue_jwt(conn, start_resp)
def _get_access_token(self) -> Optional[str]:
properties = self.properties
cb: Union[None, OIDCCallback]
resp: OIDCCallbackResult
is_human = properties.human_callback is not None
if is_human and self.idp_info is None:
return None
if properties.callback:
cb = properties.callback
if properties.human_callback:
cb = properties.human_callback
prev_token = self.access_token
if prev_token:
return prev_token
if cb is None and not prev_token:
return None
if not prev_token and cb is not None:
with self.lock:
# See if the token was changed while we were waiting for the
# lock.
new_token = self.access_token
if new_token != prev_token:
return new_token
# Ensure that we are waiting a min time between callback invocations.
delta = time.time() - self.last_call_time
if delta < TIME_BETWEEN_CALLS_SECONDS:
time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
self.last_call_time = time.time()
if is_human:
timeout = HUMAN_CALLBACK_TIMEOUT_SECONDS
assert self.idp_info is not None
else:
timeout = int(remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS)
context = OIDCCallbackContext(
timeout_seconds=timeout,
version=CALLBACK_VERSION,
refresh_token=self.refresh_token,
idp_info=self.idp_info,
username=self.properties.username,
)
resp = cb.fetch(context)
if not isinstance(resp, OIDCCallbackResult):
raise ValueError("Callback result must be of type OIDCCallbackResult")
self.refresh_token = resp.refresh_token
self.access_token = resp.access_token
self.token_gen_id += 1
return self.access_token
def _run_command(self, conn: Connection, cmd: MutableMapping[str, Any]) -> Mapping[str, Any]:
try:
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
except OperationFailure as e:
if self._is_auth_error(e):
self._invalidate(conn)
raise
def _is_auth_error(self, err: Exception) -> bool:
if not isinstance(err, OperationFailure):
return False
return err.code == _AUTHENTICATION_FAILURE_CODE
def _invalidate(self, conn: Connection) -> None:
# Ignore the invalidation if a token gen id is given and is less than our
# current token gen id.
token_gen_id = conn.oidc_token_gen_id or 0
if token_gen_id is not None and token_gen_id < self.token_gen_id:
return
self.access_token = None
def _sasl_continue_jwt(
self, conn: Connection, start_resp: Mapping[str, Any]
) -> Mapping[str, Any]:
self.access_token = None
self.refresh_token = None
start_payload: dict = bson.decode(start_resp["payload"])
if "issuer" in start_payload:
self.idp_info = OIDCIdPInfo(**start_payload)
access_token = self._get_access_token()
conn.oidc_token_gen_id = self.token_gen_id
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
return self._run_command(conn, cmd)
def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]:
access_token = self._get_access_token()
conn.oidc_token_gen_id = self.token_gen_id
cmd = self._get_start_command({"jwt": access_token})
return self._run_command(conn, cmd)
def _get_start_command(self, payload: Optional[Mapping[str, Any]]) -> MutableMapping[str, Any]:
if payload is None:
principal_name = self.username
if principal_name:
payload = {"n": principal_name}
else:
payload = {}
bin_payload = Binary(bson.encode(payload))
return {"saslStart": 1, "mechanism": "MONGODB-OIDC", "payload": bin_payload}
def _get_continue_command(
self, payload: Mapping[str, Any], start_resp: Mapping[str, Any]
) -> MutableMapping[str, Any]:
bin_payload = Binary(bson.encode(payload))
return {
"saslContinue": 1,
"payload": bin_payload,
"conversationId": start_resp["conversationId"],
}
def _authenticate_oidc(
credentials: MongoCredential, conn: Connection, reauthenticate: bool
) -> Optional[Mapping[str, Any]]:
"""Authenticate using MONGODB-OIDC."""
authenticator = _get_authenticator(credentials, conn.address)
if reauthenticate:
return authenticator.reauthenticate(conn)
else:
return authenticator.authenticate(conn)

View File

@ -34,21 +34,23 @@ from typing import (
from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from pymongo import _csot, common
from pymongo.client_session import ClientSession, _validate_session_write_concern
from pymongo.common import (
validate_is_document_type,
validate_ok_for_replace,
validate_ok_for_update,
)
from pymongo import _csot
from pymongo.errors import (
BulkWriteError,
ConfigurationError,
InvalidOperation,
OperationFailure,
)
from pymongo.helpers import _RETRYABLE_ERROR_CODES, _get_wce_doc
from pymongo.message import (
from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES
from pymongo.synchronous import common
from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern
from pymongo.synchronous.common import (
validate_is_document_type,
validate_ok_for_replace,
validate_ok_for_update,
)
from pymongo.synchronous.helpers import _get_wce_doc
from pymongo.synchronous.message import (
_DELETE,
_INSERT,
_UPDATE,
@ -56,13 +58,15 @@ from pymongo.message import (
_EncryptedBulkWriteContext,
_randint,
)
from pymongo.read_preferences import ReadPreference
from pymongo.synchronous.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from pymongo.collection import Collection
from pymongo.pool import Connection
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.pool import Connection
from pymongo.synchronous.typings import _DocumentOut, _DocumentType, _Pipeline
_IS_SYNC = True
_DELETE_ALL: int = 0
_DELETE_ONE: int = 1
@ -449,7 +453,7 @@ class _Bulk:
)
client = self.collection.database.client
client._retryable_write(
_ = client._retryable_write(
self.is_retryable,
retryable_bulk,
session,

View File

@ -0,0 +1,497 @@
# Copyright 2017 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Watch changes on a collection, a database, or the entire cluster."""
from __future__ import annotations
import copy
from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
from bson import CodecOptions, _bson_to_dict
from bson.raw_bson import RawBSONDocument
from bson.timestamp import Timestamp
from pymongo import _csot
from pymongo.errors import (
ConnectionFailure,
CursorNotFound,
InvalidOperation,
OperationFailure,
PyMongoError,
)
from pymongo.synchronous import common
from pymongo.synchronous.aggregation import (
_AggregationCommand,
_CollectionAggregationCommand,
_DatabaseAggregationCommand,
)
from pymongo.synchronous.collation import validate_collation_or_none
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.operations import _Op
from pymongo.synchronous.typings import _CollationIn, _DocumentType, _Pipeline
_IS_SYNC = True
# The change streams spec considers the following server errors from the
# getMore command non-resumable. All other getMore errors are resumable.
_RESUMABLE_GETMORE_ERRORS = frozenset(
[
6, # HostUnreachable
7, # HostNotFound
89, # NetworkTimeout
91, # ShutdownInProgress
189, # PrimarySteppedDown
262, # ExceededTimeLimit
9001, # SocketException
10107, # NotWritablePrimary
11600, # InterruptedAtShutdown
11602, # InterruptedDueToReplStateChange
13435, # NotPrimaryNoSecondaryOk
13436, # NotPrimaryOrSecondary
63, # StaleShardVersion
150, # StaleEpoch
13388, # StaleConfig
234, # RetryChangeStream
133, # FailedToSatisfyReadPreference
]
)
if TYPE_CHECKING:
from pymongo.synchronous.client_session import ClientSession
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.pool import Connection
def _resumable(exc: PyMongoError) -> bool:
"""Return True if given a resumable change stream error."""
if isinstance(exc, (ConnectionFailure, CursorNotFound)):
return True
if isinstance(exc, OperationFailure):
if exc._max_wire_version is None:
return False
return (
exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError")
) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS)
return False
class ChangeStream(Generic[_DocumentType]):
"""The internal abstract base class for change stream cursors.
Should not be called directly by application developers. Use
:meth:`pymongo.collection.Collection.watch`,
:meth:`pymongo.database.Database.watch`, or
:meth:`pymongo.mongo_client.MongoClient.watch` instead.
.. versionadded:: 3.6
.. seealso:: The MongoDB documentation on `changeStreams <https://mongodb.com/docs/manual/changeStreams/>`_.
"""
def __init__(
self,
target: Union[
MongoClient[_DocumentType],
Database[_DocumentType],
Collection[_DocumentType],
],
pipeline: Optional[_Pipeline],
full_document: Optional[str],
resume_after: Optional[Mapping[str, Any]],
max_await_time_ms: Optional[int],
batch_size: Optional[int],
collation: Optional[_CollationIn],
start_at_operation_time: Optional[Timestamp],
session: Optional[ClientSession],
start_after: Optional[Mapping[str, Any]],
comment: Optional[Any] = None,
full_document_before_change: Optional[str] = None,
show_expanded_events: Optional[bool] = None,
) -> None:
if pipeline is None:
pipeline = []
pipeline = common.validate_list("pipeline", pipeline)
common.validate_string_or_none("full_document", full_document)
validate_collation_or_none(collation)
common.validate_non_negative_integer_or_none("batchSize", batch_size)
self._decode_custom = False
self._orig_codec_options: CodecOptions[_DocumentType] = target.codec_options
if target.codec_options.type_registry._decoder_map:
self._decode_custom = True
# Keep the type registry so that we support encoding custom types
# in the pipeline.
self._target = target.with_options( # type: ignore
codec_options=target.codec_options.with_options(document_class=RawBSONDocument)
)
else:
self._target = target
self._pipeline = copy.deepcopy(pipeline)
self._full_document = full_document
self._full_document_before_change = full_document_before_change
self._uses_start_after = start_after is not None
self._uses_resume_after = resume_after is not None
self._resume_token = copy.deepcopy(start_after or resume_after)
self._max_await_time_ms = max_await_time_ms
self._batch_size = batch_size
self._collation = collation
self._start_at_operation_time = start_at_operation_time
self._session = session
self._comment = comment
self._closed = False
self._timeout = self._target._timeout
self._show_expanded_events = show_expanded_events
def _initialize_cursor(self) -> None:
# Initialize cursor.
self._cursor = self._create_cursor()
@property
def _aggregation_command_class(self) -> Type[_AggregationCommand]:
"""The aggregation command class to be used."""
raise NotImplementedError
@property
def _client(self) -> MongoClient:
"""The client against which the aggregation commands for
this ChangeStream will be run.
"""
raise NotImplementedError
def _change_stream_options(self) -> dict[str, Any]:
"""Return the options dict for the $changeStream pipeline stage."""
options: dict[str, Any] = {}
if self._full_document is not None:
options["fullDocument"] = self._full_document
if self._full_document_before_change is not None:
options["fullDocumentBeforeChange"] = self._full_document_before_change
resume_token = self.resume_token
if resume_token is not None:
if self._uses_start_after:
options["startAfter"] = resume_token
else:
options["resumeAfter"] = resume_token
elif self._start_at_operation_time is not None:
options["startAtOperationTime"] = self._start_at_operation_time
if self._show_expanded_events:
options["showExpandedEvents"] = self._show_expanded_events
return options
def _command_options(self) -> dict[str, Any]:
"""Return the options dict for the aggregation command."""
options = {}
if self._max_await_time_ms is not None:
options["maxAwaitTimeMS"] = self._max_await_time_ms
if self._batch_size is not None:
options["batchSize"] = self._batch_size
return options
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
"""Return the full aggregation pipeline for this ChangeStream."""
options = self._change_stream_options()
full_pipeline: list = [{"$changeStream": options}]
full_pipeline.extend(self._pipeline)
return full_pipeline
def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None:
"""Callback that caches the postBatchResumeToken or
startAtOperationTime from a changeStream aggregate command response
containing an empty batch of change documents.
This is implemented as a callback because we need access to the wire
version in order to determine whether to cache this value.
"""
if not result["cursor"]["firstBatch"]:
if "postBatchResumeToken" in result["cursor"]:
self._resume_token = result["cursor"]["postBatchResumeToken"]
elif (
self._start_at_operation_time is None
and self._uses_resume_after is False
and self._uses_start_after is False
and conn.max_wire_version >= 7
):
self._start_at_operation_time = result.get("operationTime")
# PYTHON-2181: informative error on missing operationTime.
if self._start_at_operation_time is None:
raise OperationFailure(
"Expected field 'operationTime' missing from command "
f"response : {result!r}"
)
def _run_aggregation_cmd(
self, session: Optional[ClientSession], explicit_session: bool
) -> CommandCursor:
"""Run the full aggregation pipeline for this ChangeStream and return
the corresponding CommandCursor.
"""
cmd = self._aggregation_command_class(
self._target,
CommandCursor,
self._aggregation_pipeline(),
self._command_options(),
explicit_session,
result_processor=self._process_result,
comment=self._comment,
)
return self._client._retryable_read(
cmd.get_cursor,
self._target._read_preference_for(session),
session,
operation=_Op.AGGREGATE,
)
def _create_cursor(self) -> CommandCursor:
with self._client._tmp_session(self._session, close=False) as s:
return self._run_aggregation_cmd(session=s, explicit_session=self._session is not None)
def _resume(self) -> None:
"""Reestablish this change stream after a resumable error."""
try:
self._cursor.close()
except PyMongoError:
pass
self._cursor = self._create_cursor()
def close(self) -> None:
"""Close this ChangeStream."""
self._closed = True
self._cursor.close()
def __iter__(self) -> ChangeStream[_DocumentType]:
return self
@property
def resume_token(self) -> Optional[Mapping[str, Any]]:
"""The cached resume token that will be used to resume after the most
recently returned change.
.. versionadded:: 3.9
"""
return copy.deepcopy(self._resume_token)
@_csot.apply
def next(self) -> _DocumentType:
"""Advance the cursor.
This method blocks until the next change document is returned or an
unrecoverable error is raised. This method is used when iterating over
all changes in the cursor. For example::
try:
resume_token = None
pipeline = [{'$match': {'operationType': 'insert'}}]
async with db.collection.watch(pipeline) as stream:
async for insert_change in stream:
print(insert_change)
resume_token = stream.resume_token
except pymongo.errors.PyMongoError:
# The ChangeStream encountered an unrecoverable error or the
# resume attempt failed to recreate the cursor.
if resume_token is None:
# There is no usable resume token because there was a
# failure during ChangeStream initialization.
logging.error('...')
else:
# Use the interrupted ChangeStream's resume token to create
# a new ChangeStream. The new stream will continue from the
# last seen insert change without missing any events.
async with db.collection.watch(
pipeline, resume_after=resume_token) as stream:
async for insert_change in stream:
print(insert_change)
Raises :exc:`StopIteration` if this ChangeStream is closed.
"""
while self.alive:
doc = self.try_next()
if doc is not None:
return doc
raise StopIteration
__next__ = next
@property
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
.. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise
:exc:`StopIteration` and :meth:`try_next` can return ``None``.
.. versionadded:: 3.8
"""
return not self._closed
@_csot.apply
def try_next(self) -> Optional[_DocumentType]:
"""Advance the cursor without blocking indefinitely.
This method returns the next change document without waiting
indefinitely for the next change. For example::
async with db.collection.watch() as stream:
while stream.alive:
change = await stream.try_next()
# Note that the ChangeStream's resume token may be updated
# even when no changes are returned.
print("Current resume token: %r" % (stream.resume_token,))
if change is not None:
print("Change document: %r" % (change,))
continue
# We end up here when there are no recent changes.
# Sleep for a while before trying again to avoid flooding
# the server with getMore requests when no changes are
# available.
time.sleep(10)
If no change document is cached locally then this method runs a single
getMore command. If the getMore yields any documents, the next
document is returned, otherwise, if the getMore returns no documents
(because there have been no changes) then ``None`` is returned.
:return: The next change document or ``None`` when no document is available
after running a single getMore or when the cursor is closed.
.. versionadded:: 3.8
"""
if not self._closed and not self._cursor.alive:
self._resume()
# Attempt to get the next change with at most one getMore and at most
# one resume attempt.
try:
try:
change = self._cursor._try_next(True)
except PyMongoError as exc:
if not _resumable(exc):
raise
self._resume()
change = self._cursor._try_next(False)
except PyMongoError as exc:
# Close the stream after a fatal error.
if not _resumable(exc) and not exc.timeout:
self.close()
raise
except Exception:
self.close()
raise
# Check if the cursor was invalidated.
if not self._cursor.alive:
self._closed = True
# If no changes are available.
if change is None:
# We have either iterated over all documents in the cursor,
# OR the most-recently returned batch is empty. In either case,
# update the cached resume token with the postBatchResumeToken if
# one was returned. We also clear the startAtOperationTime.
if self._cursor._post_batch_resume_token is not None:
self._resume_token = self._cursor._post_batch_resume_token
self._start_at_operation_time = None
return change
# Else, changes are available.
try:
resume_token = change["_id"]
except KeyError:
self.close()
raise InvalidOperation(
"Cannot provide resume functionality when the resume token is missing."
) from None
# If this is the last change document from the current batch, cache the
# postBatchResumeToken.
if not self._cursor._has_next() and self._cursor._post_batch_resume_token:
resume_token = self._cursor._post_batch_resume_token
# Hereafter, don't use startAfter; instead use resumeAfter.
self._uses_start_after = False
self._uses_resume_after = True
# Cache the resume token and clear startAtOperationTime.
self._resume_token = resume_token
self._start_at_operation_time = None
if self._decode_custom:
return _bson_to_dict(change.raw, self._orig_codec_options)
return change
def __enter__(self) -> ChangeStream[_DocumentType]:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
class CollectionChangeStream(ChangeStream[_DocumentType]):
"""A change stream that watches changes on a single collection.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.collection.Collection.watch` instead.
.. versionadded:: 3.7
"""
_target: Collection[_DocumentType]
@property
def _aggregation_command_class(self) -> Type[_CollectionAggregationCommand]:
return _CollectionAggregationCommand
@property
def _client(self) -> MongoClient[_DocumentType]:
return self._target.database.client
class DatabaseChangeStream(ChangeStream[_DocumentType]):
"""A change stream that watches changes on all collections in a database.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.database.Database.watch` instead.
.. versionadded:: 3.7
"""
_target: Database[_DocumentType]
@property
def _aggregation_command_class(self) -> Type[_DatabaseAggregationCommand]:
return _DatabaseAggregationCommand
@property
def _client(self) -> MongoClient[_DocumentType]:
return self._target.client
class ClusterChangeStream(DatabaseChangeStream[_DocumentType]):
"""A change stream that watches changes on all collections in the cluster.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.mongo_client.MongoClient.watch` instead.
.. versionadded:: 3.7
"""
def _change_stream_options(self) -> dict[str, Any]:
options = super()._change_stream_options()
options["allChangesForCluster"] = True
return options

View File

@ -0,0 +1,334 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Tools to parse mongo client options."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast
from bson.codec_options import _parse_codec_options
from pymongo.errors import ConfigurationError
from pymongo.read_concern import ReadConcern
from pymongo.ssl_support import get_ssl_context
from pymongo.synchronous import common
from pymongo.synchronous.compression_support import CompressionSettings
from pymongo.synchronous.monitoring import _EventListener, _EventListeners
from pymongo.synchronous.pool import PoolOptions
from pymongo.synchronous.read_preferences import (
_ServerMode,
make_read_preference,
read_pref_mode_from_name,
)
from pymongo.synchronous.server_selectors import any_server_selector
from pymongo.write_concern import WriteConcern, validate_boolean
if TYPE_CHECKING:
from bson.codec_options import CodecOptions
from pymongo.pyopenssl_context import SSLContext
from pymongo.synchronous.auth import MongoCredential
from pymongo.synchronous.encryption_options import AutoEncryptionOpts
from pymongo.synchronous.topology_description import _ServerSelector
_IS_SYNC = True
def _parse_credentials(
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
) -> Optional[MongoCredential]:
"""Parse authentication credentials."""
mechanism = options.get("authmechanism", "DEFAULT" if username else None)
source = options.get("authsource")
if username or mechanism:
from pymongo.synchronous.auth import _build_credentials_tuple
return _build_credentials_tuple(mechanism, source, username, password, options, database)
return None
def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode:
"""Parse read preference options."""
if "read_preference" in options:
return options["read_preference"]
name = options.get("readpreference", "primary")
mode = read_pref_mode_from_name(name)
tags = options.get("readpreferencetags")
max_staleness = options.get("maxstalenessseconds", -1)
return make_read_preference(mode, tags, max_staleness)
def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern:
"""Parse write concern options."""
concern = options.get("w")
wtimeout = options.get("wtimeoutms")
j = options.get("journal")
fsync = options.get("fsync")
return WriteConcern(concern, wtimeout, j, fsync)
def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
"""Parse read concern options."""
concern = options.get("readconcernlevel")
return ReadConcern(concern)
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
"""Parse ssl options."""
use_tls = options.get("tls")
if use_tls is not None:
validate_boolean("tls", use_tls)
certfile = options.get("tlscertificatekeyfile")
passphrase = options.get("tlscertificatekeyfilepassword")
ca_certs = options.get("tlscafile")
crlfile = options.get("tlscrlfile")
allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False)
allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False)
disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False)
enabled_tls_opts = []
for opt in (
"tlscertificatekeyfile",
"tlscertificatekeyfilepassword",
"tlscafile",
"tlscrlfile",
):
# Any non-null value of these options implies tls=True.
if opt in options and options[opt]:
enabled_tls_opts.append(opt)
for opt in (
"tlsallowinvalidcertificates",
"tlsallowinvalidhostnames",
"tlsdisableocspendpointcheck",
):
# A value of False for these options implies tls=True.
if opt in options and not options[opt]:
enabled_tls_opts.append(opt)
if enabled_tls_opts:
if use_tls is None:
# Implicitly enable TLS when one of the tls* options is set.
use_tls = True
elif not use_tls:
# Error since tls is explicitly disabled but a tls option is set.
raise ConfigurationError(
"TLS has not been enabled but the "
"following tls parameters have been set: "
"%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts)
)
if use_tls:
ctx = get_ssl_context(
certfile,
passphrase,
ca_certs,
crlfile,
allow_invalid_certificates,
allow_invalid_hostnames,
disable_ocsp_endpoint_check,
)
return ctx, allow_invalid_hostnames
return None, allow_invalid_hostnames
def _parse_pool_options(
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
) -> PoolOptions:
"""Parse connection pool options."""
credentials = _parse_credentials(username, password, database, options)
max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE)
min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE)
max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC)
if max_pool_size is not None and min_pool_size > max_pool_size:
raise ValueError("minPoolSize must be smaller or equal to maxPoolSize")
connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT)
socket_timeout = options.get("sockettimeoutms")
wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT)
event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners"))
appname = options.get("appname")
driver = options.get("driver")
server_api = options.get("server_api")
compression_settings = CompressionSettings(
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
)
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
load_balanced = options.get("loadbalanced")
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
return PoolOptions(
max_pool_size,
min_pool_size,
max_idle_time_seconds,
connect_timeout,
socket_timeout,
wait_queue_timeout,
ssl_context,
tls_allow_invalid_hostnames,
_EventListeners(event_listeners),
appname,
driver,
compression_settings,
max_connecting=max_connecting,
server_api=server_api,
load_balanced=load_balanced,
credentials=credentials,
)
class ClientOptions:
"""Read only configuration options for a MongoClient.
Should not be instantiated directly by application developers. Access
a client's options via :attr:`pymongo.mongo_client.MongoClient.options`
instead.
"""
def __init__(
self, username: str, password: str, database: Optional[str], options: Mapping[str, Any]
):
self.__options = options
self.__codec_options = _parse_codec_options(options)
self.__direct_connection = options.get("directconnection")
self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS)
# self.__server_selection_timeout is in seconds. Must use full name for
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
self.__server_selection_timeout = options.get(
"serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT
)
self.__pool_options = _parse_pool_options(username, password, database, options)
self.__read_preference = _parse_read_preference(options)
self.__replica_set_name = options.get("replicaset")
self.__write_concern = _parse_write_concern(options)
self.__read_concern = _parse_read_concern(options)
self.__connect = options.get("connect")
self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY)
self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES)
self.__retry_reads = options.get("retryreads", common.RETRY_READS)
self.__server_selector = options.get("server_selector", any_server_selector)
self.__auto_encryption_opts = options.get("auto_encryption_opts")
self.__load_balanced = options.get("loadbalanced")
self.__timeout = options.get("timeoutms")
self.__server_monitoring_mode = options.get(
"servermonitoringmode", common.SERVER_MONITORING_MODE
)
@property
def _options(self) -> Mapping[str, Any]:
"""The original options used to create this ClientOptions."""
return self.__options
@property
def connect(self) -> Optional[bool]:
"""Whether to begin discovering a MongoDB topology automatically."""
return self.__connect
@property
def codec_options(self) -> CodecOptions:
"""A :class:`~bson.codec_options.CodecOptions` instance."""
return self.__codec_options
@property
def direct_connection(self) -> Optional[bool]:
"""Whether to connect to the deployment in 'Single' topology."""
return self.__direct_connection
@property
def local_threshold_ms(self) -> int:
"""The local threshold for this instance."""
return self.__local_threshold_ms
@property
def server_selection_timeout(self) -> int:
"""The server selection timeout for this instance in seconds."""
return self.__server_selection_timeout
@property
def server_selector(self) -> _ServerSelector:
return self.__server_selector
@property
def heartbeat_frequency(self) -> int:
"""The monitoring frequency in seconds."""
return self.__heartbeat_frequency
@property
def pool_options(self) -> PoolOptions:
"""A :class:`~pymongo.pool.PoolOptions` instance."""
return self.__pool_options
@property
def read_preference(self) -> _ServerMode:
"""A read preference instance."""
return self.__read_preference
@property
def replica_set_name(self) -> Optional[str]:
"""Replica set name or None."""
return self.__replica_set_name
@property
def write_concern(self) -> WriteConcern:
"""A :class:`~pymongo.write_concern.WriteConcern` instance."""
return self.__write_concern
@property
def read_concern(self) -> ReadConcern:
"""A :class:`~pymongo.read_concern.ReadConcern` instance."""
return self.__read_concern
@property
def timeout(self) -> Optional[float]:
"""The configured timeoutMS converted to seconds, or None.
.. versionadded:: 4.2
"""
return self.__timeout
@property
def retry_writes(self) -> bool:
"""If this instance should retry supported write operations."""
return self.__retry_writes
@property
def retry_reads(self) -> bool:
"""If this instance should retry supported read operations."""
return self.__retry_reads
@property
def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]:
"""A :class:`~pymongo.encryption.AutoEncryptionOpts` or None."""
return self.__auto_encryption_opts
@property
def load_balanced(self) -> Optional[bool]:
"""True if the client was configured to connect to a load balancer."""
return self.__load_balanced
@property
def event_listeners(self) -> list[_EventListeners]:
"""The event listeners registered for this client.
See :mod:`~pymongo.monitoring` for details.
.. versionadded:: 4.0
"""
assert self.__pool_options._event_listeners is not None
return self.__pool_options._event_listeners.event_listeners()
@property
def server_monitoring_mode(self) -> str:
"""The configured serverMonitoringMode option.
.. versionadded:: 4.5
"""
return self.__server_monitoring_mode

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,226 @@
# Copyright 2016 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for working with `collations`_.
.. _collations: https://www.mongodb.com/docs/manual/reference/collation/
"""
from __future__ import annotations
from typing import Any, Mapping, Optional, Union
from pymongo.synchronous import common
from pymongo.write_concern import validate_boolean
_IS_SYNC = True
class CollationStrength:
"""
An enum that defines values for `strength` on a
:class:`~pymongo.collation.Collation`.
"""
PRIMARY = 1
"""Differentiate base (unadorned) characters."""
SECONDARY = 2
"""Differentiate character accents."""
TERTIARY = 3
"""Differentiate character case."""
QUATERNARY = 4
"""Differentiate words with and without punctuation."""
IDENTICAL = 5
"""Differentiate unicode code point (characters are exactly identical)."""
class CollationAlternate:
"""
An enum that defines values for `alternate` on a
:class:`~pymongo.collation.Collation`.
"""
NON_IGNORABLE = "non-ignorable"
"""Spaces and punctuation are treated as base characters."""
SHIFTED = "shifted"
"""Spaces and punctuation are *not* considered base characters.
Spaces and punctuation are distinguished regardless when the
:class:`~pymongo.collation.Collation` strength is at least
:data:`~pymongo.collation.CollationStrength.QUATERNARY`.
"""
class CollationMaxVariable:
"""
An enum that defines values for `max_variable` on a
:class:`~pymongo.collation.Collation`.
"""
PUNCT = "punct"
"""Both punctuation and spaces are ignored."""
SPACE = "space"
"""Spaces alone are ignored."""
class CollationCaseFirst:
"""
An enum that defines values for `case_first` on a
:class:`~pymongo.collation.Collation`.
"""
UPPER = "upper"
"""Sort uppercase characters first."""
LOWER = "lower"
"""Sort lowercase characters first."""
OFF = "off"
"""Default for locale or collation strength."""
class Collation:
"""Collation
:param locale: (string) The locale of the collation. This should be a string
that identifies an `ICU locale ID` exactly. For example, ``en_US`` is
valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB
documentation for a list of supported locales.
:param caseLevel: (optional) If ``True``, turn on case sensitivity if
`strength` is 1 or 2 (case sensitivity is implied if `strength` is
greater than 2). Defaults to ``False``.
:param caseFirst: (optional) Specify that either uppercase or lowercase
characters take precedence. Must be one of the following values:
* :data:`~CollationCaseFirst.UPPER`
* :data:`~CollationCaseFirst.LOWER`
* :data:`~CollationCaseFirst.OFF` (the default)
:param strength: Specify the comparison strength. This is also
known as the ICU comparison level. This must be one of the following
values:
* :data:`~CollationStrength.PRIMARY`
* :data:`~CollationStrength.SECONDARY`
* :data:`~CollationStrength.TERTIARY` (the default)
* :data:`~CollationStrength.QUATERNARY`
* :data:`~CollationStrength.IDENTICAL`
Each successive level builds upon the previous. For example, a
`strength` of :data:`~CollationStrength.SECONDARY` differentiates
characters based both on the unadorned base character and its accents.
:param numericOrdering: If ``True``, order numbers numerically
instead of in collation order (defaults to ``False``).
:param alternate: Specify whether spaces and punctuation are
considered base characters. This must be one of the following values:
* :data:`~CollationAlternate.NON_IGNORABLE` (the default)
* :data:`~CollationAlternate.SHIFTED`
:param maxVariable: When `alternate` is
:data:`~CollationAlternate.SHIFTED`, this option specifies what
characters may be ignored. This must be one of the following values:
* :data:`~CollationMaxVariable.PUNCT` (the default)
* :data:`~CollationMaxVariable.SPACE`
:param normalization: If ``True``, normalizes text into Unicode
NFD. Defaults to ``False``.
:param backwards: If ``True``, accents on characters are
considered from the back of the word to the front, as it is done in some
French dictionary ordering traditions. Defaults to ``False``.
:param kwargs: Keyword arguments supplying any additional options
to be sent with this Collation object.
.. versionadded: 3.4
"""
__slots__ = ("__document",)
def __init__(
self,
locale: str,
caseLevel: Optional[bool] = None,
caseFirst: Optional[str] = None,
strength: Optional[int] = None,
numericOrdering: Optional[bool] = None,
alternate: Optional[str] = None,
maxVariable: Optional[str] = None,
normalization: Optional[bool] = None,
backwards: Optional[bool] = None,
**kwargs: Any,
) -> None:
locale = common.validate_string("locale", locale)
self.__document: dict[str, Any] = {"locale": locale}
if caseLevel is not None:
self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel)
if caseFirst is not None:
self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst)
if strength is not None:
self.__document["strength"] = common.validate_integer("strength", strength)
if numericOrdering is not None:
self.__document["numericOrdering"] = validate_boolean(
"numericOrdering", numericOrdering
)
if alternate is not None:
self.__document["alternate"] = common.validate_string("alternate", alternate)
if maxVariable is not None:
self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable)
if normalization is not None:
self.__document["normalization"] = validate_boolean("normalization", normalization)
if backwards is not None:
self.__document["backwards"] = validate_boolean("backwards", backwards)
self.__document.update(kwargs)
@property
def document(self) -> dict[str, Any]:
"""The document representation of this collation.
.. note::
:class:`Collation` is immutable. Mutating the value of
:attr:`document` does not mutate this :class:`Collation`.
"""
return self.__document.copy()
def __repr__(self) -> str:
document = self.document
return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document))
def __eq__(self, other: Any) -> bool:
if isinstance(other, Collation):
return self.document == other.document
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def validate_collation_or_none(
value: Optional[Union[Mapping[str, Any], Collation]]
) -> Optional[dict[str, Any]]:
if value is None:
return None
if isinstance(value, Collation):
return value.document
if isinstance(value, dict):
return value
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,415 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CommandCursor class to iterate over command results."""
from __future__ import annotations
from collections import deque
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterator,
Mapping,
NoReturn,
Optional,
Sequence,
Union,
)
from bson import CodecOptions, _convert_raw_document_lists_to_streams
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.synchronous.cursor import _ConnectionManager
from pymongo.synchronous.message import (
_CursorAddress,
_GetMore,
_OpMsg,
_OpReply,
_RawBatchGetMore,
)
from pymongo.synchronous.response import PinnedResponse
from pymongo.synchronous.typings import _Address, _DocumentOut, _DocumentType
if TYPE_CHECKING:
from pymongo.synchronous.client_session import ClientSession
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.pool import Connection
_IS_SYNC = True
class CommandCursor(Generic[_DocumentType]):
"""A cursor / iterator over command cursors."""
_getmore_class = _GetMore
def __init__(
self,
collection: Collection[_DocumentType],
cursor_info: Mapping[str, Any],
address: Optional[_Address],
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[ClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new command cursor."""
self._sock_mgr: Any = None
self._collection: Collection[_DocumentType] = collection
self._id = cursor_info["id"]
self._data = deque(cursor_info["firstBatch"])
self._postbatchresumetoken: Optional[Mapping[str, Any]] = cursor_info.get(
"postBatchResumeToken"
)
self._address = address
self._batch_size = batch_size
self._max_await_time_ms = max_await_time_ms
self._session = session
self._explicit_session = explicit_session
self._killed = self._id == 0
self._comment = comment
if _IS_SYNC and self._killed:
self._end_session(True) # type: ignore[unused-coroutine]
if "ns" in cursor_info: # noqa: SIM401
self._ns = cursor_info["ns"]
else:
self._ns = collection.full_name
self.batch_size(batch_size)
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
raise TypeError("max_await_time_ms must be an integer or None")
def __del__(self) -> None:
if _IS_SYNC:
self._die(False) # type: ignore[unused-coroutine]
def batch_size(self, batch_size: int) -> CommandCursor[_DocumentType]:
"""Limits the number of documents returned in one batch. Each batch
requires a round trip to the server. It can be adjusted to optimize
performance and limit data transfer.
.. note:: batch_size can not override MongoDB's internal limits on the
amount of data it will return to the client in a single batch (i.e
if you set batch size to 1,000,000,000, MongoDB will currently only
return 4-16MB of results per batch).
Raises :exc:`TypeError` if `batch_size` is not an integer.
Raises :exc:`ValueError` if `batch_size` is less than ``0``.
:param batch_size: The size of each batch of results requested.
"""
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")
if batch_size < 0:
raise ValueError("batch_size must be >= 0")
self._batch_size = batch_size == 1 and 2 or batch_size
return self
def _has_next(self) -> bool:
"""Returns `True` if the cursor has documents remaining from the
previous batch.
"""
return len(self._data) > 0
@property
def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]:
"""Retrieve the postBatchResumeToken from the response to a
changeStream aggregate or getMore.
"""
return self._postbatchresumetoken
def _maybe_pin_connection(self, conn: Connection) -> None:
client = self._collection.database.client
if not client._should_pin_cursor(self._session):
return
if not self._sock_mgr:
conn.pin_cursor()
conn_mgr = _ConnectionManager(conn, False)
# Ensure the connection gets returned when the entire result is
# returned in the first batch.
if self._id == 0:
conn_mgr.close()
else:
self._sock_mgr = conn_mgr
def _unpack_response(
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions[Mapping[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> Sequence[_DocumentOut]:
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
@property
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
Even if :attr:`alive` is ``True``, :meth:`next` can raise
:exc:`StopIteration`. Best to use a for loop::
async for doc in collection.aggregate(pipeline):
print(doc)
.. note:: :attr:`alive` can be True while iterating a cursor from
a failed server. In this case :attr:`alive` will return False after
:meth:`next` fails to retrieve the next batch of results from the
server.
"""
return bool(len(self._data) or (not self._killed))
@property
def cursor_id(self) -> int:
"""Returns the id of the cursor."""
return self._id
@property
def address(self) -> Optional[_Address]:
"""The (host, port) of the server used, or None.
.. versionadded:: 3.0
"""
return self._address
@property
def session(self) -> Optional[ClientSession]:
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
.. versionadded:: 3.6
"""
if self._explicit_session:
return self._session
return None
def _die(self, synchronous: bool = False) -> None:
"""Closes this cursor."""
already_killed = self._killed
self._killed = True
if self._id and not already_killed:
cursor_id = self._id
assert self._address is not None
address = _CursorAddress(self._address, self._ns)
else:
# Skip killCursors.
cursor_id = 0
address = None
self._collection.database.client._cleanup_cursor(
synchronous,
cursor_id,
address,
self._sock_mgr,
self._session,
self._explicit_session,
)
if not self._explicit_session:
self._session = None
self._sock_mgr = None
def _end_session(self, synchronous: bool) -> None:
if self._session and not self._explicit_session:
self._session._end_session(lock=synchronous)
self._session = None
def close(self) -> None:
"""Explicitly close / kill this cursor."""
self._die(True)
def _send_message(self, operation: _GetMore) -> None:
"""Send a getmore message and handle the response."""
client = self._collection.database.client
try:
response = client._run_operation(
operation, self._unpack_response, address=self._address
)
except OperationFailure as exc:
if exc.code in _CURSOR_CLOSED_ERRORS:
# Don't send killCursors because the cursor is already closed.
self._killed = True
if exc.timeout:
self._die(False)
else:
# Return the session and pinned connection, if necessary.
self.close()
raise
except ConnectionFailure:
# Don't send killCursors because the cursor is already closed.
self._killed = True
# Return the session and pinned connection, if necessary.
self.close()
raise
except Exception:
self.close()
raise
if isinstance(response, PinnedResponse):
if not self._sock_mgr:
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
if response.from_command:
cursor = response.docs[0]["cursor"]
documents = cursor["nextBatch"]
self._postbatchresumetoken = cursor.get("postBatchResumeToken")
self._id = cursor["id"]
else:
documents = response.docs
assert isinstance(response.data, _OpReply)
self._id = response.data.cursor_id
if self._id == 0:
self.close()
self._data = deque(documents)
def _refresh(self) -> int:
"""Refreshes the cursor with more data from the server.
Returns the length of self._data after refresh. Will exit early if
self._data is already non-empty. Raises OperationFailure when the
cursor cannot be refreshed due to an error on the query.
"""
if len(self._data) or self._killed:
return len(self._data)
if self._id: # Get More
dbname, collname = self._ns.split(".", 1)
read_pref = self._collection._read_preference_for(self.session)
self._send_message(
self._getmore_class(
dbname,
collname,
self._batch_size,
self._id,
self._collection.codec_options,
read_pref,
self._session,
self._collection.database.client,
self._max_await_time_ms,
self._sock_mgr,
False,
self._comment,
)
)
else: # Cursor id is zero nothing else to return
self._die(True)
return len(self._data)
def __iter__(self) -> Iterator[_DocumentType]:
return self
def next(self) -> _DocumentType:
"""Advance the cursor."""
# Block until a document is returnable.
while self.alive:
doc = self._try_next(True)
if doc is not None:
return doc
raise StopIteration
def __next__(self) -> _DocumentType:
return self.next()
def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
"""Advance the cursor blocking for at most one getMore command."""
if not len(self._data) and not self._killed and get_more_allowed:
self._refresh()
if len(self._data):
return self._data.popleft()
else:
return None
def try_next(self) -> Optional[_DocumentType]:
"""Advance the cursor without blocking indefinitely.
This method returns the next document without waiting
indefinitely for data.
If no document is cached locally then this method runs a single
getMore command. If the getMore yields any documents, the next
document is returned, otherwise, if the getMore returns no documents
(because there is no additional data) then ``None`` is returned.
:return: The next document or ``None`` when no document is available
after running a single getMore or when the cursor is closed.
.. versionadded:: 4.5
"""
return self._try_next(get_more_allowed=True)
def __enter__(self) -> CommandCursor[_DocumentType]:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
def to_list(self) -> list[_DocumentType]:
return [x for x in self] # noqa: C416,RUF100
class RawBatchCommandCursor(CommandCursor[_DocumentType]):
_getmore_class = _RawBatchGetMore
def __init__(
self,
collection: Collection[_DocumentType],
cursor_info: Mapping[str, Any],
address: Optional[_Address],
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[ClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new cursor / iterator over raw batches of BSON data.
Should not be called directly by application developers -
see :meth:`~pymongo.collection.Collection.aggregate_raw_batches`
instead.
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
"""
assert not cursor_info.get("firstBatch")
super().__init__(
collection,
cursor_info,
address,
batch_size,
max_await_time_ms,
session,
explicit_session,
comment,
)
def _unpack_response( # type: ignore[override]
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> list[Mapping[str, Any]]:
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
if not legacy_response:
# OP_MSG returns firstBatch/nextBatch documents as a BSON array
# Re-assemble the array of documents into a document stream
_convert_raw_document_lists_to_streams(raw_response[0])
return raw_response # type: ignore[return-value]
def __getitem__(self, index: int) -> NoReturn:
raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor")

View File

@ -40,20 +40,22 @@ from bson import SON
from bson.binary import UuidRepresentation
from bson.codec_options import CodecOptions, DatetimeConversion, TypeRegistry
from bson.raw_bson import RawBSONDocument
from pymongo.compression_support import (
from pymongo.driver_info import DriverInfo
from pymongo.errors import ConfigurationError
from pymongo.read_concern import ReadConcern
from pymongo.server_api import ServerApi
from pymongo.synchronous.compression_support import (
validate_compressors,
validate_zlib_compression_level,
)
from pymongo.driver_info import DriverInfo
from pymongo.errors import ConfigurationError
from pymongo.monitoring import _validate_event_listeners
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _MONGOS_MODES, _ServerMode
from pymongo.server_api import ServerApi
from pymongo.synchronous.monitoring import _validate_event_listeners
from pymongo.synchronous.read_preferences import _MONGOS_MODES, _ServerMode
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean
if TYPE_CHECKING:
from pymongo.client_session import ClientSession
from pymongo.synchronous.client_session import ClientSession
_IS_SYNC = True
ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict)
@ -378,7 +380,7 @@ def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode:
def validate_auth_mechanism(option: str, value: Any) -> str:
"""Validate the authMechanism URI option."""
from pymongo.auth import MECHANISMS
from pymongo.synchronous.auth import MECHANISMS
if value not in MECHANISMS:
raise ValueError(f"{option} must be in {tuple(MECHANISMS)}")
@ -444,7 +446,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
elif key in ["ALLOWED_HOSTS"] and isinstance(value, list):
props[key] = value
elif key in ["OIDC_CALLBACK", "OIDC_HUMAN_CALLBACK"]:
from pymongo.auth_oidc import OIDCCallback
from pymongo.synchronous.auth_oidc import OIDCCallback
if not isinstance(value, OIDCCallback):
raise ValueError("callback must be an OIDCCallback object")
@ -640,7 +642,7 @@ def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[A
"""Validate the driver keyword arg."""
if value is None:
return value
from pymongo.encryption_options import AutoEncryptionOpts
from pymongo.synchronous.encryption_options import AutoEncryptionOpts
if not isinstance(value, AutoEncryptionOpts):
raise TypeError(f"{option} must be an instance of AutoEncryptionOpts")
@ -902,7 +904,7 @@ class BaseObject:
) -> None:
if not isinstance(codec_options, CodecOptions):
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
self.__codec_options = codec_options
self._codec_options = codec_options
if not isinstance(read_preference, _ServerMode):
raise TypeError(
@ -910,24 +912,24 @@ class BaseObject:
"pymongo.read_preferences for valid "
"options."
)
self.__read_preference = read_preference
self._read_preference = read_preference
if not isinstance(write_concern, WriteConcern):
raise TypeError(
"write_concern must be an instance of pymongo.write_concern.WriteConcern"
)
self.__write_concern = write_concern
self._write_concern = write_concern
if not isinstance(read_concern, ReadConcern):
raise TypeError("read_concern must be an instance of pymongo.read_concern.ReadConcern")
self.__read_concern = read_concern
self._read_concern = read_concern
@property
def codec_options(self) -> CodecOptions:
"""Read only access to the :class:`~bson.codec_options.CodecOptions`
of this instance.
"""
return self.__codec_options
return self._codec_options
@property
def write_concern(self) -> WriteConcern:
@ -937,7 +939,7 @@ class BaseObject:
.. versionchanged:: 3.0
The :attr:`write_concern` attribute is now read only.
"""
return self.__write_concern
return self._write_concern
def _write_concern_for(self, session: Optional[ClientSession]) -> WriteConcern:
"""Read only access to the write concern of this instance or session."""
@ -953,14 +955,14 @@ class BaseObject:
.. versionchanged:: 3.0
The :attr:`read_preference` attribute is now read only.
"""
return self.__read_preference
return self._read_preference
def _read_preference_for(self, session: Optional[ClientSession]) -> _ServerMode:
"""Read only access to the read preference of this instance or session."""
# Override this operation's read preference with the transaction's.
if session:
return session._txn_read_preference() or self.__read_preference
return self.__read_preference
return session._txn_read_preference() or self._read_preference
return self._read_preference
@property
def read_concern(self) -> ReadConcern:
@ -969,7 +971,7 @@ class BaseObject:
.. versionadded:: 3.2
"""
return self.__read_concern
return self._read_concern
class _CaseInsensitiveDictionary(MutableMapping[str, Any]):

View File

@ -16,8 +16,11 @@ from __future__ import annotations
import warnings
from typing import Any, Iterable, Optional, Union
from pymongo.hello import HelloCompat
from pymongo.helpers import _SENSITIVE_COMMANDS
from pymongo.helpers_constants import _SENSITIVE_COMMANDS
from pymongo.synchronous.hello_compat import HelloCompat
_IS_SYNC = True
_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"}
_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD}
@ -146,6 +149,7 @@ class ZstdContext:
def compress(data: bytes) -> bytes:
# ZstdCompressor is not thread safe.
# TODO: Use a pool?
import zstandard
return zstandard.ZstdCompressor().compress(data)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,270 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support for automatic client-side field level encryption."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional
try:
import pymongocrypt # type:ignore[import] # noqa: F401
_HAVE_PYMONGOCRYPT = True
except ImportError:
_HAVE_PYMONGOCRYPT = False
from bson import int64
from pymongo.errors import ConfigurationError
from pymongo.synchronous.common import validate_is_mapping
from pymongo.synchronous.uri_parser import _parse_kms_tls_options
if TYPE_CHECKING:
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.typings import _DocumentTypeArg
_IS_SYNC = True
class AutoEncryptionOpts:
"""Options to configure automatic client-side field level encryption."""
def __init__(
self,
kms_providers: Mapping[str, Any],
key_vault_namespace: str,
key_vault_client: Optional[MongoClient[_DocumentTypeArg]] = None,
schema_map: Optional[Mapping[str, Any]] = None,
bypass_auto_encryption: bool = False,
mongocryptd_uri: str = "mongodb://localhost:27020",
mongocryptd_bypass_spawn: bool = False,
mongocryptd_spawn_path: str = "mongocryptd",
mongocryptd_spawn_args: Optional[list[str]] = None,
kms_tls_options: Optional[Mapping[str, Any]] = None,
crypt_shared_lib_path: Optional[str] = None,
crypt_shared_lib_required: bool = False,
bypass_query_analysis: bool = False,
encrypted_fields_map: Optional[Mapping[str, Any]] = None,
) -> None:
"""Options to configure automatic client-side field level encryption.
Automatic client-side field level encryption requires MongoDB >=4.2
enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not
supported for operations on a database or view and will result in
error.
Although automatic encryption requires MongoDB >=4.2 enterprise or a
MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all
users. To configure automatic *decryption* without automatic
*encryption* set ``bypass_auto_encryption=True``. Explicit
encryption and explicit decryption is also supported for all users
with the :class:`~pymongo.encryption.ClientEncryption` class.
See :ref:`automatic-client-side-encryption` for an example.
:param kms_providers: Map of KMS provider options. The `kms_providers`
map values differ by provider:
- `aws`: Map with "accessKeyId" and "secretAccessKey" as strings.
These are the AWS access key ID and AWS secret access key used
to generate KMS messages. An optional "sessionToken" may be
included to support temporary AWS credentials.
- `azure`: Map with "tenantId", "clientId", and "clientSecret" as
strings. Additionally, "identityPlatformEndpoint" may also be
specified as a string (defaults to 'login.microsoftonline.com').
These are the Azure Active Directory credentials used to
generate Azure Key Vault messages.
- `gcp`: Map with "email" as a string and "privateKey"
as `bytes` or a base64 encoded string.
Additionally, "endpoint" may also be specified as a string
(defaults to 'oauth2.googleapis.com'). These are the
credentials used to generate Google Cloud KMS messages.
- `kmip`: Map with "endpoint" as a host with required port.
For example: ``{"endpoint": "example.com:443"}``.
- `local`: Map with "key" as `bytes` (96 bytes in length) or
a base64 encoded string which decodes
to 96 bytes. "key" is the master key used to encrypt/decrypt
data keys. This key should be generated and stored as securely
as possible.
KMS providers may be specified with an optional name suffix
separated by a colon, for example "kmip:name" or "aws:name".
Named KMS providers do not support :ref:`CSFLE on-demand credentials`.
Named KMS providers enables more than one of each KMS provider type to be configured.
For example, to configure multiple local KMS providers::
kms_providers = {
"local": {"key": local_kek1}, # Unnamed KMS provider.
"local:myname": {"key": local_kek2}, # Named KMS provider with name "myname".
}
:param key_vault_namespace: The namespace for the key vault collection.
The key vault collection contains all data keys used for encryption
and decryption. Data keys are stored as documents in this MongoDB
collection. Data keys are protected with encryption by a KMS
provider.
:param key_vault_client: By default, the key vault collection
is assumed to reside in the same MongoDB cluster as the encrypted
MongoClient. Use this option to route data key queries to a
separate MongoDB cluster.
:param schema_map: Map of collection namespace ("db.coll") to
JSON Schema. By default, a collection's JSONSchema is periodically
polled with the listCollections command. But a JSONSchema may be
specified locally with the schemaMap option.
**Supplying a `schema_map` provides more security than relying on
JSON Schemas obtained from the server. It protects against a
malicious server advertising a false JSON Schema, which could trick
the client into sending unencrypted data that should be
encrypted.**
Schemas supplied in the schemaMap only apply to configuring
automatic encryption for client side encryption. Other validation
rules in the JSON schema will not be enforced by the driver and
will result in an error.
:param bypass_auto_encryption: If ``True``, automatic
encryption will be disabled but automatic decryption will still be
enabled. Defaults to ``False``.
:param mongocryptd_uri: The MongoDB URI used to connect
to the *local* mongocryptd process. Defaults to
``'mongodb://localhost:27020'``.
:param mongocryptd_bypass_spawn: If ``True``, the encrypted
MongoClient will not attempt to spawn the mongocryptd process.
Defaults to ``False``.
:param mongocryptd_spawn_path: Used for spawning the
mongocryptd process. Defaults to ``'mongocryptd'`` and spawns
mongocryptd from the system path.
:param mongocryptd_spawn_args: A list of string arguments to
use when spawning the mongocryptd process. Defaults to
``['--idleShutdownTimeoutSecs=60']``. If the list does not include
the ``idleShutdownTimeoutSecs`` option then
``'--idleShutdownTimeoutSecs=60'`` will be added.
:param kms_tls_options: A map of KMS provider names to TLS
options to use when creating secure connections to KMS providers.
Accepts the same TLS options as
:class:`pymongo.mongo_client.MongoClient`. For example, to
override the system default CA file::
kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}}
Or to supply a client certificate::
kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}}
:param crypt_shared_lib_path: Override the path to load the crypt_shared library.
:param crypt_shared_lib_required: If True, raise an error if libmongocrypt is
unable to load the crypt_shared library.
:param bypass_query_analysis: If ``True``, disable automatic analysis
of outgoing commands. Set `bypass_query_analysis` to use explicit
encryption on indexed fields without the MongoDB Enterprise Advanced
licensed crypt_shared library.
:param encrypted_fields_map: Map of collection namespace ("db.coll") to documents
that described the encrypted fields for Queryable Encryption. For example::
{
"db.encryptedCollection": {
"escCollection": "enxcol_.encryptedCollection.esc",
"ecocCollection": "enxcol_.encryptedCollection.ecoc",
"fields": [
{
"path": "firstName",
"keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')),
"bsonType": "string",
"queries": {"queryType": "equality"}
},
{
"path": "ssn",
"keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')),
"bsonType": "string"
}
]
}
}
.. versionchanged:: 4.2
Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`,
and `bypass_query_analysis` parameters.
.. versionchanged:: 4.0
Added the `kms_tls_options` parameter and the "kmip" KMS provider.
.. versionadded:: 3.9
"""
if not _HAVE_PYMONGOCRYPT:
raise ConfigurationError(
"client side encryption requires the pymongocrypt library: "
"install a compatible version with: "
"python -m pip install 'pymongo[encryption]'"
)
if encrypted_fields_map:
validate_is_mapping("encrypted_fields_map", encrypted_fields_map)
self._encrypted_fields_map = encrypted_fields_map
self._bypass_query_analysis = bypass_query_analysis
self._crypt_shared_lib_path = crypt_shared_lib_path
self._crypt_shared_lib_required = crypt_shared_lib_required
self._kms_providers = kms_providers
self._key_vault_namespace = key_vault_namespace
self._key_vault_client = key_vault_client
self._schema_map = schema_map
self._bypass_auto_encryption = bypass_auto_encryption
self._mongocryptd_uri = mongocryptd_uri
self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
self._mongocryptd_spawn_path = mongocryptd_spawn_path
if mongocryptd_spawn_args is None:
mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"]
self._mongocryptd_spawn_args = mongocryptd_spawn_args
if not isinstance(self._mongocryptd_spawn_args, list):
raise TypeError("mongocryptd_spawn_args must be a list")
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
# Maps KMS provider name to a SSLContext.
self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options)
self._bypass_query_analysis = bypass_query_analysis
class RangeOpts:
"""Options to configure encrypted queries using the rangePreview algorithm."""
def __init__(
self,
sparsity: int,
min: Optional[Any] = None,
max: Optional[Any] = None,
precision: Optional[int] = None,
) -> None:
"""Options to configure encrypted queries using the rangePreview algorithm.
.. note:: This feature is experimental only, and not intended for public use.
:param sparsity: An integer.
:param min: A BSON scalar value corresponding to the type being queried.
:param max: A BSON scalar value corresponding to the type being queried.
:param precision: An integer, may only be set for double or decimal128 types.
.. versionadded:: 4.4
"""
self.min = min
self.max = max
self.sparsity = sparsity
self.precision = precision
@property
def document(self) -> dict[str, Any]:
doc = {}
for k, v in [
("sparsity", int64.Int64(self.sparsity)),
("precision", self.precision),
("min", self.min),
("max", self.max),
]:
if v is not None:
doc[k] = v
return doc

View File

@ -0,0 +1,225 @@
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example event logger classes.
.. versionadded:: 3.11
These loggers can be registered using :func:`register` or
:class:`~pymongo.mongo_client.MongoClient`.
``monitoring.register(CommandLogger())``
or
``MongoClient(event_listeners=[CommandLogger()])``
"""
from __future__ import annotations
import logging
from pymongo.synchronous import monitoring
_IS_SYNC = True
class CommandLogger(monitoring.CommandListener):
"""A simple listener that logs command events.
Listens for :class:`~pymongo.monitoring.CommandStartedEvent`,
:class:`~pymongo.monitoring.CommandSucceededEvent` and
:class:`~pymongo.monitoring.CommandFailedEvent` events and
logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def started(self, event: monitoring.CommandStartedEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} started on server "
f"{event.connection_id}"
)
def succeeded(self, event: monitoring.CommandSucceededEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} on server {event.connection_id} "
f"succeeded in {event.duration_micros} "
"microseconds"
)
def failed(self, event: monitoring.CommandFailedEvent) -> None:
logging.info(
f"Command {event.command_name} with request id "
f"{event.request_id} on server {event.connection_id} "
f"failed in {event.duration_micros} "
"microseconds"
)
class ServerLogger(monitoring.ServerListener):
"""A simple listener that logs server discovery events.
Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`,
:class:`~pymongo.monitoring.ServerDescriptionChangedEvent`,
and :class:`~pymongo.monitoring.ServerClosedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def opened(self, event: monitoring.ServerOpeningEvent) -> None:
logging.info(f"Server {event.server_address} added to topology {event.topology_id}")
def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None:
previous_server_type = event.previous_description.server_type
new_server_type = event.new_description.server_type
if new_server_type != previous_server_type:
# server_type_name was added in PyMongo 3.4
logging.info(
f"Server {event.server_address} changed type from "
f"{event.previous_description.server_type_name} to "
f"{event.new_description.server_type_name}"
)
def closed(self, event: monitoring.ServerClosedEvent) -> None:
logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}")
class HeartbeatLogger(monitoring.ServerHeartbeatListener):
"""A simple listener that logs server heartbeat events.
Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`,
:class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`,
and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
logging.info(f"Heartbeat sent to server {event.connection_id}")
def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
# The reply.document attribute was added in PyMongo 3.4.
logging.info(
f"Heartbeat to server {event.connection_id} "
"succeeded with reply "
f"{event.reply.document}"
)
def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
logging.warning(
f"Heartbeat to server {event.connection_id} failed with error {event.reply}"
)
class TopologyLogger(monitoring.TopologyListener):
"""A simple listener that logs server topology events.
Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`,
:class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`,
and :class:`~pymongo.monitoring.TopologyClosedEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def opened(self, event: monitoring.TopologyOpenedEvent) -> None:
logging.info(f"Topology with id {event.topology_id} opened")
def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None:
logging.info(f"Topology description updated for topology id {event.topology_id}")
previous_topology_type = event.previous_description.topology_type
new_topology_type = event.new_description.topology_type
if new_topology_type != previous_topology_type:
# topology_type_name was added in PyMongo 3.4
logging.info(
f"Topology {event.topology_id} changed type from "
f"{event.previous_description.topology_type_name} to "
f"{event.new_description.topology_type_name}"
)
# The has_writable_server and has_readable_server methods
# were added in PyMongo 3.4.
if not event.new_description.has_writable_server():
logging.warning("No writable servers available.")
if not event.new_description.has_readable_server():
logging.warning("No readable servers available.")
def closed(self, event: monitoring.TopologyClosedEvent) -> None:
logging.info(f"Topology with id {event.topology_id} closed")
class ConnectionPoolLogger(monitoring.ConnectionPoolListener):
"""A simple listener that logs server connection pool events.
Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`,
:class:`~pymongo.monitoring.PoolClearedEvent`,
:class:`~pymongo.monitoring.PoolClosedEvent`,
:~pymongo.monitoring.class:`ConnectionCreatedEvent`,
:class:`~pymongo.monitoring.ConnectionReadyEvent`,
:class:`~pymongo.monitoring.ConnectionClosedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`,
:class:`~pymongo.monitoring.ConnectionCheckedOutEvent`,
and :class:`~pymongo.monitoring.ConnectionCheckedInEvent`
events and logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def pool_created(self, event: monitoring.PoolCreatedEvent) -> None:
logging.info(f"[pool {event.address}] pool created")
def pool_ready(self, event: monitoring.PoolReadyEvent) -> None:
logging.info(f"[pool {event.address}] pool ready")
def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None:
logging.info(f"[pool {event.address}] pool cleared")
def pool_closed(self, event: monitoring.PoolClosedEvent) -> None:
logging.info(f"[pool {event.address}] pool closed")
def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None:
logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created")
def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded"
)
def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] "
f'connection closed, reason: "{event.reason}"'
)
def connection_check_out_started(
self, event: monitoring.ConnectionCheckOutStartedEvent
) -> None:
logging.info(f"[pool {event.address}] connection check out started")
def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None:
logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}")
def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool"
)
def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None:
logging.info(
f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool"
)

View File

@ -0,0 +1,219 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for the 'hello' and legacy hello commands."""
from __future__ import annotations
import copy
import datetime
import itertools
from typing import Any, Generic, Mapping, Optional
from bson.objectid import ObjectId
from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous import common
from pymongo.synchronous.hello_compat import HelloCompat
from pymongo.synchronous.typings import ClusterTime, _DocumentType
_IS_SYNC = True
def _get_server_type(doc: Mapping[str, Any]) -> int:
"""Determine the server type from a hello response."""
if not doc.get("ok"):
return SERVER_TYPE.Unknown
if doc.get("serviceId"):
return SERVER_TYPE.LoadBalancer
elif doc.get("isreplicaset"):
return SERVER_TYPE.RSGhost
elif doc.get("setName"):
if doc.get("hidden"):
return SERVER_TYPE.RSOther
elif doc.get(HelloCompat.PRIMARY):
return SERVER_TYPE.RSPrimary
elif doc.get(HelloCompat.LEGACY_PRIMARY):
return SERVER_TYPE.RSPrimary
elif doc.get("secondary"):
return SERVER_TYPE.RSSecondary
elif doc.get("arbiterOnly"):
return SERVER_TYPE.RSArbiter
else:
return SERVER_TYPE.RSOther
elif doc.get("msg") == "isdbgrid":
return SERVER_TYPE.Mongos
else:
return SERVER_TYPE.Standalone
class Hello(Generic[_DocumentType]):
"""Parse a hello response from the server.
.. versionadded:: 3.12
"""
__slots__ = ("_doc", "_server_type", "_is_writable", "_is_readable", "_awaitable")
def __init__(self, doc: _DocumentType, awaitable: bool = False) -> None:
self._server_type = _get_server_type(doc)
self._doc: _DocumentType = doc
self._is_writable = self._server_type in (
SERVER_TYPE.RSPrimary,
SERVER_TYPE.Standalone,
SERVER_TYPE.Mongos,
SERVER_TYPE.LoadBalancer,
)
self._is_readable = self.server_type == SERVER_TYPE.RSSecondary or self._is_writable
self._awaitable = awaitable
@property
def document(self) -> _DocumentType:
"""The complete hello command response document.
.. versionadded:: 3.4
"""
return copy.copy(self._doc)
@property
def server_type(self) -> int:
return self._server_type
@property
def all_hosts(self) -> set[tuple[str, int]]:
"""List of hosts, passives, and arbiters known to this server."""
return set(
map(
common.clean_node,
itertools.chain(
self._doc.get("hosts", []),
self._doc.get("passives", []),
self._doc.get("arbiters", []),
),
)
)
@property
def tags(self) -> Mapping[str, Any]:
"""Replica set member tags or empty dict."""
return self._doc.get("tags", {})
@property
def primary(self) -> Optional[tuple[str, int]]:
"""This server's opinion about who the primary is, or None."""
if self._doc.get("primary"):
return common.partition_node(self._doc["primary"])
else:
return None
@property
def replica_set_name(self) -> Optional[str]:
"""Replica set name or None."""
return self._doc.get("setName")
@property
def max_bson_size(self) -> int:
return self._doc.get("maxBsonObjectSize", common.MAX_BSON_SIZE)
@property
def max_message_size(self) -> int:
return self._doc.get("maxMessageSizeBytes", 2 * self.max_bson_size)
@property
def max_write_batch_size(self) -> int:
return self._doc.get("maxWriteBatchSize", common.MAX_WRITE_BATCH_SIZE)
@property
def min_wire_version(self) -> int:
return self._doc.get("minWireVersion", common.MIN_WIRE_VERSION)
@property
def max_wire_version(self) -> int:
return self._doc.get("maxWireVersion", common.MAX_WIRE_VERSION)
@property
def set_version(self) -> Optional[int]:
return self._doc.get("setVersion")
@property
def election_id(self) -> Optional[ObjectId]:
return self._doc.get("electionId")
@property
def cluster_time(self) -> Optional[ClusterTime]:
return self._doc.get("$clusterTime")
@property
def logical_session_timeout_minutes(self) -> Optional[int]:
return self._doc.get("logicalSessionTimeoutMinutes")
@property
def is_writable(self) -> bool:
return self._is_writable
@property
def is_readable(self) -> bool:
return self._is_readable
@property
def me(self) -> Optional[tuple[str, int]]:
me = self._doc.get("me")
if me:
return common.clean_node(me)
return None
@property
def last_write_date(self) -> Optional[datetime.datetime]:
return self._doc.get("lastWrite", {}).get("lastWriteDate")
@property
def compressors(self) -> Optional[list[str]]:
return self._doc.get("compression")
@property
def sasl_supported_mechs(self) -> list[str]:
"""Supported authentication mechanisms for the current user.
For example::
>>> hello.sasl_supported_mechs
["SCRAM-SHA-1", "SCRAM-SHA-256"]
"""
return self._doc.get("saslSupportedMechs", [])
@property
def speculative_authenticate(self) -> Optional[Mapping[str, Any]]:
"""The speculativeAuthenticate field."""
return self._doc.get("speculativeAuthenticate")
@property
def topology_version(self) -> Optional[Mapping[str, Any]]:
return self._doc.get("topologyVersion")
@property
def awaitable(self) -> bool:
return self._awaitable
@property
def service_id(self) -> Optional[ObjectId]:
return self._doc.get("serviceId")
@property
def hello_ok(self) -> bool:
return self._doc.get("helloOk", False)
@property
def connection_id(self) -> Optional[int]:
return self._doc.get("connectionId")

Some files were not shown because too many files have changed in this diff Show More