PYTHON-4264 Async PyMongo Beta (#1629)
This commit is contained in:
parent
e9c86f4c00
commit
d6bf0e1e78
@ -17,6 +17,17 @@ repos:
|
||||
exclude: .patch
|
||||
exclude_types: [json]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: synchro
|
||||
name: synchro
|
||||
entry: bash ./tools/synchro.sh
|
||||
language: python
|
||||
require_serial: true
|
||||
additional_dependencies:
|
||||
- ruff==0.1.3
|
||||
- unasync
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.1.3
|
||||
@ -74,7 +85,7 @@ repos:
|
||||
stages: [manual]
|
||||
|
||||
- repo: https://github.com/ariebovenberg/slotscheck
|
||||
rev: v0.17.0
|
||||
rev: v0.19.0
|
||||
hooks:
|
||||
- id: slotscheck
|
||||
files: \.py$
|
||||
|
||||
@ -22,6 +22,7 @@ include doc/make.bat
|
||||
include doc/static/periodic-executor-refs.dot
|
||||
recursive-include requirements *.txt
|
||||
recursive-include tools *.py
|
||||
recursive-include tools *.sh
|
||||
include tools/README.rst
|
||||
include green_framework_test.py
|
||||
recursive-include test *.pem
|
||||
|
||||
@ -3,13 +3,13 @@ Changelog
|
||||
|
||||
Changes in Version 4.8.0
|
||||
-------------------------
|
||||
|
||||
The handshake metadata for "os.name" on Windows has been simplified to "Windows" to improve import time.
|
||||
|
||||
The repr of ``bson.binary.Binary`` is now redacted when the subtype is SENSITIVE_SUBTYPE(8).
|
||||
|
||||
.. warning:: PyMongo 4.8 drops support for Python 3.7 and PyPy 3.8: Python 3.8+ or PyPy 3.9+ is now required.
|
||||
|
||||
PyMongo 4.8 brings a number of improvements including:
|
||||
- The handshake metadata for "os.name" on Windows has been simplified to "Windows" to improve import time.
|
||||
- The repr of ``bson.binary.Binary`` is now redacted when the subtype is SENSITIVE_SUBTYPE(8).
|
||||
- A new asynchronous API with full asyncio support.
|
||||
|
||||
Changes in Version 4.7.3
|
||||
-------------------------
|
||||
|
||||
|
||||
@ -21,980 +21,34 @@ The :mod:`gridfs` package is an implementation of GridFS on top of
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import abc
|
||||
from typing import Any, Mapping, Optional, cast
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from gridfs.asynchronous.grid_file import (
|
||||
AsyncGridFS,
|
||||
AsyncGridFSBucket,
|
||||
AsyncGridIn,
|
||||
AsyncGridOut,
|
||||
AsyncGridOutCursor,
|
||||
)
|
||||
from gridfs.errors import NoFile
|
||||
from gridfs.grid_file import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
from gridfs.grid_file_shared import DEFAULT_CHUNK_SIZE
|
||||
from gridfs.synchronous.grid_file import (
|
||||
GridFS,
|
||||
GridFSBucket,
|
||||
GridIn,
|
||||
GridOut,
|
||||
GridOutCursor,
|
||||
_clear_entity_type_registry,
|
||||
_disallow_transactions,
|
||||
)
|
||||
from pymongo import ASCENDING, DESCENDING, _csot
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.common import validate_string
|
||||
from pymongo.database import Database
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
__all__ = [
|
||||
"AsyncGridFS",
|
||||
"GridFS",
|
||||
"AsyncGridFSBucket",
|
||||
"GridFSBucket",
|
||||
"NoFile",
|
||||
"DEFAULT_CHUNK_SIZE",
|
||||
"AsyncGridIn",
|
||||
"GridIn",
|
||||
"AsyncGridOut",
|
||||
"GridOut",
|
||||
"AsyncGridOutCursor",
|
||||
"GridOutCursor",
|
||||
]
|
||||
|
||||
|
||||
class GridFS:
|
||||
"""An instance of GridFS on top of a single Database."""
|
||||
|
||||
def __init__(self, database: Database, collection: str = "fs"):
|
||||
"""Create a new instance of :class:`GridFS`.
|
||||
|
||||
Raises :class:`TypeError` if `database` is not an instance of
|
||||
:class:`~pymongo.database.Database`.
|
||||
|
||||
:param database: database to use
|
||||
:param collection: root collection to use
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Removed the `disable_md5` parameter. See
|
||||
:ref:`removed-gridfs-checksum` for details.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Running a GridFS operation in a transaction now always raises an
|
||||
error. GridFS does not support multi-document transactions.
|
||||
|
||||
.. versionchanged:: 3.7
|
||||
Added the `disable_md5` parameter.
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
Indexes are only ensured on the first write to the DB.
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
`database` must use an acknowledged
|
||||
:attr:`~pymongo.database.Database.write_concern`
|
||||
|
||||
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
|
||||
"""
|
||||
if not isinstance(database, Database):
|
||||
raise TypeError("database must be an instance of Database")
|
||||
|
||||
database = _clear_entity_type_registry(database)
|
||||
|
||||
if not database.write_concern.acknowledged:
|
||||
raise ConfigurationError("database must use acknowledged write_concern")
|
||||
|
||||
self.__collection = database[collection]
|
||||
self.__files = self.__collection.files
|
||||
self.__chunks = self.__collection.chunks
|
||||
|
||||
def new_file(self, **kwargs: Any) -> GridIn:
|
||||
"""Create a new file in GridFS.
|
||||
|
||||
Returns a new :class:`~gridfs.grid_file.GridIn` instance to
|
||||
which data can be written. Any keyword arguments will be
|
||||
passed through to :meth:`~gridfs.grid_file.GridIn`.
|
||||
|
||||
If the ``"_id"`` of the file is manually specified, it must
|
||||
not already exist in GridFS. Otherwise
|
||||
:class:`~gridfs.errors.FileExists` is raised.
|
||||
|
||||
:param kwargs: keyword arguments for file creation
|
||||
"""
|
||||
return GridIn(self.__collection, **kwargs)
|
||||
|
||||
def put(self, data: Any, **kwargs: Any) -> Any:
|
||||
"""Put data in GridFS as a new file.
|
||||
|
||||
Equivalent to doing::
|
||||
|
||||
with fs.new_file(**kwargs) as f:
|
||||
f.write(data)
|
||||
|
||||
`data` can be either an instance of :class:`bytes` or a file-like
|
||||
object providing a :meth:`read` method. If an `encoding` keyword
|
||||
argument is passed, `data` can also be a :class:`str` instance, which
|
||||
will be encoded as `encoding` before being written. Any keyword
|
||||
arguments will be passed through to the created file - see
|
||||
:meth:`~gridfs.grid_file.GridIn` for possible arguments. Returns the
|
||||
``"_id"`` of the created file.
|
||||
|
||||
If the ``"_id"`` of the file is manually specified, it must
|
||||
not already exist in GridFS. Otherwise
|
||||
:class:`~gridfs.errors.FileExists` is raised.
|
||||
|
||||
:param data: data to be written as a file.
|
||||
:param kwargs: keyword arguments for file creation
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
w=0 writes to GridFS are now prohibited.
|
||||
"""
|
||||
with GridIn(self.__collection, **kwargs) as grid_file:
|
||||
grid_file.write(data)
|
||||
return grid_file._id
|
||||
|
||||
def get(self, file_id: Any, session: Optional[ClientSession] = None) -> GridOut:
|
||||
"""Get a file from GridFS by ``"_id"``.
|
||||
|
||||
Returns an instance of :class:`~gridfs.grid_file.GridOut`,
|
||||
which provides a file-like interface for reading.
|
||||
|
||||
:param file_id: ``"_id"`` of the file to get
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
gout = GridOut(self.__collection, file_id, session=session)
|
||||
|
||||
# Raise NoFile now, instead of on first attribute access.
|
||||
gout._ensure_file()
|
||||
return gout
|
||||
|
||||
def get_version(
|
||||
self,
|
||||
filename: Optional[str] = None,
|
||||
version: Optional[int] = -1,
|
||||
session: Optional[ClientSession] = None,
|
||||
**kwargs: Any,
|
||||
) -> GridOut:
|
||||
"""Get a file from GridFS by ``"filename"`` or metadata fields.
|
||||
|
||||
Returns a version of the file in GridFS whose filename matches
|
||||
`filename` and whose metadata fields match the supplied keyword
|
||||
arguments, as an instance of :class:`~gridfs.grid_file.GridOut`.
|
||||
|
||||
Version numbering is a convenience atop the GridFS API provided
|
||||
by MongoDB. If more than one file matches the query (either by
|
||||
`filename` alone, by metadata fields, or by a combination of
|
||||
both), then version ``-1`` will be the most recently uploaded
|
||||
matching file, ``-2`` the second most recently
|
||||
uploaded, etc. Version ``0`` will be the first version
|
||||
uploaded, ``1`` the second version, etc. So if three versions
|
||||
have been uploaded, then version ``0`` is the same as version
|
||||
``-3``, version ``1`` is the same as version ``-2``, and
|
||||
version ``2`` is the same as version ``-1``.
|
||||
|
||||
Raises :class:`~gridfs.errors.NoFile` if no such version of
|
||||
that file exists.
|
||||
|
||||
:param filename: ``"filename"`` of the file to get, or `None`
|
||||
:param version: version of the file to get (defaults
|
||||
to -1, the most recent version uploaded)
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:param kwargs: find files by custom metadata.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
``get_version`` no longer ensures indexes.
|
||||
"""
|
||||
query = kwargs
|
||||
if filename is not None:
|
||||
query["filename"] = filename
|
||||
|
||||
_disallow_transactions(session)
|
||||
cursor = self.__files.find(query, session=session)
|
||||
if version is None:
|
||||
version = -1
|
||||
if version < 0:
|
||||
skip = abs(version) - 1
|
||||
cursor.limit(-1).skip(skip).sort("uploadDate", DESCENDING)
|
||||
else:
|
||||
cursor.limit(-1).skip(version).sort("uploadDate", ASCENDING)
|
||||
try:
|
||||
doc = next(cursor)
|
||||
return GridOut(self.__collection, file_document=doc, session=session)
|
||||
except StopIteration:
|
||||
raise NoFile("no version %d for filename %r" % (version, filename)) from None
|
||||
|
||||
def get_last_version(
|
||||
self, filename: Optional[str] = None, session: Optional[ClientSession] = None, **kwargs: Any
|
||||
) -> GridOut:
|
||||
"""Get the most recent version of a file in GridFS by ``"filename"``
|
||||
or metadata fields.
|
||||
|
||||
Equivalent to calling :meth:`get_version` with the default
|
||||
`version` (``-1``).
|
||||
|
||||
:param filename: ``"filename"`` of the file to get, or `None`
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:param kwargs: find files by custom metadata.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
return self.get_version(filename=filename, session=session, **kwargs)
|
||||
|
||||
# TODO add optional safe mode for chunk removal?
|
||||
def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None:
|
||||
"""Delete a file from GridFS by ``"_id"``.
|
||||
|
||||
Deletes all data belonging to the file with ``"_id"``:
|
||||
`file_id`.
|
||||
|
||||
.. warning:: Any processes/threads reading from the file while
|
||||
this method is executing will likely see an invalid/corrupt
|
||||
file. Care should be taken to avoid concurrent reads to a file
|
||||
while it is being deleted.
|
||||
|
||||
.. note:: Deletes of non-existent files are considered successful
|
||||
since the end result is the same: no file with that _id remains.
|
||||
|
||||
:param file_id: ``"_id"`` of the file to delete
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
``delete`` no longer ensures indexes.
|
||||
"""
|
||||
_disallow_transactions(session)
|
||||
self.__files.delete_one({"_id": file_id}, session=session)
|
||||
self.__chunks.delete_many({"files_id": file_id}, session=session)
|
||||
|
||||
def list(self, session: Optional[ClientSession] = None) -> list[str]:
|
||||
"""List the names of all files stored in this instance of
|
||||
:class:`GridFS`.
|
||||
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
``list`` no longer ensures indexes.
|
||||
"""
|
||||
_disallow_transactions(session)
|
||||
# With an index, distinct includes documents with no filename
|
||||
# as None.
|
||||
return [
|
||||
name for name in self.__files.distinct("filename", session=session) if name is not None
|
||||
]
|
||||
|
||||
def find_one(
|
||||
self,
|
||||
filter: Optional[Any] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Optional[GridOut]:
|
||||
"""Get a single file from gridfs.
|
||||
|
||||
All arguments to :meth:`find` are also valid arguments for
|
||||
:meth:`find_one`, although any `limit` argument will be
|
||||
ignored. Returns a single :class:`~gridfs.grid_file.GridOut`,
|
||||
or ``None`` if no matching file is found. For example:
|
||||
|
||||
.. code-block: python
|
||||
|
||||
file = fs.find_one({"filename": "lisa.txt"})
|
||||
|
||||
:param filter: a dictionary specifying
|
||||
the query to be performing OR any other type to be used as
|
||||
the value for a query for ``"_id"`` in the file collection.
|
||||
:param args: any additional positional arguments are
|
||||
the same as the arguments to :meth:`find`.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:param kwargs: any additional keyword arguments
|
||||
are the same as the arguments to :meth:`find`.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
if filter is not None and not isinstance(filter, abc.Mapping):
|
||||
filter = {"_id": filter}
|
||||
|
||||
_disallow_transactions(session)
|
||||
for f in self.find(filter, *args, session=session, **kwargs):
|
||||
return f
|
||||
|
||||
return None
|
||||
|
||||
def find(self, *args: Any, **kwargs: Any) -> GridOutCursor:
|
||||
"""Query GridFS for files.
|
||||
|
||||
Returns a cursor that iterates across files matching
|
||||
arbitrary queries on the files collection. Can be combined
|
||||
with other modifiers for additional control. For example::
|
||||
|
||||
for grid_out in fs.find({"filename": "lisa.txt"},
|
||||
no_cursor_timeout=True):
|
||||
data = grid_out.read()
|
||||
|
||||
would iterate through all versions of "lisa.txt" stored in GridFS.
|
||||
Note that setting no_cursor_timeout to True may be important to
|
||||
prevent the cursor from timing out during long multi-file processing
|
||||
work.
|
||||
|
||||
As another example, the call::
|
||||
|
||||
most_recent_three = fs.find().sort("uploadDate", -1).limit(3)
|
||||
|
||||
would return a cursor to the three most recently uploaded files
|
||||
in GridFS.
|
||||
|
||||
Follows a similar interface to
|
||||
:meth:`~pymongo.collection.Collection.find`
|
||||
in :class:`~pymongo.collection.Collection`.
|
||||
|
||||
If a :class:`~pymongo.client_session.ClientSession` is passed to
|
||||
:meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances
|
||||
are associated with that session.
|
||||
|
||||
:param filter: A query document that selects which files
|
||||
to include in the result set. Can be an empty document to include
|
||||
all files.
|
||||
:param skip: the number of files to omit (from
|
||||
the start of the result set) when returning the results
|
||||
:param limit: the maximum number of results to
|
||||
return
|
||||
:param no_cursor_timeout: if False (the default), any
|
||||
returned cursor is closed by the server after 10 minutes of
|
||||
inactivity. If set to True, the returned cursor will never
|
||||
time out on the server. Care should be taken to ensure that
|
||||
cursors with no_cursor_timeout turned on are properly closed.
|
||||
:param sort: a list of (key, direction) pairs
|
||||
specifying the sort order for this query. See
|
||||
:meth:`~pymongo.cursor.Cursor.sort` for details.
|
||||
|
||||
Raises :class:`TypeError` if any of the arguments are of
|
||||
improper type. Returns an instance of
|
||||
:class:`~gridfs.grid_file.GridOutCursor`
|
||||
corresponding to this query.
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
Removed the read_preference, tag_sets, and
|
||||
secondary_acceptable_latency_ms options.
|
||||
.. versionadded:: 2.7
|
||||
.. seealso:: The MongoDB documentation on `find <https://dochub.mongodb.org/core/find>`_.
|
||||
"""
|
||||
return GridOutCursor(self.__collection, *args, **kwargs)
|
||||
|
||||
def exists(
|
||||
self,
|
||||
document_or_id: Optional[Any] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Check if a file exists in this instance of :class:`GridFS`.
|
||||
|
||||
The file to check for can be specified by the value of its
|
||||
``_id`` key, or by passing in a query document. A query
|
||||
document can be passed in as dictionary, or by using keyword
|
||||
arguments. Thus, the following three calls are equivalent:
|
||||
|
||||
>>> fs.exists(file_id)
|
||||
>>> fs.exists({"_id": file_id})
|
||||
>>> fs.exists(_id=file_id)
|
||||
|
||||
As are the following two calls:
|
||||
|
||||
>>> fs.exists({"filename": "mike.txt"})
|
||||
>>> fs.exists(filename="mike.txt")
|
||||
|
||||
And the following two:
|
||||
|
||||
>>> fs.exists({"foo": {"$gt": 12}})
|
||||
>>> fs.exists(foo={"$gt": 12})
|
||||
|
||||
Returns ``True`` if a matching file exists, ``False``
|
||||
otherwise. Calls to :meth:`exists` will not automatically
|
||||
create appropriate indexes; application developers should be
|
||||
sure to create indexes if needed and as appropriate.
|
||||
|
||||
:param document_or_id: query document, or _id of the
|
||||
document to check for
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
:param kwargs: keyword arguments are used as a
|
||||
query document, if they're present.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
_disallow_transactions(session)
|
||||
if kwargs:
|
||||
f = self.__files.find_one(kwargs, ["_id"], session=session)
|
||||
else:
|
||||
f = self.__files.find_one(document_or_id, ["_id"], session=session)
|
||||
|
||||
return f is not None
|
||||
|
||||
|
||||
class GridFSBucket:
|
||||
"""An instance of GridFS on top of a single Database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Database,
|
||||
bucket_name: str = "fs",
|
||||
chunk_size_bytes: int = DEFAULT_CHUNK_SIZE,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
) -> None:
|
||||
"""Create a new instance of :class:`GridFSBucket`.
|
||||
|
||||
Raises :exc:`TypeError` if `database` is not an instance of
|
||||
:class:`~pymongo.database.Database`.
|
||||
|
||||
Raises :exc:`~pymongo.errors.ConfigurationError` if `write_concern`
|
||||
is not acknowledged.
|
||||
|
||||
:param database: database to use.
|
||||
:param bucket_name: The name of the bucket. Defaults to 'fs'.
|
||||
:param chunk_size_bytes: The chunk size in bytes. Defaults
|
||||
to 255KB.
|
||||
:param write_concern: The
|
||||
:class:`~pymongo.write_concern.WriteConcern` to use. If ``None``
|
||||
(the default) db.write_concern is used.
|
||||
:param read_preference: The read preference to use. If
|
||||
``None`` (the default) db.read_preference is used.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Removed the `disable_md5` parameter. See
|
||||
:ref:`removed-gridfs-checksum` for details.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Running a GridFSBucket operation in a transaction now always raises
|
||||
an error. GridFSBucket does not support multi-document transactions.
|
||||
|
||||
.. versionchanged:: 3.7
|
||||
Added the `disable_md5` parameter.
|
||||
|
||||
.. versionadded:: 3.1
|
||||
|
||||
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
|
||||
"""
|
||||
if not isinstance(db, Database):
|
||||
raise TypeError("database must be an instance of Database")
|
||||
|
||||
db = _clear_entity_type_registry(db)
|
||||
|
||||
wtc = write_concern if write_concern is not None else db.write_concern
|
||||
if not wtc.acknowledged:
|
||||
raise ConfigurationError("write concern must be acknowledged")
|
||||
|
||||
self._bucket_name = bucket_name
|
||||
self._collection = db[bucket_name]
|
||||
self._chunks: Collection = self._collection.chunks.with_options(
|
||||
write_concern=write_concern, read_preference=read_preference
|
||||
)
|
||||
|
||||
self._files: Collection = self._collection.files.with_options(
|
||||
write_concern=write_concern, read_preference=read_preference
|
||||
)
|
||||
|
||||
self._chunk_size_bytes = chunk_size_bytes
|
||||
self._timeout = db.client.options.timeout
|
||||
|
||||
def open_upload_stream(
|
||||
self,
|
||||
filename: str,
|
||||
chunk_size_bytes: Optional[int] = None,
|
||||
metadata: Optional[Mapping[str, Any]] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> GridIn:
|
||||
"""Opens a Stream that the application can write the contents of the
|
||||
file to.
|
||||
|
||||
The user must specify the filename, and can choose to add any
|
||||
additional information in the metadata field of the file document or
|
||||
modify the chunk size.
|
||||
For example::
|
||||
|
||||
my_db = MongoClient().test
|
||||
fs = GridFSBucket(my_db)
|
||||
with fs.open_upload_stream(
|
||||
"test_file", chunk_size_bytes=4,
|
||||
metadata={"contentType": "text/plain"}) as grid_in:
|
||||
grid_in.write("data I want to store!")
|
||||
# uploaded on close
|
||||
|
||||
Returns an instance of :class:`~gridfs.grid_file.GridIn`.
|
||||
|
||||
Raises :exc:`~gridfs.errors.NoFile` if no such version of
|
||||
that file exists.
|
||||
Raises :exc:`~ValueError` if `filename` is not a string.
|
||||
|
||||
:param filename: The name of the file to upload.
|
||||
:param chunk_size_bytes` (options): The number of bytes per chunk of this
|
||||
file. Defaults to the chunk_size_bytes in :class:`GridFSBucket`.
|
||||
:param metadata: User data for the 'metadata' field of the
|
||||
files collection document. If not provided the metadata field will
|
||||
be omitted from the files collection document.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
validate_string("filename", filename)
|
||||
|
||||
opts = {
|
||||
"filename": filename,
|
||||
"chunk_size": (
|
||||
chunk_size_bytes if chunk_size_bytes is not None else self._chunk_size_bytes
|
||||
),
|
||||
}
|
||||
if metadata is not None:
|
||||
opts["metadata"] = metadata
|
||||
|
||||
return GridIn(self._collection, session=session, **opts)
|
||||
|
||||
def open_upload_stream_with_id(
|
||||
self,
|
||||
file_id: Any,
|
||||
filename: str,
|
||||
chunk_size_bytes: Optional[int] = None,
|
||||
metadata: Optional[Mapping[str, Any]] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> GridIn:
|
||||
"""Opens a Stream that the application can write the contents of the
|
||||
file to.
|
||||
|
||||
The user must specify the file id and filename, and can choose to add
|
||||
any additional information in the metadata field of the file document
|
||||
or modify the chunk size.
|
||||
For example::
|
||||
|
||||
my_db = MongoClient().test
|
||||
fs = GridFSBucket(my_db)
|
||||
with fs.open_upload_stream_with_id(
|
||||
ObjectId(),
|
||||
"test_file",
|
||||
chunk_size_bytes=4,
|
||||
metadata={"contentType": "text/plain"}) as grid_in:
|
||||
grid_in.write("data I want to store!")
|
||||
# uploaded on close
|
||||
|
||||
Returns an instance of :class:`~gridfs.grid_file.GridIn`.
|
||||
|
||||
Raises :exc:`~gridfs.errors.NoFile` if no such version of
|
||||
that file exists.
|
||||
Raises :exc:`~ValueError` if `filename` is not a string.
|
||||
|
||||
:param file_id: The id to use for this file. The id must not have
|
||||
already been used for another file.
|
||||
:param filename: The name of the file to upload.
|
||||
:param chunk_size_bytes` (options): The number of bytes per chunk of this
|
||||
file. Defaults to the chunk_size_bytes in :class:`GridFSBucket`.
|
||||
:param metadata: User data for the 'metadata' field of the
|
||||
files collection document. If not provided the metadata field will
|
||||
be omitted from the files collection document.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
validate_string("filename", filename)
|
||||
|
||||
opts = {
|
||||
"_id": file_id,
|
||||
"filename": filename,
|
||||
"chunk_size": (
|
||||
chunk_size_bytes if chunk_size_bytes is not None else self._chunk_size_bytes
|
||||
),
|
||||
}
|
||||
if metadata is not None:
|
||||
opts["metadata"] = metadata
|
||||
|
||||
return GridIn(self._collection, session=session, **opts)
|
||||
|
||||
@_csot.apply
|
||||
def upload_from_stream(
|
||||
self,
|
||||
filename: str,
|
||||
source: Any,
|
||||
chunk_size_bytes: Optional[int] = None,
|
||||
metadata: Optional[Mapping[str, Any]] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> ObjectId:
|
||||
"""Uploads a user file to a GridFS bucket.
|
||||
|
||||
Reads the contents of the user file from `source` and uploads
|
||||
it to the file `filename`. Source can be a string or file-like object.
|
||||
For example::
|
||||
|
||||
my_db = MongoClient().test
|
||||
fs = GridFSBucket(my_db)
|
||||
file_id = fs.upload_from_stream(
|
||||
"test_file",
|
||||
"data I want to store!",
|
||||
chunk_size_bytes=4,
|
||||
metadata={"contentType": "text/plain"})
|
||||
|
||||
Returns the _id of the uploaded file.
|
||||
|
||||
Raises :exc:`~gridfs.errors.NoFile` if no such version of
|
||||
that file exists.
|
||||
Raises :exc:`~ValueError` if `filename` is not a string.
|
||||
|
||||
:param filename: The name of the file to upload.
|
||||
:param source: The source stream of the content to be uploaded. Must be
|
||||
a file-like object that implements :meth:`read` or a string.
|
||||
:param chunk_size_bytes` (options): The number of bytes per chunk of this
|
||||
file. Defaults to the chunk_size_bytes of :class:`GridFSBucket`.
|
||||
:param metadata: User data for the 'metadata' field of the
|
||||
files collection document. If not provided the metadata field will
|
||||
be omitted from the files collection document.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
with self.open_upload_stream(filename, chunk_size_bytes, metadata, session=session) as gin:
|
||||
gin.write(source)
|
||||
|
||||
return cast(ObjectId, gin._id)
|
||||
|
||||
@_csot.apply
|
||||
def upload_from_stream_with_id(
|
||||
self,
|
||||
file_id: Any,
|
||||
filename: str,
|
||||
source: Any,
|
||||
chunk_size_bytes: Optional[int] = None,
|
||||
metadata: Optional[Mapping[str, Any]] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> None:
|
||||
"""Uploads a user file to a GridFS bucket with a custom file id.
|
||||
|
||||
Reads the contents of the user file from `source` and uploads
|
||||
it to the file `filename`. Source can be a string or file-like object.
|
||||
For example::
|
||||
|
||||
my_db = MongoClient().test
|
||||
fs = GridFSBucket(my_db)
|
||||
file_id = fs.upload_from_stream(
|
||||
ObjectId(),
|
||||
"test_file",
|
||||
"data I want to store!",
|
||||
chunk_size_bytes=4,
|
||||
metadata={"contentType": "text/plain"})
|
||||
|
||||
Raises :exc:`~gridfs.errors.NoFile` if no such version of
|
||||
that file exists.
|
||||
Raises :exc:`~ValueError` if `filename` is not a string.
|
||||
|
||||
:param file_id: The id to use for this file. The id must not have
|
||||
already been used for another file.
|
||||
:param filename: The name of the file to upload.
|
||||
:param source: The source stream of the content to be uploaded. Must be
|
||||
a file-like object that implements :meth:`read` or a string.
|
||||
:param chunk_size_bytes` (options): The number of bytes per chunk of this
|
||||
file. Defaults to the chunk_size_bytes of :class:`GridFSBucket`.
|
||||
:param metadata: User data for the 'metadata' field of the
|
||||
files collection document. If not provided the metadata field will
|
||||
be omitted from the files collection document.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
with self.open_upload_stream_with_id(
|
||||
file_id, filename, chunk_size_bytes, metadata, session=session
|
||||
) as gin:
|
||||
gin.write(source)
|
||||
|
||||
def open_download_stream(
|
||||
self, file_id: Any, session: Optional[ClientSession] = None
|
||||
) -> GridOut:
|
||||
"""Opens a Stream from which the application can read the contents of
|
||||
the stored file specified by file_id.
|
||||
|
||||
For example::
|
||||
|
||||
my_db = MongoClient().test
|
||||
fs = GridFSBucket(my_db)
|
||||
# get _id of file to read.
|
||||
file_id = fs.upload_from_stream("test_file", "data I want to store!")
|
||||
grid_out = fs.open_download_stream(file_id)
|
||||
contents = grid_out.read()
|
||||
|
||||
Returns an instance of :class:`~gridfs.grid_file.GridOut`.
|
||||
|
||||
Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists.
|
||||
|
||||
:param file_id: The _id of the file to be downloaded.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
gout = GridOut(self._collection, file_id, session=session)
|
||||
|
||||
# Raise NoFile now, instead of on first attribute access.
|
||||
gout._ensure_file()
|
||||
return gout
|
||||
|
||||
@_csot.apply
|
||||
def download_to_stream(
|
||||
self, file_id: Any, destination: Any, session: Optional[ClientSession] = None
|
||||
) -> None:
|
||||
"""Downloads the contents of the stored file specified by file_id and
|
||||
writes the contents to `destination`.
|
||||
|
||||
For example::
|
||||
|
||||
my_db = MongoClient().test
|
||||
fs = GridFSBucket(my_db)
|
||||
# Get _id of file to read
|
||||
file_id = fs.upload_from_stream("test_file", "data I want to store!")
|
||||
# Get file to write to
|
||||
file = open('myfile','wb+')
|
||||
fs.download_to_stream(file_id, file)
|
||||
file.seek(0)
|
||||
contents = file.read()
|
||||
|
||||
Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists.
|
||||
|
||||
:param file_id: The _id of the file to be downloaded.
|
||||
:param destination: a file-like object implementing :meth:`write`.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
with self.open_download_stream(file_id, session=session) as gout:
|
||||
while True:
|
||||
chunk = gout.readchunk()
|
||||
if not len(chunk):
|
||||
break
|
||||
destination.write(chunk)
|
||||
|
||||
@_csot.apply
|
||||
def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None:
|
||||
"""Given an file_id, delete this stored file's files collection document
|
||||
and associated chunks from a GridFS bucket.
|
||||
|
||||
For example::
|
||||
|
||||
my_db = MongoClient().test
|
||||
fs = GridFSBucket(my_db)
|
||||
# Get _id of file to delete
|
||||
file_id = fs.upload_from_stream("test_file", "data I want to store!")
|
||||
fs.delete(file_id)
|
||||
|
||||
Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists.
|
||||
|
||||
:param file_id: The _id of the file to be deleted.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
_disallow_transactions(session)
|
||||
res = self._files.delete_one({"_id": file_id}, session=session)
|
||||
self._chunks.delete_many({"files_id": file_id}, session=session)
|
||||
if not res.deleted_count:
|
||||
raise NoFile("no file could be deleted because none matched %s" % file_id)
|
||||
|
||||
def find(self, *args: Any, **kwargs: Any) -> GridOutCursor:
|
||||
"""Find and return the files collection documents that match ``filter``
|
||||
|
||||
Returns a cursor that iterates across files matching
|
||||
arbitrary queries on the files collection. Can be combined
|
||||
with other modifiers for additional control.
|
||||
|
||||
For example::
|
||||
|
||||
for grid_data in fs.find({"filename": "lisa.txt"},
|
||||
no_cursor_timeout=True):
|
||||
data = grid_data.read()
|
||||
|
||||
would iterate through all versions of "lisa.txt" stored in GridFS.
|
||||
Note that setting no_cursor_timeout to True may be important to
|
||||
prevent the cursor from timing out during long multi-file processing
|
||||
work.
|
||||
|
||||
As another example, the call::
|
||||
|
||||
most_recent_three = fs.find().sort("uploadDate", -1).limit(3)
|
||||
|
||||
would return a cursor to the three most recently uploaded files
|
||||
in GridFS.
|
||||
|
||||
Follows a similar interface to
|
||||
:meth:`~pymongo.collection.Collection.find`
|
||||
in :class:`~pymongo.collection.Collection`.
|
||||
|
||||
If a :class:`~pymongo.client_session.ClientSession` is passed to
|
||||
:meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances
|
||||
are associated with that session.
|
||||
|
||||
:param filter: Search query.
|
||||
:param batch_size: The number of documents to return per
|
||||
batch.
|
||||
:param limit: The maximum number of documents to return.
|
||||
:param no_cursor_timeout: The server normally times out idle
|
||||
cursors after an inactivity period (10 minutes) to prevent excess
|
||||
memory use. Set this option to True prevent that.
|
||||
:param skip: The number of documents to skip before
|
||||
returning.
|
||||
:param sort: The order by which to sort results. Defaults to
|
||||
None.
|
||||
"""
|
||||
return GridOutCursor(self._collection, *args, **kwargs)
|
||||
|
||||
def open_download_stream_by_name(
|
||||
self, filename: str, revision: int = -1, session: Optional[ClientSession] = None
|
||||
) -> GridOut:
|
||||
"""Opens a Stream from which the application can read the contents of
|
||||
`filename` and optional `revision`.
|
||||
|
||||
For example::
|
||||
|
||||
my_db = MongoClient().test
|
||||
fs = GridFSBucket(my_db)
|
||||
grid_out = fs.open_download_stream_by_name("test_file")
|
||||
contents = grid_out.read()
|
||||
|
||||
Returns an instance of :class:`~gridfs.grid_file.GridOut`.
|
||||
|
||||
Raises :exc:`~gridfs.errors.NoFile` if no such version of
|
||||
that file exists.
|
||||
|
||||
Raises :exc:`~ValueError` filename is not a string.
|
||||
|
||||
:param filename: The name of the file to read from.
|
||||
:param revision: Which revision (documents with the same
|
||||
filename and different uploadDate) of the file to retrieve.
|
||||
Defaults to -1 (the most recent revision).
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
:Note: Revision numbers are defined as follows:
|
||||
|
||||
- 0 = the original stored file
|
||||
- 1 = the first revision
|
||||
- 2 = the second revision
|
||||
- etc...
|
||||
- -2 = the second most recent revision
|
||||
- -1 = the most recent revision
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
validate_string("filename", filename)
|
||||
query = {"filename": filename}
|
||||
_disallow_transactions(session)
|
||||
cursor = self._files.find(query, session=session)
|
||||
if revision < 0:
|
||||
skip = abs(revision) - 1
|
||||
cursor.limit(-1).skip(skip).sort("uploadDate", DESCENDING)
|
||||
else:
|
||||
cursor.limit(-1).skip(revision).sort("uploadDate", ASCENDING)
|
||||
try:
|
||||
grid_file = next(cursor)
|
||||
return GridOut(self._collection, file_document=grid_file, session=session)
|
||||
except StopIteration:
|
||||
raise NoFile("no version %d for filename %r" % (revision, filename)) from None
|
||||
|
||||
@_csot.apply
|
||||
def download_to_stream_by_name(
|
||||
self,
|
||||
filename: str,
|
||||
destination: Any,
|
||||
revision: int = -1,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> None:
|
||||
"""Write the contents of `filename` (with optional `revision`) to
|
||||
`destination`.
|
||||
|
||||
For example::
|
||||
|
||||
my_db = MongoClient().test
|
||||
fs = GridFSBucket(my_db)
|
||||
# Get file to write to
|
||||
file = open('myfile','wb')
|
||||
fs.download_to_stream_by_name("test_file", file)
|
||||
|
||||
Raises :exc:`~gridfs.errors.NoFile` if no such version of
|
||||
that file exists.
|
||||
|
||||
Raises :exc:`~ValueError` if `filename` is not a string.
|
||||
|
||||
:param filename: The name of the file to read from.
|
||||
:param destination: A file-like object that implements :meth:`write`.
|
||||
:param revision: Which revision (documents with the same
|
||||
filename and different uploadDate) of the file to retrieve.
|
||||
Defaults to -1 (the most recent revision).
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
:Note: Revision numbers are defined as follows:
|
||||
|
||||
- 0 = the original stored file
|
||||
- 1 = the first revision
|
||||
- 2 = the second revision
|
||||
- etc...
|
||||
- -2 = the second most recent revision
|
||||
- -1 = the most recent revision
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
with self.open_download_stream_by_name(filename, revision, session=session) as gout:
|
||||
while True:
|
||||
chunk = gout.readchunk()
|
||||
if not len(chunk):
|
||||
break
|
||||
destination.write(chunk)
|
||||
|
||||
def rename(
|
||||
self, file_id: Any, new_filename: str, session: Optional[ClientSession] = None
|
||||
) -> None:
|
||||
"""Renames the stored file with the specified file_id.
|
||||
|
||||
For example::
|
||||
|
||||
my_db = MongoClient().test
|
||||
fs = GridFSBucket(my_db)
|
||||
# Get _id of file to rename
|
||||
file_id = fs.upload_from_stream("test_file", "data I want to store!")
|
||||
fs.rename(file_id, "new_test_name")
|
||||
|
||||
Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists.
|
||||
|
||||
:param file_id: The _id of the file to be renamed.
|
||||
:param new_filename: The new name of the file.
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
_disallow_transactions(session)
|
||||
result = self._files.update_one(
|
||||
{"_id": file_id}, {"$set": {"filename": new_filename}}, session=session
|
||||
)
|
||||
if not result.matched_count:
|
||||
raise NoFile(
|
||||
"no files could be renamed %r because none "
|
||||
"matched file_id %i" % (new_filename, file_id)
|
||||
)
|
||||
|
||||
1899
gridfs/asynchronous/grid_file.py
Normal file
1899
gridfs/asynchronous/grid_file.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
# Copyright 2009-present MongoDB, Inc.
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,953 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tools for representing files stored in GridFS."""
|
||||
"""Re-import of synchronous gridfs API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Iterable, Mapping, NoReturn, Optional
|
||||
|
||||
from bson.int64 import Int64
|
||||
from bson.objectid import ObjectId
|
||||
from gridfs.errors import CorruptGridFile, FileExists, NoFile
|
||||
from pymongo import ASCENDING
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.cursor import Cursor
|
||||
from pymongo.errors import (
|
||||
BulkWriteError,
|
||||
ConfigurationError,
|
||||
CursorNotFound,
|
||||
DuplicateKeyError,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.helpers import _check_write_command_response
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
|
||||
_SEEK_SET = os.SEEK_SET
|
||||
_SEEK_CUR = os.SEEK_CUR
|
||||
_SEEK_END = os.SEEK_END
|
||||
|
||||
EMPTY = b""
|
||||
NEWLN = b"\n"
|
||||
|
||||
"""Default chunk size, in bytes."""
|
||||
# Slightly under a power of 2, to work well with server's record allocations.
|
||||
DEFAULT_CHUNK_SIZE = 255 * 1024
|
||||
# The number of chunked bytes to buffer before calling insert_many.
|
||||
_UPLOAD_BUFFER_SIZE = MAX_MESSAGE_SIZE
|
||||
# The number of chunk documents to buffer before calling insert_many.
|
||||
_UPLOAD_BUFFER_CHUNKS = 100000
|
||||
# Rough BSON overhead of a chunk document not including the chunk data itself.
|
||||
# Essentially len(encode({"_id": ObjectId(), "files_id": ObjectId(), "n": 1, "data": ""}))
|
||||
_CHUNK_OVERHEAD = 60
|
||||
|
||||
_C_INDEX: dict[str, Any] = {"files_id": ASCENDING, "n": ASCENDING}
|
||||
_F_INDEX: dict[str, Any] = {"filename": ASCENDING, "uploadDate": ASCENDING}
|
||||
|
||||
|
||||
def _grid_in_property(
|
||||
field_name: str,
|
||||
docstring: str,
|
||||
read_only: Optional[bool] = False,
|
||||
closed_only: Optional[bool] = False,
|
||||
) -> Any:
|
||||
"""Create a GridIn property."""
|
||||
warn_str = ""
|
||||
if docstring.startswith("DEPRECATED,"):
|
||||
warn_str = (
|
||||
f"GridIn property '{field_name}' is deprecated and will be removed in PyMongo 5.0"
|
||||
)
|
||||
|
||||
def getter(self: Any) -> Any:
|
||||
if warn_str:
|
||||
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
|
||||
if closed_only and not self._closed:
|
||||
raise AttributeError("can only get %r on a closed file" % field_name)
|
||||
# Protect against PHP-237
|
||||
if field_name == "length":
|
||||
return self._file.get(field_name, 0)
|
||||
return self._file.get(field_name, None)
|
||||
|
||||
def setter(self: Any, value: Any) -> Any:
|
||||
if warn_str:
|
||||
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
|
||||
if self._closed:
|
||||
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {field_name: value}})
|
||||
self._file[field_name] = value
|
||||
|
||||
if read_only:
|
||||
docstring += "\n\nThis attribute is read-only."
|
||||
elif closed_only:
|
||||
docstring = "{}\n\n{}".format(
|
||||
docstring,
|
||||
"This attribute is read-only and "
|
||||
"can only be read after :meth:`close` "
|
||||
"has been called.",
|
||||
)
|
||||
|
||||
if not read_only and not closed_only:
|
||||
return property(getter, setter, doc=docstring)
|
||||
return property(getter, doc=docstring)
|
||||
|
||||
|
||||
def _grid_out_property(field_name: str, docstring: str) -> Any:
|
||||
"""Create a GridOut property."""
|
||||
warn_str = ""
|
||||
if docstring.startswith("DEPRECATED,"):
|
||||
warn_str = (
|
||||
f"GridOut property '{field_name}' is deprecated and will be removed in PyMongo 5.0"
|
||||
)
|
||||
|
||||
def getter(self: Any) -> Any:
|
||||
if warn_str:
|
||||
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
|
||||
self._ensure_file()
|
||||
|
||||
# Protect against PHP-237
|
||||
if field_name == "length":
|
||||
return self._file.get(field_name, 0)
|
||||
return self._file.get(field_name, None)
|
||||
|
||||
docstring += "\n\nThis attribute is read-only."
|
||||
return property(getter, doc=docstring)
|
||||
|
||||
|
||||
def _clear_entity_type_registry(entity: Any, **kwargs: Any) -> Any:
|
||||
"""Clear the given database/collection object's type registry."""
|
||||
codecopts = entity.codec_options.with_options(type_registry=None)
|
||||
return entity.with_options(codec_options=codecopts, **kwargs)
|
||||
|
||||
|
||||
def _disallow_transactions(session: Optional[ClientSession]) -> None:
|
||||
if session and session.in_transaction:
|
||||
raise InvalidOperation("GridFS does not support multi-document transactions")
|
||||
|
||||
|
||||
class GridIn:
|
||||
"""Class to write data to GridFS."""
|
||||
|
||||
def __init__(
|
||||
self, root_collection: Collection, session: Optional[ClientSession] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Write a file to GridFS
|
||||
|
||||
Application developers should generally not need to
|
||||
instantiate this class directly - instead see the methods
|
||||
provided by :class:`~gridfs.GridFS`.
|
||||
|
||||
Raises :class:`TypeError` if `root_collection` is not an
|
||||
instance of :class:`~pymongo.collection.Collection`.
|
||||
|
||||
Any of the file level options specified in the `GridFS Spec
|
||||
<http://dochub.mongodb.org/core/gridfsspec>`_ may be passed as
|
||||
keyword arguments. Any additional keyword arguments will be
|
||||
set as additional fields on the file document. Valid keyword
|
||||
arguments include:
|
||||
|
||||
- ``"_id"``: unique ID for this file (default:
|
||||
:class:`~bson.objectid.ObjectId`) - this ``"_id"`` must
|
||||
not have already been used for another file
|
||||
|
||||
- ``"filename"``: human name for the file
|
||||
|
||||
- ``"contentType"`` or ``"content_type"``: valid mime-type
|
||||
for the file
|
||||
|
||||
- ``"chunkSize"`` or ``"chunk_size"``: size of each of the
|
||||
chunks, in bytes (default: 255 kb)
|
||||
|
||||
- ``"encoding"``: encoding used for this file. Any :class:`str`
|
||||
that is written to the file will be converted to :class:`bytes`.
|
||||
|
||||
:param root_collection: root collection to write to
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession` to use for all
|
||||
commands
|
||||
:param kwargs: Any: file level options (see above)
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Removed the `disable_md5` parameter. See
|
||||
:ref:`removed-gridfs-checksum` for details.
|
||||
|
||||
.. versionchanged:: 3.7
|
||||
Added the `disable_md5` parameter.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
`root_collection` must use an acknowledged
|
||||
:attr:`~pymongo.collection.Collection.write_concern`
|
||||
"""
|
||||
if not isinstance(root_collection, Collection):
|
||||
raise TypeError("root_collection must be an instance of Collection")
|
||||
|
||||
if not root_collection.write_concern.acknowledged:
|
||||
raise ConfigurationError("root_collection must use acknowledged write_concern")
|
||||
_disallow_transactions(session)
|
||||
|
||||
# Handle alternative naming
|
||||
if "content_type" in kwargs:
|
||||
kwargs["contentType"] = kwargs.pop("content_type")
|
||||
if "chunk_size" in kwargs:
|
||||
kwargs["chunkSize"] = kwargs.pop("chunk_size")
|
||||
|
||||
coll = _clear_entity_type_registry(root_collection, read_preference=ReadPreference.PRIMARY)
|
||||
|
||||
# Defaults
|
||||
kwargs["_id"] = kwargs.get("_id", ObjectId())
|
||||
kwargs["chunkSize"] = kwargs.get("chunkSize", DEFAULT_CHUNK_SIZE)
|
||||
object.__setattr__(self, "_session", session)
|
||||
object.__setattr__(self, "_coll", coll)
|
||||
object.__setattr__(self, "_chunks", coll.chunks)
|
||||
object.__setattr__(self, "_file", kwargs)
|
||||
object.__setattr__(self, "_buffer", io.BytesIO())
|
||||
object.__setattr__(self, "_position", 0)
|
||||
object.__setattr__(self, "_chunk_number", 0)
|
||||
object.__setattr__(self, "_closed", False)
|
||||
object.__setattr__(self, "_ensured_index", False)
|
||||
object.__setattr__(self, "_buffered_docs", [])
|
||||
object.__setattr__(self, "_buffered_docs_size", 0)
|
||||
|
||||
def __create_index(self, collection: Collection, index_key: Any, unique: bool) -> None:
|
||||
doc = collection.find_one(projection={"_id": 1}, session=self._session)
|
||||
if doc is None:
|
||||
try:
|
||||
index_keys = [
|
||||
index_spec["key"]
|
||||
for index_spec in collection.list_indexes(session=self._session)
|
||||
]
|
||||
except OperationFailure:
|
||||
index_keys = []
|
||||
if index_key not in index_keys:
|
||||
collection.create_index(index_key.items(), unique=unique, session=self._session)
|
||||
|
||||
def __ensure_indexes(self) -> None:
|
||||
if not object.__getattribute__(self, "_ensured_index"):
|
||||
_disallow_transactions(self._session)
|
||||
self.__create_index(self._coll.files, _F_INDEX, False)
|
||||
self.__create_index(self._coll.chunks, _C_INDEX, True)
|
||||
object.__setattr__(self, "_ensured_index", True)
|
||||
|
||||
def abort(self) -> None:
|
||||
"""Remove all chunks/files that may have been uploaded and close."""
|
||||
self._coll.chunks.delete_many({"files_id": self._file["_id"]}, session=self._session)
|
||||
self._coll.files.delete_one({"_id": self._file["_id"]}, session=self._session)
|
||||
object.__setattr__(self, "_closed", True)
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
"""Is this file closed?"""
|
||||
return self._closed
|
||||
|
||||
_id: Any = _grid_in_property("_id", "The ``'_id'`` value for this file.", read_only=True)
|
||||
filename: Optional[str] = _grid_in_property("filename", "Name of this file.")
|
||||
name: Optional[str] = _grid_in_property("filename", "Alias for `filename`.")
|
||||
content_type: Optional[str] = _grid_in_property(
|
||||
"contentType", "DEPRECATED, will be removed in PyMongo 5.0. Mime-type for this file."
|
||||
)
|
||||
length: int = _grid_in_property("length", "Length (in bytes) of this file.", closed_only=True)
|
||||
chunk_size: int = _grid_in_property("chunkSize", "Chunk size for this file.", read_only=True)
|
||||
upload_date: datetime.datetime = _grid_in_property(
|
||||
"uploadDate", "Date that this file was uploaded.", closed_only=True
|
||||
)
|
||||
md5: Optional[str] = _grid_in_property(
|
||||
"md5",
|
||||
"DEPRECATED, will be removed in PyMongo 5.0. MD5 of the contents of this file if an md5 sum was created.",
|
||||
closed_only=True,
|
||||
)
|
||||
|
||||
_buffer: io.BytesIO
|
||||
_closed: bool
|
||||
_buffered_docs: list[dict[str, Any]]
|
||||
_buffered_docs_size: int
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in self._file:
|
||||
return self._file[name]
|
||||
raise AttributeError("GridIn object has no attribute '%s'" % name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
# For properties of this instance like _buffer, or descriptors set on
|
||||
# the class like filename, use regular __setattr__
|
||||
if name in self.__dict__ or name in self.__class__.__dict__:
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
# All other attributes are part of the document in db.fs.files.
|
||||
# Store them to be sent to server on close() or if closed, send
|
||||
# them now.
|
||||
self._file[name] = value
|
||||
if self._closed:
|
||||
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
|
||||
|
||||
def __flush_data(self, data: Any, force: bool = False) -> None:
|
||||
"""Flush `data` to a chunk."""
|
||||
self.__ensure_indexes()
|
||||
assert len(data) <= self.chunk_size
|
||||
if data:
|
||||
self._buffered_docs.append(
|
||||
{"files_id": self._file["_id"], "n": self._chunk_number, "data": data}
|
||||
)
|
||||
self._buffered_docs_size += len(data) + _CHUNK_OVERHEAD
|
||||
if not self._buffered_docs:
|
||||
return
|
||||
# Limit to 100,000 chunks or 32MB (+1 chunk) of data.
|
||||
if (
|
||||
force
|
||||
or self._buffered_docs_size >= _UPLOAD_BUFFER_SIZE
|
||||
or len(self._buffered_docs) >= _UPLOAD_BUFFER_CHUNKS
|
||||
):
|
||||
try:
|
||||
self._chunks.insert_many(self._buffered_docs, session=self._session)
|
||||
except BulkWriteError as exc:
|
||||
# For backwards compatibility, raise an insert_one style exception.
|
||||
write_errors = exc.details["writeErrors"]
|
||||
for err in write_errors:
|
||||
if err.get("code") in (11000, 11001, 12582): # Duplicate key errors
|
||||
self._raise_file_exists(self._file["_id"])
|
||||
result = {"writeErrors": write_errors}
|
||||
wces = exc.details["writeConcernErrors"]
|
||||
if wces:
|
||||
result["writeConcernError"] = wces[-1]
|
||||
_check_write_command_response(result)
|
||||
raise
|
||||
self._buffered_docs = []
|
||||
self._buffered_docs_size = 0
|
||||
self._chunk_number += 1
|
||||
self._position += len(data)
|
||||
|
||||
def __flush_buffer(self, force: bool = False) -> None:
|
||||
"""Flush the buffer contents out to a chunk."""
|
||||
self.__flush_data(self._buffer.getvalue(), force=force)
|
||||
self._buffer.close()
|
||||
self._buffer = io.BytesIO()
|
||||
|
||||
def __flush(self) -> Any:
|
||||
"""Flush the file to the database."""
|
||||
try:
|
||||
self.__flush_buffer(force=True)
|
||||
# The GridFS spec says length SHOULD be an Int64.
|
||||
self._file["length"] = Int64(self._position)
|
||||
self._file["uploadDate"] = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
|
||||
return self._coll.files.insert_one(self._file, session=self._session)
|
||||
except DuplicateKeyError:
|
||||
self._raise_file_exists(self._id)
|
||||
|
||||
def _raise_file_exists(self, file_id: Any) -> NoReturn:
|
||||
"""Raise a FileExists exception for the given file_id."""
|
||||
raise FileExists("file with _id %r already exists" % file_id)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Flush the file and close it.
|
||||
|
||||
A closed file cannot be written any more. Calling
|
||||
:meth:`close` more than once is allowed.
|
||||
"""
|
||||
if not self._closed:
|
||||
self.__flush()
|
||||
object.__setattr__(self, "_closed", True)
|
||||
|
||||
def read(self, size: int = -1) -> NoReturn:
|
||||
raise io.UnsupportedOperation("read")
|
||||
|
||||
def readable(self) -> bool:
|
||||
return False
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return False
|
||||
|
||||
def write(self, data: Any) -> None:
|
||||
"""Write data to the file. There is no return value.
|
||||
|
||||
`data` can be either a string of bytes or a file-like object
|
||||
(implementing :meth:`read`). If the file has an
|
||||
:attr:`encoding` attribute, `data` can also be a
|
||||
:class:`str` instance, which will be encoded as
|
||||
:attr:`encoding` before being written.
|
||||
|
||||
Due to buffering, the data may not actually be written to the
|
||||
database until the :meth:`close` method is called. Raises
|
||||
:class:`ValueError` if this file is already closed. Raises
|
||||
:class:`TypeError` if `data` is not an instance of
|
||||
:class:`bytes`, a file-like object, or an instance of :class:`str`.
|
||||
Unicode data is only allowed if the file has an :attr:`encoding`
|
||||
attribute.
|
||||
|
||||
:param data: string of bytes or file-like object to be written
|
||||
to the file
|
||||
"""
|
||||
if self._closed:
|
||||
raise ValueError("cannot write to a closed file")
|
||||
|
||||
try:
|
||||
# file-like
|
||||
read = data.read
|
||||
except AttributeError:
|
||||
# string
|
||||
if not isinstance(data, (str, bytes)):
|
||||
raise TypeError("can only write strings or file-like objects") from None
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = data.encode(self.encoding)
|
||||
except AttributeError:
|
||||
raise TypeError(
|
||||
"must specify an encoding for file in order to write str"
|
||||
) from None
|
||||
read = io.BytesIO(data).read
|
||||
|
||||
if self._buffer.tell() > 0:
|
||||
# Make sure to flush only when _buffer is complete
|
||||
space = self.chunk_size - self._buffer.tell()
|
||||
if space:
|
||||
try:
|
||||
to_write = read(space)
|
||||
except BaseException:
|
||||
self.abort()
|
||||
raise
|
||||
self._buffer.write(to_write)
|
||||
if len(to_write) < space:
|
||||
return # EOF or incomplete
|
||||
self.__flush_buffer()
|
||||
to_write = read(self.chunk_size)
|
||||
while to_write and len(to_write) == self.chunk_size:
|
||||
self.__flush_data(to_write)
|
||||
to_write = read(self.chunk_size)
|
||||
self._buffer.write(to_write)
|
||||
|
||||
def writelines(self, sequence: Iterable[Any]) -> None:
|
||||
"""Write a sequence of strings to the file.
|
||||
|
||||
Does not add separators.
|
||||
"""
|
||||
for line in sequence:
|
||||
self.write(line)
|
||||
|
||||
def writeable(self) -> bool:
|
||||
return True
|
||||
|
||||
def __enter__(self) -> GridIn:
|
||||
"""Support for the context manager protocol."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any:
|
||||
"""Support for the context manager protocol.
|
||||
|
||||
Close the file if no exceptions occur and allow exceptions to propagate.
|
||||
"""
|
||||
if exc_type is None:
|
||||
# No exceptions happened.
|
||||
self.close()
|
||||
else:
|
||||
# Something happened, at minimum mark as closed.
|
||||
object.__setattr__(self, "_closed", True)
|
||||
|
||||
# propagate exceptions
|
||||
return False
|
||||
|
||||
|
||||
class GridOut(io.IOBase):
|
||||
"""Class to read data out of GridFS."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_collection: Collection,
|
||||
file_id: Optional[int] = None,
|
||||
file_document: Optional[Any] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> None:
|
||||
"""Read a file from GridFS
|
||||
|
||||
Application developers should generally not need to
|
||||
instantiate this class directly - instead see the methods
|
||||
provided by :class:`~gridfs.GridFS`.
|
||||
|
||||
Either `file_id` or `file_document` must be specified,
|
||||
`file_document` will be given priority if present. Raises
|
||||
:class:`TypeError` if `root_collection` is not an instance of
|
||||
:class:`~pymongo.collection.Collection`.
|
||||
|
||||
:param root_collection: root collection to read from
|
||||
:param file_id: value of ``"_id"`` for the file to read
|
||||
:param file_document: file document from
|
||||
`root_collection.files`
|
||||
:param session: a
|
||||
:class:`~pymongo.client_session.ClientSession` to use for all
|
||||
commands
|
||||
|
||||
.. versionchanged:: 3.8
|
||||
For better performance and to better follow the GridFS spec,
|
||||
:class:`GridOut` now uses a single cursor to read all the chunks in
|
||||
the file.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
Creating a GridOut does not immediately retrieve the file metadata
|
||||
from the server. Metadata is fetched when first needed.
|
||||
"""
|
||||
if not isinstance(root_collection, Collection):
|
||||
raise TypeError("root_collection must be an instance of Collection")
|
||||
_disallow_transactions(session)
|
||||
|
||||
root_collection = _clear_entity_type_registry(root_collection)
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.__chunks = root_collection.chunks
|
||||
self.__files = root_collection.files
|
||||
self.__file_id = file_id
|
||||
self.__buffer = EMPTY
|
||||
# Start position within the current buffered chunk.
|
||||
self.__buffer_pos = 0
|
||||
self.__chunk_iter = None
|
||||
# Position within the total file.
|
||||
self.__position = 0
|
||||
self._file = file_document
|
||||
self._session = session
|
||||
|
||||
_id: Any = _grid_out_property("_id", "The ``'_id'`` value for this file.")
|
||||
filename: str = _grid_out_property("filename", "Name of this file.")
|
||||
name: str = _grid_out_property("filename", "Alias for `filename`.")
|
||||
content_type: Optional[str] = _grid_out_property(
|
||||
"contentType", "DEPRECATED, will be removed in PyMongo 5.0. Mime-type for this file."
|
||||
)
|
||||
length: int = _grid_out_property("length", "Length (in bytes) of this file.")
|
||||
chunk_size: int = _grid_out_property("chunkSize", "Chunk size for this file.")
|
||||
upload_date: datetime.datetime = _grid_out_property(
|
||||
"uploadDate", "Date that this file was first uploaded."
|
||||
)
|
||||
aliases: Optional[list[str]] = _grid_out_property(
|
||||
"aliases", "DEPRECATED, will be removed in PyMongo 5.0. List of aliases for this file."
|
||||
)
|
||||
metadata: Optional[Mapping[str, Any]] = _grid_out_property(
|
||||
"metadata", "Metadata attached to this file."
|
||||
)
|
||||
md5: Optional[str] = _grid_out_property(
|
||||
"md5",
|
||||
"DEPRECATED, will be removed in PyMongo 5.0. MD5 of the contents of this file if an md5 sum was created.",
|
||||
)
|
||||
|
||||
_file: Any
|
||||
__chunk_iter: Any
|
||||
|
||||
def _ensure_file(self) -> None:
|
||||
if not self._file:
|
||||
_disallow_transactions(self._session)
|
||||
self._file = self.__files.find_one({"_id": self.__file_id}, session=self._session)
|
||||
if not self._file:
|
||||
raise NoFile(
|
||||
f"no file in gridfs collection {self.__files!r} with _id {self.__file_id!r}"
|
||||
)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
self._ensure_file()
|
||||
if name in self._file:
|
||||
return self._file[name]
|
||||
raise AttributeError("GridOut object has no attribute '%s'" % name)
|
||||
|
||||
def readable(self) -> bool:
|
||||
return True
|
||||
|
||||
def readchunk(self) -> bytes:
|
||||
"""Reads a chunk at a time. If the current position is within a
|
||||
chunk the remainder of the chunk is returned.
|
||||
"""
|
||||
received = len(self.__buffer) - self.__buffer_pos
|
||||
chunk_data = EMPTY
|
||||
chunk_size = int(self.chunk_size)
|
||||
|
||||
if received > 0:
|
||||
chunk_data = self.__buffer[self.__buffer_pos :]
|
||||
elif self.__position < int(self.length):
|
||||
chunk_number = int((received + self.__position) / chunk_size)
|
||||
if self.__chunk_iter is None:
|
||||
self.__chunk_iter = _GridOutChunkIterator(
|
||||
self, self.__chunks, self._session, chunk_number
|
||||
)
|
||||
|
||||
chunk = self.__chunk_iter.next()
|
||||
chunk_data = chunk["data"][self.__position % chunk_size :]
|
||||
|
||||
if not chunk_data:
|
||||
raise CorruptGridFile("truncated chunk")
|
||||
|
||||
self.__position += len(chunk_data)
|
||||
self.__buffer = EMPTY
|
||||
self.__buffer_pos = 0
|
||||
return chunk_data
|
||||
|
||||
def _read_size_or_line(self, size: int = -1, line: bool = False) -> bytes:
|
||||
"""Internal read() and readline() helper."""
|
||||
self._ensure_file()
|
||||
remainder = int(self.length) - self.__position
|
||||
if size < 0 or size > remainder:
|
||||
size = remainder
|
||||
|
||||
if size == 0:
|
||||
return EMPTY
|
||||
|
||||
received = 0
|
||||
data = []
|
||||
while received < size:
|
||||
needed = size - received
|
||||
if self.__buffer:
|
||||
# Optimization: Read the buffer with zero byte copies.
|
||||
buf = self.__buffer
|
||||
chunk_start = self.__buffer_pos
|
||||
chunk_data = memoryview(buf)[self.__buffer_pos :]
|
||||
self.__buffer = EMPTY
|
||||
self.__buffer_pos = 0
|
||||
self.__position += len(chunk_data)
|
||||
else:
|
||||
buf = self.readchunk()
|
||||
chunk_start = 0
|
||||
chunk_data = memoryview(buf)
|
||||
if line:
|
||||
pos = buf.find(NEWLN, chunk_start, chunk_start + needed) - chunk_start
|
||||
if pos >= 0:
|
||||
# Decrease size to exit the loop.
|
||||
size = received + pos + 1
|
||||
needed = pos + 1
|
||||
if len(chunk_data) > needed:
|
||||
data.append(chunk_data[:needed])
|
||||
# Optimization: Save the buffer with zero byte copies.
|
||||
self.__buffer = buf
|
||||
self.__buffer_pos = chunk_start + needed
|
||||
self.__position -= len(self.__buffer) - self.__buffer_pos
|
||||
else:
|
||||
data.append(chunk_data)
|
||||
received += len(chunk_data)
|
||||
|
||||
# Detect extra chunks after reading the entire file.
|
||||
if size == remainder and self.__chunk_iter:
|
||||
try:
|
||||
self.__chunk_iter.next()
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
return b"".join(data)
|
||||
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
"""Read at most `size` bytes from the file (less if there
|
||||
isn't enough data).
|
||||
|
||||
The bytes are returned as an instance of :class:`bytes`
|
||||
If `size` is negative or omitted all data is read.
|
||||
|
||||
:param size: the number of bytes to read
|
||||
|
||||
.. versionchanged:: 3.8
|
||||
This method now only checks for extra chunks after reading the
|
||||
entire file. Previously, this method would check for extra chunks
|
||||
on every call.
|
||||
"""
|
||||
return self._read_size_or_line(size=size)
|
||||
|
||||
def readline(self, size: int = -1) -> bytes: # type: ignore[override]
|
||||
"""Read one line or up to `size` bytes from the file.
|
||||
|
||||
:param size: the maximum number of bytes to read
|
||||
"""
|
||||
return self._read_size_or_line(size=size, line=True)
|
||||
|
||||
def tell(self) -> int:
|
||||
"""Return the current position of this file."""
|
||||
return self.__position
|
||||
|
||||
def seek(self, pos: int, whence: int = _SEEK_SET) -> int:
|
||||
"""Set the current position of this file.
|
||||
|
||||
:param pos: the position (or offset if using relative
|
||||
positioning) to seek to
|
||||
:param whence: where to seek
|
||||
from. :attr:`os.SEEK_SET` (``0``) for absolute file
|
||||
positioning, :attr:`os.SEEK_CUR` (``1``) to seek relative
|
||||
to the current position, :attr:`os.SEEK_END` (``2``) to
|
||||
seek relative to the file's end.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The method now returns the new position in the file, to
|
||||
conform to the behavior of :meth:`io.IOBase.seek`.
|
||||
"""
|
||||
if whence == _SEEK_SET:
|
||||
new_pos = pos
|
||||
elif whence == _SEEK_CUR:
|
||||
new_pos = self.__position + pos
|
||||
elif whence == _SEEK_END:
|
||||
new_pos = int(self.length) + pos
|
||||
else:
|
||||
raise OSError(22, "Invalid value for `whence`")
|
||||
|
||||
if new_pos < 0:
|
||||
raise OSError(22, "Invalid value for `pos` - must be positive")
|
||||
|
||||
# Optimization, continue using the same buffer and chunk iterator.
|
||||
if new_pos == self.__position:
|
||||
return new_pos
|
||||
|
||||
self.__position = new_pos
|
||||
self.__buffer = EMPTY
|
||||
self.__buffer_pos = 0
|
||||
if self.__chunk_iter:
|
||||
self.__chunk_iter.close()
|
||||
self.__chunk_iter = None
|
||||
return new_pos
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return True
|
||||
|
||||
def __iter__(self) -> GridOut:
|
||||
"""Return an iterator over all of this file's data.
|
||||
|
||||
The iterator will return lines (delimited by ``b'\\n'``) of
|
||||
:class:`bytes`. This can be useful when serving files
|
||||
using a webserver that handles such an iterator efficiently.
|
||||
|
||||
.. versionchanged:: 3.8
|
||||
The iterator now raises :class:`CorruptGridFile` when encountering
|
||||
any truncated, missing, or extra chunk in a file. The previous
|
||||
behavior was to only raise :class:`CorruptGridFile` on a missing
|
||||
chunk.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
The iterator now iterates over *lines* in the file, instead
|
||||
of chunks, to conform to the base class :py:class:`io.IOBase`.
|
||||
Use :meth:`GridOut.readchunk` to read chunk by chunk instead
|
||||
of line by line.
|
||||
"""
|
||||
return self
|
||||
|
||||
def close(self) -> None:
|
||||
"""Make GridOut more generically file-like."""
|
||||
if self.__chunk_iter:
|
||||
self.__chunk_iter.close()
|
||||
self.__chunk_iter = None
|
||||
super().close()
|
||||
|
||||
def write(self, value: Any) -> NoReturn:
|
||||
raise io.UnsupportedOperation("write")
|
||||
|
||||
def writelines(self, lines: Any) -> NoReturn:
|
||||
raise io.UnsupportedOperation("writelines")
|
||||
|
||||
def writable(self) -> bool:
|
||||
return False
|
||||
|
||||
def __enter__(self) -> GridOut:
|
||||
"""Makes it possible to use :class:`GridOut` files
|
||||
with the context manager protocol.
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any:
|
||||
"""Makes it possible to use :class:`GridOut` files
|
||||
with the context manager protocol.
|
||||
"""
|
||||
self.close()
|
||||
return False
|
||||
|
||||
def fileno(self) -> NoReturn:
|
||||
raise io.UnsupportedOperation("fileno")
|
||||
|
||||
def flush(self) -> None:
|
||||
# GridOut is read-only, so flush does nothing.
|
||||
pass
|
||||
|
||||
def isatty(self) -> bool:
|
||||
return False
|
||||
|
||||
def truncate(self, size: Optional[int] = None) -> NoReturn:
|
||||
# See https://docs.python.org/3/library/io.html#io.IOBase.writable
|
||||
# for why truncate has to raise.
|
||||
raise io.UnsupportedOperation("truncate")
|
||||
|
||||
# Override IOBase.__del__ otherwise it will lead to __getattr__ on
|
||||
# __IOBase_closed which calls _ensure_file and potentially performs I/O.
|
||||
# We cannot do I/O in __del__ since it can lead to a deadlock.
|
||||
def __del__(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _GridOutChunkIterator:
|
||||
"""Iterates over a file's chunks using a single cursor.
|
||||
|
||||
Raises CorruptGridFile when encountering any truncated, missing, or extra
|
||||
chunk in a file.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grid_out: GridOut,
|
||||
chunks: Collection,
|
||||
session: Optional[ClientSession],
|
||||
next_chunk: Any,
|
||||
) -> None:
|
||||
self._id = grid_out._id
|
||||
self._chunk_size = int(grid_out.chunk_size)
|
||||
self._length = int(grid_out.length)
|
||||
self._chunks = chunks
|
||||
self._session = session
|
||||
self._next_chunk = next_chunk
|
||||
self._num_chunks = math.ceil(float(self._length) / self._chunk_size)
|
||||
self._cursor = None
|
||||
|
||||
_cursor: Optional[Cursor]
|
||||
|
||||
def expected_chunk_length(self, chunk_n: int) -> int:
|
||||
if chunk_n < self._num_chunks - 1:
|
||||
return self._chunk_size
|
||||
return self._length - (self._chunk_size * (self._num_chunks - 1))
|
||||
|
||||
def __iter__(self) -> _GridOutChunkIterator:
|
||||
return self
|
||||
|
||||
def _create_cursor(self) -> None:
|
||||
filter = {"files_id": self._id}
|
||||
if self._next_chunk > 0:
|
||||
filter["n"] = {"$gte": self._next_chunk}
|
||||
_disallow_transactions(self._session)
|
||||
self._cursor = self._chunks.find(filter, sort=[("n", 1)], session=self._session)
|
||||
|
||||
def _next_with_retry(self) -> Mapping[str, Any]:
|
||||
"""Return the next chunk and retry once on CursorNotFound.
|
||||
|
||||
We retry on CursorNotFound to maintain backwards compatibility in
|
||||
cases where two calls to read occur more than 10 minutes apart (the
|
||||
server's default cursor timeout).
|
||||
"""
|
||||
if self._cursor is None:
|
||||
self._create_cursor()
|
||||
assert self._cursor is not None
|
||||
try:
|
||||
return self._cursor.next()
|
||||
except CursorNotFound:
|
||||
self._cursor.close()
|
||||
self._create_cursor()
|
||||
return self._cursor.next()
|
||||
|
||||
def next(self) -> Mapping[str, Any]:
|
||||
try:
|
||||
chunk = self._next_with_retry()
|
||||
except StopIteration:
|
||||
if self._next_chunk >= self._num_chunks:
|
||||
raise
|
||||
raise CorruptGridFile("no chunk #%d" % self._next_chunk) from None
|
||||
|
||||
if chunk["n"] != self._next_chunk:
|
||||
self.close()
|
||||
raise CorruptGridFile(
|
||||
"Missing chunk: expected chunk #%d but found "
|
||||
"chunk with n=%d" % (self._next_chunk, chunk["n"])
|
||||
)
|
||||
|
||||
if chunk["n"] >= self._num_chunks:
|
||||
# According to spec, ignore extra chunks if they are empty.
|
||||
if len(chunk["data"]):
|
||||
self.close()
|
||||
raise CorruptGridFile(
|
||||
"Extra chunk found: expected %d chunks but found "
|
||||
"chunk with n=%d" % (self._num_chunks, chunk["n"])
|
||||
)
|
||||
|
||||
expected_length = self.expected_chunk_length(chunk["n"])
|
||||
if len(chunk["data"]) != expected_length:
|
||||
self.close()
|
||||
raise CorruptGridFile(
|
||||
"truncated chunk #%d: expected chunk length to be %d but "
|
||||
"found chunk with length %d" % (chunk["n"], expected_length, len(chunk["data"]))
|
||||
)
|
||||
|
||||
self._next_chunk += 1
|
||||
return chunk
|
||||
|
||||
__next__ = next
|
||||
|
||||
def close(self) -> None:
|
||||
if self._cursor:
|
||||
self._cursor.close()
|
||||
self._cursor = None
|
||||
|
||||
|
||||
class GridOutIterator:
|
||||
def __init__(self, grid_out: GridOut, chunks: Collection, session: ClientSession):
|
||||
self.__chunk_iter = _GridOutChunkIterator(grid_out, chunks, session, 0)
|
||||
|
||||
def __iter__(self) -> GridOutIterator:
|
||||
return self
|
||||
|
||||
def next(self) -> bytes:
|
||||
chunk = self.__chunk_iter.next()
|
||||
return bytes(chunk["data"])
|
||||
|
||||
__next__ = next
|
||||
|
||||
|
||||
class GridOutCursor(Cursor):
|
||||
"""A cursor / iterator for returning GridOut objects as the result
|
||||
of an arbitrary query against the GridFS files collection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: Collection,
|
||||
filter: Optional[Mapping[str, Any]] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 0,
|
||||
no_cursor_timeout: bool = False,
|
||||
sort: Optional[Any] = None,
|
||||
batch_size: int = 0,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> None:
|
||||
"""Create a new cursor, similar to the normal
|
||||
:class:`~pymongo.cursor.Cursor`.
|
||||
|
||||
Should not be called directly by application developers - see
|
||||
the :class:`~gridfs.GridFS` method :meth:`~gridfs.GridFS.find` instead.
|
||||
|
||||
.. versionadded 2.7
|
||||
|
||||
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
|
||||
"""
|
||||
_disallow_transactions(session)
|
||||
collection = _clear_entity_type_registry(collection)
|
||||
|
||||
# Hold on to the base "fs" collection to create GridOut objects later.
|
||||
self.__root_collection = collection
|
||||
|
||||
super().__init__(
|
||||
collection.files,
|
||||
filter,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
no_cursor_timeout=no_cursor_timeout,
|
||||
sort=sort,
|
||||
batch_size=batch_size,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def next(self) -> GridOut:
|
||||
"""Get next GridOut object from cursor."""
|
||||
_disallow_transactions(self.session)
|
||||
next_file = super().next()
|
||||
return GridOut(self.__root_collection, file_document=next_file, session=self.session)
|
||||
|
||||
__next__ = next
|
||||
|
||||
def add_option(self, *args: Any, **kwargs: Any) -> NoReturn:
|
||||
raise NotImplementedError("Method does not exist for GridOutCursor")
|
||||
|
||||
def remove_option(self, *args: Any, **kwargs: Any) -> NoReturn:
|
||||
raise NotImplementedError("Method does not exist for GridOutCursor")
|
||||
|
||||
def _clone_base(self, session: Optional[ClientSession]) -> GridOutCursor:
|
||||
"""Creates an empty GridOutCursor for information to be copied into."""
|
||||
return GridOutCursor(self.__root_collection, session=session)
|
||||
from gridfs.synchronous.grid_file import * # noqa: F403
|
||||
|
||||
149
gridfs/grid_file_shared.py
Normal file
149
gridfs/grid_file_shared.py
Normal file
@ -0,0 +1,149 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Optional
|
||||
|
||||
from pymongo import ASCENDING
|
||||
from pymongo.asynchronous.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.errors import InvalidOperation
|
||||
|
||||
_SEEK_SET = os.SEEK_SET
|
||||
_SEEK_CUR = os.SEEK_CUR
|
||||
_SEEK_END = os.SEEK_END
|
||||
|
||||
EMPTY = b""
|
||||
NEWLN = b"\n"
|
||||
|
||||
"""Default chunk size, in bytes."""
|
||||
# Slightly under a power of 2, to work well with server's record allocations.
|
||||
DEFAULT_CHUNK_SIZE = 255 * 1024
|
||||
# The number of chunked bytes to buffer before calling insert_many.
|
||||
_UPLOAD_BUFFER_SIZE = MAX_MESSAGE_SIZE
|
||||
# The number of chunk documents to buffer before calling insert_many.
|
||||
_UPLOAD_BUFFER_CHUNKS = 100000
|
||||
# Rough BSON overhead of a chunk document not including the chunk data itself.
|
||||
# Essentially len(encode({"_id": ObjectId(), "files_id": ObjectId(), "n": 1, "data": ""}))
|
||||
_CHUNK_OVERHEAD = 60
|
||||
|
||||
_C_INDEX: dict[str, Any] = {"files_id": ASCENDING, "n": ASCENDING}
|
||||
_F_INDEX: dict[str, Any] = {"filename": ASCENDING, "uploadDate": ASCENDING}
|
||||
|
||||
|
||||
def _a_grid_in_property(
|
||||
field_name: str,
|
||||
docstring: str,
|
||||
read_only: Optional[bool] = False,
|
||||
closed_only: Optional[bool] = False,
|
||||
) -> Any:
|
||||
"""Create a GridIn property."""
|
||||
|
||||
def getter(self: Any) -> Any:
|
||||
if closed_only and not self._closed:
|
||||
raise AttributeError("can only get %r on a closed file" % field_name)
|
||||
# Protect against PHP-237
|
||||
if field_name == "length":
|
||||
return self._file.get(field_name, 0)
|
||||
return self._file.get(field_name, None)
|
||||
|
||||
if read_only:
|
||||
docstring += "\n\nThis attribute is read-only."
|
||||
elif closed_only:
|
||||
docstring = "{}\n\n{}".format(
|
||||
docstring,
|
||||
"This attribute is read-only and "
|
||||
"can only be read after :meth:`close` "
|
||||
"has been called.",
|
||||
)
|
||||
|
||||
return property(getter, doc=docstring)
|
||||
|
||||
|
||||
def _a_grid_out_property(field_name: str, docstring: str) -> Any:
|
||||
"""Create a GridOut property."""
|
||||
|
||||
def a_getter(self: Any) -> Any:
|
||||
if not self._file:
|
||||
raise InvalidOperation(
|
||||
"You must call GridOut.open() before accessing " "the %s property" % field_name
|
||||
)
|
||||
# Protect against PHP-237
|
||||
if field_name == "length":
|
||||
return self._file.get(field_name, 0)
|
||||
return self._file.get(field_name, None)
|
||||
|
||||
docstring += "\n\nThis attribute is read-only."
|
||||
return property(a_getter, doc=docstring)
|
||||
|
||||
|
||||
def _grid_in_property(
|
||||
field_name: str,
|
||||
docstring: str,
|
||||
read_only: Optional[bool] = False,
|
||||
closed_only: Optional[bool] = False,
|
||||
) -> Any:
|
||||
"""Create a GridIn property."""
|
||||
warn_str = ""
|
||||
if docstring.startswith("DEPRECATED,"):
|
||||
warn_str = (
|
||||
f"GridIn property '{field_name}' is deprecated and will be removed in PyMongo 5.0"
|
||||
)
|
||||
|
||||
def getter(self: Any) -> Any:
|
||||
if warn_str:
|
||||
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
|
||||
if closed_only and not self._closed:
|
||||
raise AttributeError("can only get %r on a closed file" % field_name)
|
||||
# Protect against PHP-237
|
||||
if field_name == "length":
|
||||
return self._file.get(field_name, 0)
|
||||
return self._file.get(field_name, None)
|
||||
|
||||
def setter(self: Any, value: Any) -> Any:
|
||||
if warn_str:
|
||||
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
|
||||
if self._closed:
|
||||
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {field_name: value}})
|
||||
self._file[field_name] = value
|
||||
|
||||
if read_only:
|
||||
docstring += "\n\nThis attribute is read-only."
|
||||
elif closed_only:
|
||||
docstring = "{}\n\n{}".format(
|
||||
docstring,
|
||||
"This attribute is read-only and "
|
||||
"can only be read after :meth:`close` "
|
||||
"has been called.",
|
||||
)
|
||||
|
||||
if not read_only and not closed_only:
|
||||
return property(getter, setter, doc=docstring)
|
||||
return property(getter, doc=docstring)
|
||||
|
||||
|
||||
def _grid_out_property(field_name: str, docstring: str) -> Any:
|
||||
"""Create a GridOut property."""
|
||||
warn_str = ""
|
||||
if docstring.startswith("DEPRECATED,"):
|
||||
warn_str = (
|
||||
f"GridOut property '{field_name}' is deprecated and will be removed in PyMongo 5.0"
|
||||
)
|
||||
|
||||
def getter(self: Any) -> Any:
|
||||
if warn_str:
|
||||
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
|
||||
self.open()
|
||||
|
||||
# Protect against PHP-237
|
||||
if field_name == "length":
|
||||
return self._file.get(field_name, 0)
|
||||
return self._file.get(field_name, None)
|
||||
|
||||
docstring += "\n\nThis attribute is read-only."
|
||||
return property(getter, doc=docstring)
|
||||
|
||||
|
||||
def _clear_entity_type_registry(entity: Any, **kwargs: Any) -> Any:
|
||||
"""Clear the given database/collection object's type registry."""
|
||||
codecopts = entity.codec_options.with_options(type_registry=None)
|
||||
return entity.with_options(codec_options=codecopts, **kwargs)
|
||||
1887
gridfs/synchronous/grid_file.py
Normal file
1887
gridfs/synchronous/grid_file.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -6,3 +6,10 @@ exclude = (?x)(
|
||||
^test/mypy_fails/*.*$
|
||||
| ^test/conftest.py$
|
||||
)
|
||||
|
||||
[mypy-pymongo.synchronous.*,gridfs.synchronous.*,test.synchronous.*]
|
||||
warn_unused_ignores = false
|
||||
disable_error_code = unused-coroutine
|
||||
|
||||
[mypy-pymongo.asynchronous.*,test.asynchronous.*]
|
||||
warn_unused_ignores = false
|
||||
|
||||
@ -33,6 +33,7 @@ __all__ = [
|
||||
"MIN_SUPPORTED_WIRE_VERSION",
|
||||
"CursorType",
|
||||
"MongoClient",
|
||||
"AsyncMongoClient",
|
||||
"DeleteMany",
|
||||
"DeleteOne",
|
||||
"IndexModel",
|
||||
@ -87,11 +88,12 @@ TEXT = "text"
|
||||
|
||||
from pymongo import _csot
|
||||
from pymongo._version import __version__, get_version_string, version_tuple
|
||||
from pymongo.collection import ReturnDocument
|
||||
from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.cursor import CursorType
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.operations import (
|
||||
from pymongo.synchronous.collection import ReturnDocument
|
||||
from pymongo.synchronous.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.operations import (
|
||||
DeleteMany,
|
||||
DeleteOne,
|
||||
IndexModel,
|
||||
@ -100,7 +102,7 @@ from pymongo.operations import (
|
||||
UpdateMany,
|
||||
UpdateOne,
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.synchronous.read_preferences import ReadPreference
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
version = __version__
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import AbstractContextManager
|
||||
@ -96,16 +97,27 @@ F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def apply(func: F) -> F:
|
||||
"""Apply the client's timeoutMS to this operation."""
|
||||
"""Apply the client's timeoutMS to this operation. Can wrap both asynchronous and synchronous methods"""
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
if get_timeout() is None:
|
||||
timeout = self._timeout
|
||||
if timeout is not None:
|
||||
with _TimeoutContext(timeout):
|
||||
return func(self, *args, **kwargs)
|
||||
return func(self, *args, **kwargs)
|
||||
@functools.wraps(func)
|
||||
async def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
if get_timeout() is None:
|
||||
timeout = self._timeout
|
||||
if timeout is not None:
|
||||
with _TimeoutContext(timeout):
|
||||
return await func(self, *args, **kwargs)
|
||||
return await func(self, *args, **kwargs)
|
||||
else:
|
||||
|
||||
@functools.wraps(func)
|
||||
def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
if get_timeout() is None:
|
||||
timeout = self._timeout
|
||||
if timeout is not None:
|
||||
with _TimeoutContext(timeout):
|
||||
return func(self, *args, **kwargs)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return cast(F, csot_wrapper)
|
||||
|
||||
|
||||
0
pymongo/asynchronous/__init__.py
Normal file
0
pymongo/asynchronous/__init__.py
Normal file
257
pymongo/asynchronous/aggregation.py
Normal file
257
pymongo/asynchronous/aggregation.py
Normal file
@ -0,0 +1,257 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Perform aggregation operations on a collection or database."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping, MutableMapping
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo.asynchronous.collation import validate_collation_or_none
|
||||
from pymongo.asynchronous.read_preferences import ReadPreference, _AggWritePref
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.client_session import ClientSession
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.read_preferences import _ServerMode
|
||||
from pymongo.asynchronous.server import Server
|
||||
from pymongo.asynchronous.typings import _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class _AggregationCommand:
|
||||
"""The internal abstract base class for aggregation cursors.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
:meth:`pymongo.collection.AsyncCollection.aggregate`, or
|
||||
:meth:`pymongo.database.AsyncDatabase.aggregate` instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[AsyncDatabase, AsyncCollection],
|
||||
cursor_class: type[AsyncCommandCursor],
|
||||
pipeline: _Pipeline,
|
||||
options: MutableMapping[str, Any],
|
||||
explicit_session: bool,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
user_fields: Optional[MutableMapping[str, Any]] = None,
|
||||
result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
if "explain" in options:
|
||||
raise ConfigurationError(
|
||||
"The explain option is not supported. Use AsyncDatabase.command instead."
|
||||
)
|
||||
|
||||
self._target = target
|
||||
|
||||
pipeline = common.validate_list("pipeline", pipeline)
|
||||
self._pipeline = pipeline
|
||||
self._performs_write = False
|
||||
if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
|
||||
self._performs_write = True
|
||||
|
||||
common.validate_is_mapping("options", options)
|
||||
if let is not None:
|
||||
common.validate_is_mapping("let", let)
|
||||
options["let"] = let
|
||||
if comment is not None:
|
||||
options["comment"] = comment
|
||||
|
||||
self._options = options
|
||||
|
||||
# This is the batchSize that will be used for setting the initial
|
||||
# batchSize for the cursor, as well as the subsequent getMores.
|
||||
self._batch_size = common.validate_non_negative_integer_or_none(
|
||||
"batchSize", self._options.pop("batchSize", None)
|
||||
)
|
||||
|
||||
# If the cursor option is already specified, avoid overriding it.
|
||||
self._options.setdefault("cursor", {})
|
||||
# If the pipeline performs a write, we ignore the initial batchSize
|
||||
# since the server doesn't return results in this case.
|
||||
if self._batch_size is not None and not self._performs_write:
|
||||
self._options["cursor"]["batchSize"] = self._batch_size
|
||||
|
||||
self._cursor_class = cursor_class
|
||||
self._explicit_session = explicit_session
|
||||
self._user_fields = user_fields
|
||||
self._result_processor = result_processor
|
||||
|
||||
self._collation = validate_collation_or_none(options.pop("collation", None))
|
||||
|
||||
self._max_await_time_ms = options.pop("maxAwaitTimeMS", None)
|
||||
self._write_preference: Optional[_AggWritePref] = None
|
||||
|
||||
@property
|
||||
def _aggregation_target(self) -> Union[str, int]:
|
||||
"""The argument to pass to the aggregate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _cursor_namespace(self) -> str:
|
||||
"""The namespace in which the aggregate command is run."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> AsyncCollection:
|
||||
"""The AsyncCollection used for the aggregate command cursor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _database(self) -> AsyncDatabase:
|
||||
"""The database against which the aggregation command is run."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_read_preference(
|
||||
self, session: Optional[ClientSession]
|
||||
) -> Union[_AggWritePref, _ServerMode]:
|
||||
if self._write_preference:
|
||||
return self._write_preference
|
||||
pref = self._target._read_preference_for(session)
|
||||
if self._performs_write and pref != ReadPreference.PRIMARY:
|
||||
self._write_preference = pref = _AggWritePref(pref) # type: ignore[assignment]
|
||||
return pref
|
||||
|
||||
async def get_cursor(
|
||||
self,
|
||||
session: Optional[ClientSession],
|
||||
server: Server,
|
||||
conn: Connection,
|
||||
read_preference: _ServerMode,
|
||||
) -> AsyncCommandCursor[_DocumentType]:
|
||||
# Serialize command.
|
||||
cmd = {"aggregate": self._aggregation_target, "pipeline": self._pipeline}
|
||||
cmd.update(self._options)
|
||||
|
||||
# Apply this target's read concern if:
|
||||
# readConcern has not been specified as a kwarg and either
|
||||
# - server version is >= 4.2 or
|
||||
# - server version is >= 3.2 and pipeline doesn't use $out
|
||||
if ("readConcern" not in cmd) and (
|
||||
not self._performs_write or (conn.max_wire_version >= 8)
|
||||
):
|
||||
read_concern = self._target.read_concern
|
||||
else:
|
||||
read_concern = None
|
||||
|
||||
# Apply this target's write concern if:
|
||||
# writeConcern has not been specified as a kwarg and pipeline doesn't
|
||||
# perform a write operation
|
||||
if "writeConcern" not in cmd and self._performs_write:
|
||||
write_concern = self._target._write_concern_for(session)
|
||||
else:
|
||||
write_concern = None
|
||||
|
||||
# Run command.
|
||||
result = await conn.command(
|
||||
self._database.name,
|
||||
cmd,
|
||||
read_preference,
|
||||
self._target.codec_options,
|
||||
parse_write_concern_error=True,
|
||||
read_concern=read_concern,
|
||||
write_concern=write_concern,
|
||||
collation=self._collation,
|
||||
session=session,
|
||||
client=self._database.client,
|
||||
user_fields=self._user_fields,
|
||||
)
|
||||
|
||||
if self._result_processor:
|
||||
self._result_processor(result, conn)
|
||||
|
||||
# Extract cursor from result or mock/fake one if necessary.
|
||||
if "cursor" in result:
|
||||
cursor = result["cursor"]
|
||||
else:
|
||||
# Unacknowledged $out/$merge write. Fake a cursor.
|
||||
cursor = {
|
||||
"id": 0,
|
||||
"firstBatch": result.get("result", []),
|
||||
"ns": self._cursor_namespace,
|
||||
}
|
||||
|
||||
# Create and return cursor instance.
|
||||
cmd_cursor = self._cursor_class(
|
||||
self._cursor_collection(cursor),
|
||||
cursor,
|
||||
conn.address,
|
||||
batch_size=self._batch_size or 0,
|
||||
max_await_time_ms=self._max_await_time_ms,
|
||||
session=session,
|
||||
explicit_session=self._explicit_session,
|
||||
comment=self._options.get("comment"),
|
||||
)
|
||||
await cmd_cursor._maybe_pin_connection(conn)
|
||||
return cmd_cursor
|
||||
|
||||
|
||||
class _CollectionAggregationCommand(_AggregationCommand):
|
||||
_target: AsyncCollection
|
||||
|
||||
@property
|
||||
def _aggregation_target(self) -> str:
|
||||
return self._target.name
|
||||
|
||||
@property
|
||||
def _cursor_namespace(self) -> str:
|
||||
return self._target.full_name
|
||||
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
|
||||
"""The AsyncCollection used for the aggregate command cursor."""
|
||||
return self._target
|
||||
|
||||
@property
|
||||
def _database(self) -> AsyncDatabase:
|
||||
return self._target.database
|
||||
|
||||
|
||||
class _CollectionRawAggregationCommand(_CollectionAggregationCommand):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# For raw-batches, we set the initial batchSize for the cursor to 0.
|
||||
if not self._performs_write:
|
||||
self._options["cursor"]["batchSize"] = 0
|
||||
|
||||
|
||||
class _DatabaseAggregationCommand(_AggregationCommand):
|
||||
_target: AsyncDatabase
|
||||
|
||||
@property
|
||||
def _aggregation_target(self) -> int:
|
||||
return 1
|
||||
|
||||
@property
|
||||
def _cursor_namespace(self) -> str:
|
||||
return f"{self._target.name}.$cmd.aggregate"
|
||||
|
||||
@property
|
||||
def _database(self) -> AsyncDatabase:
|
||||
return self._target
|
||||
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
|
||||
"""The AsyncCollection used for the aggregate command cursor."""
|
||||
# AsyncCollection level aggregate may not always return the "ns" field
|
||||
# according to our MockupDB tests. Let's handle that case for db level
|
||||
# aggregate too by defaulting to the <db>.$cmd.aggregate namespace.
|
||||
_, collname = cursor.get("ns", self._cursor_namespace).split(".", 1)
|
||||
return self._database[collname]
|
||||
663
pymongo/asynchronous/auth.py
Normal file
663
pymongo/asynchronous/auth.py
Normal file
@ -0,0 +1,663 @@
|
||||
# Copyright 2013-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import socket
|
||||
import typing
|
||||
from base64 import standard_b64decode, standard_b64encode
|
||||
from collections import namedtuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import quote
|
||||
|
||||
from bson.binary import Binary
|
||||
from pymongo.asynchronous.auth_aws import _authenticate_aws
|
||||
from pymongo.asynchronous.auth_oidc import (
|
||||
_authenticate_oidc,
|
||||
_get_authenticator,
|
||||
_OIDCAzureCallback,
|
||||
_OIDCGCPCallback,
|
||||
_OIDCProperties,
|
||||
_OIDCTestCallback,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.saslprep import saslprep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.hello import Hello
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
|
||||
HAVE_KERBEROS = True
|
||||
_USE_PRINCIPAL = False
|
||||
try:
|
||||
import winkerberos as kerberos # type:ignore[import]
|
||||
|
||||
if tuple(map(int, kerberos.__version__.split(".")[:2])) >= (0, 5):
|
||||
_USE_PRINCIPAL = True
|
||||
except ImportError:
|
||||
try:
|
||||
import kerberos # type:ignore[import]
|
||||
except ImportError:
|
||||
HAVE_KERBEROS = False
|
||||
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
MECHANISMS = frozenset(
|
||||
[
|
||||
"GSSAPI",
|
||||
"MONGODB-CR",
|
||||
"MONGODB-OIDC",
|
||||
"MONGODB-X509",
|
||||
"MONGODB-AWS",
|
||||
"PLAIN",
|
||||
"SCRAM-SHA-1",
|
||||
"SCRAM-SHA-256",
|
||||
"DEFAULT",
|
||||
]
|
||||
)
|
||||
"""The authentication mechanisms supported by PyMongo."""
|
||||
|
||||
|
||||
class _Cache:
|
||||
__slots__ = ("data",)
|
||||
|
||||
_hash_val = hash("_Cache")
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = None
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
# Two instances must always compare equal.
|
||||
if isinstance(other, _Cache):
|
||||
return True
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
if isinstance(other, _Cache):
|
||||
return False
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self._hash_val
|
||||
|
||||
|
||||
MongoCredential = namedtuple(
|
||||
"MongoCredential",
|
||||
["mechanism", "source", "username", "password", "mechanism_properties", "cache"],
|
||||
)
|
||||
"""A hashable namedtuple of values used for authentication."""
|
||||
|
||||
|
||||
GSSAPIProperties = namedtuple(
|
||||
"GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"]
|
||||
)
|
||||
"""Mechanism properties for GSSAPI authentication."""
|
||||
|
||||
|
||||
_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
|
||||
"""Mechanism properties for MONGODB-AWS authentication."""
|
||||
|
||||
|
||||
def _build_credentials_tuple(
|
||||
mech: str,
|
||||
source: Optional[str],
|
||||
user: str,
|
||||
passwd: str,
|
||||
extra: Mapping[str, Any],
|
||||
database: Optional[str],
|
||||
) -> MongoCredential:
|
||||
"""Build and return a mechanism specific credentials tuple."""
|
||||
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
|
||||
raise ConfigurationError(f"{mech} requires a username.")
|
||||
if mech == "GSSAPI":
|
||||
if source is not None and source != "$external":
|
||||
raise ValueError("authentication source must be $external or None for GSSAPI")
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
service_name = properties.get("SERVICE_NAME", "mongodb")
|
||||
canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False))
|
||||
service_realm = properties.get("SERVICE_REALM")
|
||||
props = GSSAPIProperties(
|
||||
service_name=service_name,
|
||||
canonicalize_host_name=canonicalize,
|
||||
service_realm=service_realm,
|
||||
)
|
||||
# Source is always $external.
|
||||
return MongoCredential(mech, "$external", user, passwd, props, None)
|
||||
elif mech == "MONGODB-X509":
|
||||
if passwd is not None:
|
||||
raise ConfigurationError("Passwords are not supported by MONGODB-X509")
|
||||
if source is not None and source != "$external":
|
||||
raise ValueError("authentication source must be $external or None for MONGODB-X509")
|
||||
# Source is always $external, user can be None.
|
||||
return MongoCredential(mech, "$external", user, None, None, None)
|
||||
elif mech == "MONGODB-AWS":
|
||||
if user is not None and passwd is None:
|
||||
raise ConfigurationError("username without a password is not supported by MONGODB-AWS")
|
||||
if source is not None and source != "$external":
|
||||
raise ConfigurationError(
|
||||
"authentication source must be $external or None for MONGODB-AWS"
|
||||
)
|
||||
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
aws_session_token = properties.get("AWS_SESSION_TOKEN")
|
||||
aws_props = _AWSProperties(aws_session_token=aws_session_token)
|
||||
# user can be None for temporary link-local EC2 credentials.
|
||||
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
|
||||
elif mech == "MONGODB-OIDC":
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
callback = properties.get("OIDC_CALLBACK")
|
||||
human_callback = properties.get("OIDC_HUMAN_CALLBACK")
|
||||
environ = properties.get("ENVIRONMENT")
|
||||
token_resource = properties.get("TOKEN_RESOURCE", "")
|
||||
default_allowed = [
|
||||
"*.mongodb.net",
|
||||
"*.mongodb-dev.net",
|
||||
"*.mongodb-qa.net",
|
||||
"*.mongodbgov.net",
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
]
|
||||
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
|
||||
msg = (
|
||||
"authentication with MONGODB-OIDC requires providing either a callback or a environment"
|
||||
)
|
||||
if passwd is not None:
|
||||
msg = "password is not supported by MONGODB-OIDC"
|
||||
raise ConfigurationError(msg)
|
||||
if callback or human_callback:
|
||||
if environ is not None:
|
||||
raise ConfigurationError(msg)
|
||||
if callback and human_callback:
|
||||
msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK"
|
||||
raise ConfigurationError(msg)
|
||||
elif environ is not None:
|
||||
if environ == "test":
|
||||
if user is not None:
|
||||
msg = "test environment for MONGODB-OIDC does not support username"
|
||||
raise ConfigurationError(msg)
|
||||
callback = _OIDCTestCallback()
|
||||
elif environ == "azure":
|
||||
passwd = None
|
||||
if not token_resource:
|
||||
raise ConfigurationError(
|
||||
"Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
|
||||
)
|
||||
callback = _OIDCAzureCallback(token_resource)
|
||||
elif environ == "gcp":
|
||||
passwd = None
|
||||
if not token_resource:
|
||||
raise ConfigurationError(
|
||||
"GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
|
||||
)
|
||||
callback = _OIDCGCPCallback(token_resource)
|
||||
else:
|
||||
raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}")
|
||||
else:
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
oidc_props = _OIDCProperties(
|
||||
callback=callback,
|
||||
human_callback=human_callback,
|
||||
environment=environ,
|
||||
allowed_hosts=allowed_hosts,
|
||||
token_resource=token_resource,
|
||||
username=user,
|
||||
)
|
||||
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
|
||||
|
||||
elif mech == "PLAIN":
|
||||
source_database = source or database or "$external"
|
||||
return MongoCredential(mech, source_database, user, passwd, None, None)
|
||||
else:
|
||||
source_database = source or database or "admin"
|
||||
if passwd is None:
|
||||
raise ConfigurationError("A password is required.")
|
||||
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
|
||||
|
||||
|
||||
def _xor(fir: bytes, sec: bytes) -> bytes:
|
||||
"""XOR two byte strings together."""
|
||||
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
|
||||
|
||||
|
||||
def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]:
|
||||
"""Split a scram response into key, value pairs."""
|
||||
return dict(
|
||||
typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1))
|
||||
for item in response.split(b",")
|
||||
)
|
||||
|
||||
|
||||
def _authenticate_scram_start(
|
||||
credentials: MongoCredential, mechanism: str
|
||||
) -> tuple[bytes, bytes, MutableMapping[str, Any]]:
|
||||
username = credentials.username
|
||||
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
|
||||
nonce = standard_b64encode(os.urandom(32))
|
||||
first_bare = b"n=" + user + b",r=" + nonce
|
||||
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": mechanism,
|
||||
"payload": Binary(b"n,," + first_bare),
|
||||
"autoAuthorize": 1,
|
||||
"options": {"skipEmptyExchange": True},
|
||||
}
|
||||
return nonce, first_bare, cmd
|
||||
|
||||
|
||||
async def _authenticate_scram(
|
||||
credentials: MongoCredential, conn: Connection, mechanism: str
|
||||
) -> None:
|
||||
"""Authenticate using SCRAM."""
|
||||
username = credentials.username
|
||||
if mechanism == "SCRAM-SHA-256":
|
||||
digest = "sha256"
|
||||
digestmod = hashlib.sha256
|
||||
data = saslprep(credentials.password).encode("utf-8")
|
||||
else:
|
||||
digest = "sha1"
|
||||
digestmod = hashlib.sha1
|
||||
data = _password_digest(username, credentials.password).encode("utf-8")
|
||||
source = credentials.source
|
||||
cache = credentials.cache
|
||||
|
||||
# Make local
|
||||
_hmac = hmac.HMAC
|
||||
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
assert isinstance(ctx, _ScramContext)
|
||||
assert ctx.scram_data is not None
|
||||
nonce, first_bare = ctx.scram_data
|
||||
res = ctx.speculative_authenticate
|
||||
else:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
|
||||
res = await conn.command(source, cmd)
|
||||
|
||||
assert res is not None
|
||||
server_first = res["payload"]
|
||||
parsed = _parse_scram_response(server_first)
|
||||
iterations = int(parsed[b"i"])
|
||||
if iterations < 4096:
|
||||
raise OperationFailure("Server returned an invalid iteration count.")
|
||||
salt = parsed[b"s"]
|
||||
rnonce = parsed[b"r"]
|
||||
if not rnonce.startswith(nonce):
|
||||
raise OperationFailure("Server returned an invalid nonce.")
|
||||
|
||||
without_proof = b"c=biws,r=" + rnonce
|
||||
if cache.data:
|
||||
client_key, server_key, csalt, citerations = cache.data
|
||||
else:
|
||||
client_key, server_key, csalt, citerations = None, None, None, None
|
||||
|
||||
# Salt and / or iterations could change for a number of different
|
||||
# reasons. Either changing invalidates the cache.
|
||||
if not client_key or salt != csalt or iterations != citerations:
|
||||
salted_pass = hashlib.pbkdf2_hmac(digest, data, standard_b64decode(salt), iterations)
|
||||
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
|
||||
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
|
||||
cache.data = (client_key, server_key, salt, iterations)
|
||||
stored_key = digestmod(client_key).digest()
|
||||
auth_msg = b",".join((first_bare, server_first, without_proof))
|
||||
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()
|
||||
client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig))
|
||||
client_final = b",".join((without_proof, client_proof))
|
||||
|
||||
server_sig = standard_b64encode(_hmac(server_key, auth_msg, digestmod).digest())
|
||||
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": res["conversationId"],
|
||||
"payload": Binary(client_final),
|
||||
}
|
||||
res = await conn.command(source, cmd)
|
||||
|
||||
parsed = _parse_scram_response(res["payload"])
|
||||
if not hmac.compare_digest(parsed[b"v"], server_sig):
|
||||
raise OperationFailure("Server returned an invalid signature.")
|
||||
|
||||
# A third empty challenge may be required if the server does not support
|
||||
# skipEmptyExchange: SERVER-44857.
|
||||
if not res["done"]:
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": res["conversationId"],
|
||||
"payload": Binary(b""),
|
||||
}
|
||||
res = await conn.command(source, cmd)
|
||||
if not res["done"]:
|
||||
raise OperationFailure("SASL conversation failed to complete.")
|
||||
|
||||
|
||||
def _password_digest(username: str, password: str) -> str:
|
||||
"""Get a password digest to use for authentication."""
|
||||
if not isinstance(password, str):
|
||||
raise TypeError("password must be an instance of str")
|
||||
if len(password) == 0:
|
||||
raise ValueError("password can't be empty")
|
||||
if not isinstance(username, str):
|
||||
raise TypeError("username must be an instance of str")
|
||||
|
||||
md5hash = hashlib.md5() # noqa: S324
|
||||
data = f"{username}:mongo:{password}"
|
||||
md5hash.update(data.encode("utf-8"))
|
||||
return md5hash.hexdigest()
|
||||
|
||||
|
||||
def _auth_key(nonce: str, username: str, password: str) -> str:
|
||||
"""Get an auth key to use for authentication."""
|
||||
digest = _password_digest(username, password)
|
||||
md5hash = hashlib.md5() # noqa: S324
|
||||
data = f"{nonce}{username}{digest}"
|
||||
md5hash.update(data.encode("utf-8"))
|
||||
return md5hash.hexdigest()
|
||||
|
||||
|
||||
def _canonicalize_hostname(hostname: str) -> str:
|
||||
"""Canonicalize hostname following MIT-krb5 behavior."""
|
||||
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
|
||||
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
|
||||
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
|
||||
)[0]
|
||||
|
||||
try:
|
||||
name = socket.getnameinfo(sockaddr, socket.NI_NAMEREQD)
|
||||
except socket.gaierror:
|
||||
return canonname.lower()
|
||||
|
||||
return name[0].lower()
|
||||
|
||||
|
||||
async def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using GSSAPI."""
|
||||
if not HAVE_KERBEROS:
|
||||
raise ConfigurationError(
|
||||
'The "kerberos" module must be installed to use GSSAPI authentication.'
|
||||
)
|
||||
|
||||
try:
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
props = credentials.mechanism_properties
|
||||
# Starting here and continuing through the while loop below - establish
|
||||
# the security context. See RFC 4752, Section 3.1, first paragraph.
|
||||
host = conn.address[0]
|
||||
if props.canonicalize_host_name:
|
||||
host = _canonicalize_hostname(host)
|
||||
service = props.service_name + "@" + host
|
||||
if props.service_realm is not None:
|
||||
service = service + "@" + props.service_realm
|
||||
|
||||
if password is not None:
|
||||
if _USE_PRINCIPAL:
|
||||
# Note that, though we use unquote_plus for unquoting URI
|
||||
# options, we use quote here. Microsoft's UrlUnescape (used
|
||||
# by WinKerberos) doesn't support +.
|
||||
principal = ":".join((quote(username), quote(password)))
|
||||
result, ctx = kerberos.authGSSClientInit(
|
||||
service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG
|
||||
)
|
||||
else:
|
||||
if "@" in username:
|
||||
user, domain = username.split("@", 1)
|
||||
else:
|
||||
user, domain = username, None
|
||||
result, ctx = kerberos.authGSSClientInit(
|
||||
service,
|
||||
gssflags=kerberos.GSS_C_MUTUAL_FLAG,
|
||||
user=user,
|
||||
domain=domain,
|
||||
password=password,
|
||||
)
|
||||
else:
|
||||
result, ctx = kerberos.authGSSClientInit(service, gssflags=kerberos.GSS_C_MUTUAL_FLAG)
|
||||
|
||||
if result != kerberos.AUTH_GSS_COMPLETE:
|
||||
raise OperationFailure("Kerberos context failed to initialize.")
|
||||
|
||||
try:
|
||||
# pykerberos uses a weird mix of exceptions and return values
|
||||
# to indicate errors.
|
||||
# 0 == continue, 1 == complete, -1 == error
|
||||
# Only authGSSClientStep can return 0.
|
||||
if kerberos.authGSSClientStep(ctx, "") != 0:
|
||||
raise OperationFailure("Unknown kerberos failure in step function.")
|
||||
|
||||
# Start a SASL conversation with mongod/s
|
||||
# Note: pykerberos deals with base64 encoded byte strings.
|
||||
# Since mongo accepts base64 strings as the payload we don't
|
||||
# have to use bson.binary.Binary.
|
||||
payload = kerberos.authGSSClientResponse(ctx)
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": "GSSAPI",
|
||||
"payload": payload,
|
||||
"autoAuthorize": 1,
|
||||
}
|
||||
response = await conn.command("$external", cmd)
|
||||
|
||||
# Limit how many times we loop to catch protocol / library issues
|
||||
for _ in range(10):
|
||||
result = kerberos.authGSSClientStep(ctx, str(response["payload"]))
|
||||
if result == -1:
|
||||
raise OperationFailure("Unknown kerberos failure in step function.")
|
||||
|
||||
payload = kerberos.authGSSClientResponse(ctx) or ""
|
||||
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": response["conversationId"],
|
||||
"payload": payload,
|
||||
}
|
||||
response = await conn.command("$external", cmd)
|
||||
|
||||
if result == kerberos.AUTH_GSS_COMPLETE:
|
||||
break
|
||||
else:
|
||||
raise OperationFailure("Kerberos authentication failed to complete.")
|
||||
|
||||
# Once the security context is established actually authenticate.
|
||||
# See RFC 4752, Section 3.1, last two paragraphs.
|
||||
if kerberos.authGSSClientUnwrap(ctx, str(response["payload"])) != 1:
|
||||
raise OperationFailure("Unknown kerberos failure during GSS_Unwrap step.")
|
||||
|
||||
if kerberos.authGSSClientWrap(ctx, kerberos.authGSSClientResponse(ctx), username) != 1:
|
||||
raise OperationFailure("Unknown kerberos failure during GSS_Wrap step.")
|
||||
|
||||
payload = kerberos.authGSSClientResponse(ctx)
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": response["conversationId"],
|
||||
"payload": payload,
|
||||
}
|
||||
await conn.command("$external", cmd)
|
||||
|
||||
finally:
|
||||
kerberos.authGSSClientClean(ctx)
|
||||
|
||||
except kerberos.KrbError as exc:
|
||||
raise OperationFailure(str(exc)) from None
|
||||
|
||||
|
||||
async def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using SASL PLAIN (RFC 4616)"""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
payload = (f"\x00{username}\x00{password}").encode()
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": "PLAIN",
|
||||
"payload": Binary(payload),
|
||||
"autoAuthorize": 1,
|
||||
}
|
||||
await conn.command(source, cmd)
|
||||
|
||||
|
||||
async def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using MONGODB-X509."""
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
# MONGODB-X509 is done after the speculative auth step.
|
||||
return
|
||||
|
||||
cmd = _X509Context(credentials, conn.address).speculate_command()
|
||||
await conn.command("$external", cmd)
|
||||
|
||||
|
||||
async def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using MONGODB-CR."""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
# Get a nonce
|
||||
response = await conn.command(source, {"getnonce": 1})
|
||||
nonce = response["nonce"]
|
||||
key = _auth_key(nonce, username, password)
|
||||
|
||||
# Actually authenticate
|
||||
query = {"authenticate": 1, "user": username, "nonce": nonce, "key": key}
|
||||
await conn.command(source, query)
|
||||
|
||||
|
||||
async def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None:
|
||||
if conn.max_wire_version >= 7:
|
||||
if conn.negotiated_mechs:
|
||||
mechs = conn.negotiated_mechs
|
||||
else:
|
||||
source = credentials.source
|
||||
cmd = conn.hello_cmd()
|
||||
cmd["saslSupportedMechs"] = source + "." + credentials.username
|
||||
mechs = (await conn.command(source, cmd, publish_events=False)).get(
|
||||
"saslSupportedMechs", []
|
||||
)
|
||||
if "SCRAM-SHA-256" in mechs:
|
||||
return await _authenticate_scram(credentials, conn, "SCRAM-SHA-256")
|
||||
else:
|
||||
return await _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
|
||||
else:
|
||||
return await _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
|
||||
|
||||
|
||||
_AUTH_MAP: Mapping[str, Callable[..., Coroutine[Any, Any, None]]] = {
|
||||
"GSSAPI": _authenticate_gssapi,
|
||||
"MONGODB-CR": _authenticate_mongo_cr,
|
||||
"MONGODB-X509": _authenticate_x509,
|
||||
"MONGODB-AWS": _authenticate_aws,
|
||||
"MONGODB-OIDC": _authenticate_oidc, # type:ignore[dict-item]
|
||||
"PLAIN": _authenticate_plain,
|
||||
"SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"),
|
||||
"SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"),
|
||||
"DEFAULT": _authenticate_default,
|
||||
}
|
||||
|
||||
|
||||
class _AuthContext:
|
||||
def __init__(self, credentials: MongoCredential, address: tuple[str, int]) -> None:
|
||||
self.credentials = credentials
|
||||
self.speculative_authenticate: Optional[Mapping[str, Any]] = None
|
||||
self.address = address
|
||||
|
||||
@staticmethod
|
||||
def from_credentials(
|
||||
creds: MongoCredential, address: tuple[str, int]
|
||||
) -> Optional[_AuthContext]:
|
||||
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
|
||||
if spec_cls:
|
||||
return cast(_AuthContext, spec_cls(creds, address))
|
||||
return None
|
||||
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_response(self, hello: Hello[Mapping[str, Any]]) -> None:
|
||||
self.speculative_authenticate = hello.speculative_authenticate
|
||||
|
||||
def speculate_succeeded(self) -> bool:
|
||||
return bool(self.speculative_authenticate)
|
||||
|
||||
|
||||
class _ScramContext(_AuthContext):
|
||||
def __init__(
|
||||
self, credentials: MongoCredential, address: tuple[str, int], mechanism: str
|
||||
) -> None:
|
||||
super().__init__(credentials, address)
|
||||
self.scram_data: Optional[tuple[bytes, bytes]] = None
|
||||
self.mechanism = mechanism
|
||||
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism)
|
||||
# The 'db' field is included only on the speculative command.
|
||||
cmd["db"] = self.credentials.source
|
||||
# Save for later use.
|
||||
self.scram_data = (nonce, first_bare)
|
||||
return cmd
|
||||
|
||||
|
||||
class _X509Context(_AuthContext):
|
||||
def speculate_command(self) -> MutableMapping[str, Any]:
|
||||
cmd = {"authenticate": 1, "mechanism": "MONGODB-X509"}
|
||||
if self.credentials.username is not None:
|
||||
cmd["user"] = self.credentials.username
|
||||
return cmd
|
||||
|
||||
|
||||
class _OIDCContext(_AuthContext):
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
authenticator = _get_authenticator(self.credentials, self.address)
|
||||
cmd = authenticator.get_spec_auth_cmd()
|
||||
if cmd is None:
|
||||
return None
|
||||
cmd["db"] = self.credentials.source
|
||||
return cmd
|
||||
|
||||
|
||||
_SPECULATIVE_AUTH_MAP: Mapping[str, Any] = {
|
||||
"MONGODB-X509": _X509Context,
|
||||
"SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"),
|
||||
"SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
|
||||
"MONGODB-OIDC": _OIDCContext,
|
||||
"DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
|
||||
}
|
||||
|
||||
|
||||
async def authenticate(
|
||||
credentials: MongoCredential, conn: Connection, reauthenticate: bool = False
|
||||
) -> None:
|
||||
"""Authenticate connection."""
|
||||
mechanism = credentials.mechanism
|
||||
auth_func = _AUTH_MAP[mechanism]
|
||||
if mechanism == "MONGODB-OIDC":
|
||||
await _authenticate_oidc(credentials, conn, reauthenticate)
|
||||
else:
|
||||
await auth_func(credentials, conn)
|
||||
100
pymongo/asynchronous/auth_aws.py
Normal file
100
pymongo/asynchronous/auth_aws.py
Normal file
@ -0,0 +1,100 @@
|
||||
# Copyright 2020-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MONGODB-AWS Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Type
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.typings import _ReadableBuffer
|
||||
from pymongo.asynchronous.auth import MongoCredential
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
async def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using MONGODB-AWS."""
|
||||
try:
|
||||
import pymongo_auth_aws # type:ignore[import]
|
||||
except ImportError as e:
|
||||
raise ConfigurationError(
|
||||
"MONGODB-AWS authentication requires pymongo-auth-aws: "
|
||||
"install with: python -m pip install 'pymongo[aws]'"
|
||||
) from e
|
||||
# Delayed import.
|
||||
from pymongo_auth_aws.auth import ( # type:ignore[import]
|
||||
set_cached_credentials,
|
||||
set_use_cached_credentials,
|
||||
)
|
||||
|
||||
set_use_cached_credentials(True)
|
||||
|
||||
if conn.max_wire_version < 9:
|
||||
raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later")
|
||||
|
||||
class AwsSaslContext(pymongo_auth_aws.AwsSaslContext): # type: ignore
|
||||
# Dependency injection:
|
||||
def binary_type(self) -> Type[Binary]:
|
||||
"""Return the bson.binary.Binary type."""
|
||||
return Binary
|
||||
|
||||
def bson_encode(self, doc: Mapping[str, Any]) -> bytes:
|
||||
"""Encode a dictionary to BSON."""
|
||||
return bson.encode(doc)
|
||||
|
||||
def bson_decode(self, data: _ReadableBuffer) -> Mapping[str, Any]:
|
||||
"""Decode BSON to a dictionary."""
|
||||
return bson.decode(data)
|
||||
|
||||
try:
|
||||
ctx = AwsSaslContext(
|
||||
pymongo_auth_aws.AwsCredential(
|
||||
credentials.username,
|
||||
credentials.password,
|
||||
credentials.mechanism_properties.aws_session_token,
|
||||
)
|
||||
)
|
||||
client_payload = ctx.step(None)
|
||||
client_first = {"saslStart": 1, "mechanism": "MONGODB-AWS", "payload": client_payload}
|
||||
server_first = await conn.command("$external", client_first)
|
||||
res = server_first
|
||||
# Limit how many times we loop to catch protocol / library issues
|
||||
for _ in range(10):
|
||||
client_payload = ctx.step(res["payload"])
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": server_first["conversationId"],
|
||||
"payload": client_payload,
|
||||
}
|
||||
res = await conn.command("$external", cmd)
|
||||
if res["done"]:
|
||||
# SASL complete.
|
||||
break
|
||||
except pymongo_auth_aws.PyMongoAuthAwsError as exc:
|
||||
# Clear the cached credentials if we hit a failure in auth.
|
||||
set_cached_credentials(None)
|
||||
# Convert to OperationFailure and include pymongo-auth-aws version.
|
||||
raise OperationFailure(
|
||||
f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})"
|
||||
) from None
|
||||
except Exception:
|
||||
# Clear the cached credentials if we hit a failure in auth.
|
||||
set_cached_credentials(None)
|
||||
raise
|
||||
380
pymongo/asynchronous/auth_oidc.py
Normal file
380
pymongo/asynchronous/auth_oidc.py
Normal file
@ -0,0 +1,380 @@
|
||||
# Copyright 2023-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MONGODB-OIDC Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
from pymongo._azure_helpers import _get_azure_response
|
||||
from pymongo._csot import remaining
|
||||
from pymongo._gcp_helpers import _get_gcp_response
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.helpers_constants import _AUTHENTICATION_FAILURE_CODE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.auth import MongoCredential
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCIdPInfo:
|
||||
issuer: str
|
||||
clientId: Optional[str] = field(default=None)
|
||||
requestScopes: Optional[list[str]] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCCallbackContext:
|
||||
timeout_seconds: float
|
||||
username: str
|
||||
version: int
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[OIDCIdPInfo] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCCallbackResult:
|
||||
access_token: str
|
||||
expires_in_seconds: Optional[float] = field(default=None)
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
class OIDCCallback(abc.ABC):
|
||||
"""A base class for defining OIDC callbacks."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
"""Convert the given BSON value into our own type."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCProperties:
|
||||
callback: Optional[OIDCCallback] = field(default=None)
|
||||
human_callback: Optional[OIDCCallback] = field(default=None)
|
||||
environment: Optional[str] = field(default=None)
|
||||
allowed_hosts: list[str] = field(default_factory=list)
|
||||
token_resource: Optional[str] = field(default=None)
|
||||
username: str = ""
|
||||
|
||||
|
||||
"""Mechanism properties for MONGODB-OIDC authentication."""
|
||||
|
||||
TOKEN_BUFFER_MINUTES = 5
|
||||
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
|
||||
CALLBACK_VERSION = 1
|
||||
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
|
||||
TIME_BETWEEN_CALLS_SECONDS = 0.1
|
||||
|
||||
|
||||
def _get_authenticator(
|
||||
credentials: MongoCredential, address: tuple[str, int]
|
||||
) -> _OIDCAuthenticator:
|
||||
if credentials.cache.data:
|
||||
return credentials.cache.data
|
||||
|
||||
# Extract values.
|
||||
principal_name = credentials.username
|
||||
properties = credentials.mechanism_properties
|
||||
|
||||
# Validate that the address is allowed.
|
||||
if not properties.environment:
|
||||
found = False
|
||||
allowed_hosts = properties.allowed_hosts
|
||||
for patt in allowed_hosts:
|
||||
if patt == address[0]:
|
||||
found = True
|
||||
elif patt.startswith("*.") and address[0].endswith(patt[1:]):
|
||||
found = True
|
||||
if not found:
|
||||
raise ConfigurationError(
|
||||
f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
|
||||
)
|
||||
|
||||
# Get or create the cache data.
|
||||
credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties)
|
||||
return credentials.cache.data
|
||||
|
||||
|
||||
class _OIDCTestCallback(OIDCCallback):
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
token_file = os.environ.get("OIDC_TOKEN_FILE")
|
||||
if not token_file:
|
||||
raise RuntimeError(
|
||||
'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set'
|
||||
)
|
||||
with open(token_file) as fid:
|
||||
return OIDCCallbackResult(access_token=fid.read().strip())
|
||||
|
||||
|
||||
class _OIDCAWSCallback(OIDCCallback):
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
|
||||
if not token_file:
|
||||
raise RuntimeError(
|
||||
'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set'
|
||||
)
|
||||
with open(token_file) as fid:
|
||||
return OIDCCallbackResult(access_token=fid.read().strip())
|
||||
|
||||
|
||||
class _OIDCAzureCallback(OIDCCallback):
|
||||
def __init__(self, token_resource: str) -> None:
|
||||
self.token_resource = quote(token_resource)
|
||||
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds)
|
||||
return OIDCCallbackResult(
|
||||
access_token=resp["access_token"], expires_in_seconds=resp["expires_in"]
|
||||
)
|
||||
|
||||
|
||||
class _OIDCGCPCallback(OIDCCallback):
|
||||
def __init__(self, token_resource: str) -> None:
|
||||
self.token_resource = quote(token_resource)
|
||||
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
|
||||
return OIDCCallbackResult(access_token=resp["access_token"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCAuthenticator:
|
||||
username: str
|
||||
properties: _OIDCProperties
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
access_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[OIDCIdPInfo] = field(default=None)
|
||||
token_gen_id: int = field(default=0)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
last_call_time: float = field(default=0)
|
||||
|
||||
async def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
"""Handle a reauthenticate from the server."""
|
||||
# Invalidate the token for the connection.
|
||||
self._invalidate(conn)
|
||||
# Call the appropriate auth logic for the callback type.
|
||||
if self.properties.callback:
|
||||
return await self._authenticate_machine(conn)
|
||||
return await self._authenticate_human(conn)
|
||||
|
||||
async def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
"""Handle an initial authenticate request."""
|
||||
# First handle speculative auth.
|
||||
# If it succeeded, we are done.
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
resp = ctx.speculative_authenticate
|
||||
if resp and resp["done"]:
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
return resp
|
||||
|
||||
# If spec auth failed, call the appropriate auth logic for the callback type.
|
||||
# We cannot assume that the token is invalid, because a proxy may have been
|
||||
# involved that stripped the speculative auth information.
|
||||
if self.properties.callback:
|
||||
return await self._authenticate_machine(conn)
|
||||
return await self._authenticate_human(conn)
|
||||
|
||||
def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]:
|
||||
"""Get the appropriate speculative auth command."""
|
||||
if not self.access_token:
|
||||
return None
|
||||
return self._get_start_command({"jwt": self.access_token})
|
||||
|
||||
async def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]:
|
||||
# If there is a cached access token, try to authenticate with it. If
|
||||
# authentication fails with error code 18, invalidate the access token,
|
||||
# fetch a new access token, and try to authenticate again. If authentication
|
||||
# fails for any other reason, raise the error to the user.
|
||||
if self.access_token:
|
||||
try:
|
||||
return await self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
return await self._authenticate_machine(conn)
|
||||
raise
|
||||
return await self._sasl_start_jwt(conn)
|
||||
|
||||
async def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
# If we have a cached access token, try a JwtStepRequest.
|
||||
# authentication fails with error code 18, invalidate the access token,
|
||||
# and try to authenticate again. If authentication fails for any other
|
||||
# reason, raise the error to the user.
|
||||
if self.access_token:
|
||||
try:
|
||||
return await self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
return await self._authenticate_human(conn)
|
||||
raise
|
||||
|
||||
# If we have a cached refresh token, try a JwtStepRequest with that.
|
||||
# If authentication fails with error code 18, invalidate the access and
|
||||
# refresh tokens, and try to authenticate again. If authentication fails for
|
||||
# any other reason, raise the error to the user.
|
||||
if self.refresh_token:
|
||||
try:
|
||||
return await self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
self.refresh_token = None
|
||||
return await self._authenticate_human(conn)
|
||||
raise
|
||||
|
||||
# Start a new Two-Step SASL conversation.
|
||||
# Run a PrincipalStepRequest to get the IdpInfo.
|
||||
cmd = self._get_start_command(None)
|
||||
start_resp = await self._run_command(conn, cmd)
|
||||
# Attempt to authenticate with a JwtStepRequest.
|
||||
return await self._sasl_continue_jwt(conn, start_resp)
|
||||
|
||||
def _get_access_token(self) -> Optional[str]:
|
||||
properties = self.properties
|
||||
cb: Union[None, OIDCCallback]
|
||||
resp: OIDCCallbackResult
|
||||
|
||||
is_human = properties.human_callback is not None
|
||||
if is_human and self.idp_info is None:
|
||||
return None
|
||||
|
||||
if properties.callback:
|
||||
cb = properties.callback
|
||||
if properties.human_callback:
|
||||
cb = properties.human_callback
|
||||
|
||||
prev_token = self.access_token
|
||||
if prev_token:
|
||||
return prev_token
|
||||
|
||||
if cb is None and not prev_token:
|
||||
return None
|
||||
|
||||
if not prev_token and cb is not None:
|
||||
with self.lock:
|
||||
# See if the token was changed while we were waiting for the
|
||||
# lock.
|
||||
new_token = self.access_token
|
||||
if new_token != prev_token:
|
||||
return new_token
|
||||
|
||||
# Ensure that we are waiting a min time between callback invocations.
|
||||
delta = time.time() - self.last_call_time
|
||||
if delta < TIME_BETWEEN_CALLS_SECONDS:
|
||||
time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
|
||||
self.last_call_time = time.time()
|
||||
|
||||
if is_human:
|
||||
timeout = HUMAN_CALLBACK_TIMEOUT_SECONDS
|
||||
assert self.idp_info is not None
|
||||
else:
|
||||
timeout = int(remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS)
|
||||
context = OIDCCallbackContext(
|
||||
timeout_seconds=timeout,
|
||||
version=CALLBACK_VERSION,
|
||||
refresh_token=self.refresh_token,
|
||||
idp_info=self.idp_info,
|
||||
username=self.properties.username,
|
||||
)
|
||||
resp = cb.fetch(context)
|
||||
if not isinstance(resp, OIDCCallbackResult):
|
||||
raise ValueError("Callback result must be of type OIDCCallbackResult")
|
||||
self.refresh_token = resp.refresh_token
|
||||
self.access_token = resp.access_token
|
||||
self.token_gen_id += 1
|
||||
|
||||
return self.access_token
|
||||
|
||||
async def _run_command(
|
||||
self, conn: Connection, cmd: MutableMapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
try:
|
||||
return await conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
self._invalidate(conn)
|
||||
raise
|
||||
|
||||
def _is_auth_error(self, err: Exception) -> bool:
|
||||
if not isinstance(err, OperationFailure):
|
||||
return False
|
||||
return err.code == _AUTHENTICATION_FAILURE_CODE
|
||||
|
||||
def _invalidate(self, conn: Connection) -> None:
|
||||
# Ignore the invalidation if a token gen id is given and is less than our
|
||||
# current token gen id.
|
||||
token_gen_id = conn.oidc_token_gen_id or 0
|
||||
if token_gen_id is not None and token_gen_id < self.token_gen_id:
|
||||
return
|
||||
self.access_token = None
|
||||
|
||||
async def _sasl_continue_jwt(
|
||||
self, conn: Connection, start_resp: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
self.access_token = None
|
||||
self.refresh_token = None
|
||||
start_payload: dict = bson.decode(start_resp["payload"])
|
||||
if "issuer" in start_payload:
|
||||
self.idp_info = OIDCIdPInfo(**start_payload)
|
||||
access_token = self._get_access_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
|
||||
return await self._run_command(conn, cmd)
|
||||
|
||||
async def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]:
|
||||
access_token = self._get_access_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
cmd = self._get_start_command({"jwt": access_token})
|
||||
return await self._run_command(conn, cmd)
|
||||
|
||||
def _get_start_command(self, payload: Optional[Mapping[str, Any]]) -> MutableMapping[str, Any]:
|
||||
if payload is None:
|
||||
principal_name = self.username
|
||||
if principal_name:
|
||||
payload = {"n": principal_name}
|
||||
else:
|
||||
payload = {}
|
||||
bin_payload = Binary(bson.encode(payload))
|
||||
return {"saslStart": 1, "mechanism": "MONGODB-OIDC", "payload": bin_payload}
|
||||
|
||||
def _get_continue_command(
|
||||
self, payload: Mapping[str, Any], start_resp: Mapping[str, Any]
|
||||
) -> MutableMapping[str, Any]:
|
||||
bin_payload = Binary(bson.encode(payload))
|
||||
return {
|
||||
"saslContinue": 1,
|
||||
"payload": bin_payload,
|
||||
"conversationId": start_resp["conversationId"],
|
||||
}
|
||||
|
||||
|
||||
async def _authenticate_oidc(
|
||||
credentials: MongoCredential, conn: Connection, reauthenticate: bool
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""Authenticate using MONGODB-OIDC."""
|
||||
authenticator = _get_authenticator(credentials, conn.address)
|
||||
if reauthenticate:
|
||||
return await authenticator.reauthenticate(conn)
|
||||
else:
|
||||
return await authenticator.authenticate(conn)
|
||||
599
pymongo/asynchronous/bulk.py
Normal file
599
pymongo/asynchronous/bulk.py
Normal file
@ -0,0 +1,599 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The bulk write operations interface.
|
||||
|
||||
.. versionadded:: 2.7
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from collections.abc import MutableMapping
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Iterator,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo import _csot
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo.asynchronous.client_session import ClientSession, _validate_session_write_concern
|
||||
from pymongo.asynchronous.common import (
|
||||
validate_is_document_type,
|
||||
validate_ok_for_replace,
|
||||
validate_ok_for_update,
|
||||
)
|
||||
from pymongo.asynchronous.helpers import _get_wce_doc
|
||||
from pymongo.asynchronous.message import (
|
||||
_DELETE,
|
||||
_INSERT,
|
||||
_UPDATE,
|
||||
_BulkWriteContext,
|
||||
_EncryptedBulkWriteContext,
|
||||
_randint,
|
||||
)
|
||||
from pymongo.asynchronous.read_preferences import ReadPreference
|
||||
from pymongo.errors import (
|
||||
BulkWriteError,
|
||||
ConfigurationError,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.typings import _DocumentOut, _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
_DELETE_ALL: int = 0
|
||||
_DELETE_ONE: int = 1
|
||||
|
||||
# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err
|
||||
_BAD_VALUE: int = 2
|
||||
_UNKNOWN_ERROR: int = 8
|
||||
_WRITE_CONCERN_ERROR: int = 64
|
||||
|
||||
_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete")
|
||||
|
||||
|
||||
class _Run:
|
||||
"""Represents a batch of write operations."""
|
||||
|
||||
def __init__(self, op_type: int) -> None:
|
||||
"""Initialize a new Run object."""
|
||||
self.op_type: int = op_type
|
||||
self.index_map: list[int] = []
|
||||
self.ops: list[Any] = []
|
||||
self.idx_offset: int = 0
|
||||
|
||||
def index(self, idx: int) -> int:
|
||||
"""Get the original index of an operation in this run.
|
||||
|
||||
:param idx: The Run index that maps to the original index.
|
||||
"""
|
||||
return self.index_map[idx]
|
||||
|
||||
def add(self, original_index: int, operation: Any) -> None:
|
||||
"""Add an operation to this Run instance.
|
||||
|
||||
:param original_index: The original index of this operation
|
||||
within a larger bulk operation.
|
||||
:param operation: The operation document.
|
||||
"""
|
||||
self.index_map.append(original_index)
|
||||
self.ops.append(operation)
|
||||
|
||||
|
||||
def _merge_command(
|
||||
run: _Run,
|
||||
full_result: MutableMapping[str, Any],
|
||||
offset: int,
|
||||
result: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Merge a write command result into the full bulk result."""
|
||||
affected = result.get("n", 0)
|
||||
|
||||
if run.op_type == _INSERT:
|
||||
full_result["nInserted"] += affected
|
||||
|
||||
elif run.op_type == _DELETE:
|
||||
full_result["nRemoved"] += affected
|
||||
|
||||
elif run.op_type == _UPDATE:
|
||||
upserted = result.get("upserted")
|
||||
if upserted:
|
||||
n_upserted = len(upserted)
|
||||
for doc in upserted:
|
||||
doc["index"] = run.index(doc["index"] + offset)
|
||||
full_result["upserted"].extend(upserted)
|
||||
full_result["nUpserted"] += n_upserted
|
||||
full_result["nMatched"] += affected - n_upserted
|
||||
else:
|
||||
full_result["nMatched"] += affected
|
||||
full_result["nModified"] += result["nModified"]
|
||||
|
||||
write_errors = result.get("writeErrors")
|
||||
if write_errors:
|
||||
for doc in write_errors:
|
||||
# Leave the server response intact for APM.
|
||||
replacement = doc.copy()
|
||||
idx = doc["index"] + offset
|
||||
replacement["index"] = run.index(idx)
|
||||
# Add the failed operation to the error document.
|
||||
replacement["op"] = run.ops[idx]
|
||||
full_result["writeErrors"].append(replacement)
|
||||
|
||||
wce = _get_wce_doc(result)
|
||||
if wce:
|
||||
full_result["writeConcernErrors"].append(wce)
|
||||
|
||||
|
||||
def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn:
|
||||
"""Raise a BulkWriteError from the full bulk api result."""
|
||||
# retryWrites on MMAPv1 should raise an actionable error.
|
||||
if full_result["writeErrors"]:
|
||||
full_result["writeErrors"].sort(key=lambda error: error["index"])
|
||||
err = full_result["writeErrors"][0]
|
||||
code = err["code"]
|
||||
msg = err["errmsg"]
|
||||
if code == 20 and msg.startswith("Transaction numbers"):
|
||||
errmsg = (
|
||||
"This MongoDB deployment does not support "
|
||||
"retryable writes. Please add retryWrites=false "
|
||||
"to your connection string."
|
||||
)
|
||||
raise OperationFailure(errmsg, code, full_result)
|
||||
raise BulkWriteError(full_result)
|
||||
|
||||
|
||||
class _Bulk:
|
||||
"""The private guts of the bulk write API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: AsyncCollection[_DocumentType],
|
||||
ordered: bool,
|
||||
bypass_document_validation: bool,
|
||||
comment: Optional[str] = None,
|
||||
let: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Initialize a _Bulk instance."""
|
||||
self.collection = collection.with_options(
|
||||
codec_options=collection.codec_options._replace(
|
||||
unicode_decode_error_handler="replace", document_class=dict
|
||||
)
|
||||
)
|
||||
self.let = let
|
||||
if self.let is not None:
|
||||
common.validate_is_document_type("let", self.let)
|
||||
self.comment: Optional[str] = comment
|
||||
self.ordered = ordered
|
||||
self.ops: list[tuple[int, Mapping[str, Any]]] = []
|
||||
self.executed = False
|
||||
self.bypass_doc_val = bypass_document_validation
|
||||
self.uses_collation = False
|
||||
self.uses_array_filters = False
|
||||
self.uses_hint_update = False
|
||||
self.uses_hint_delete = False
|
||||
self.is_retryable = True
|
||||
self.retrying = False
|
||||
self.started_retryable_write = False
|
||||
# Extra state so that we know where to pick up on a retry attempt.
|
||||
self.current_run = None
|
||||
self.next_run = None
|
||||
|
||||
@property
|
||||
def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
|
||||
encrypter = self.collection.database.client._encrypter
|
||||
if encrypter and not encrypter._bypass_auto_encryption:
|
||||
return _EncryptedBulkWriteContext
|
||||
else:
|
||||
return _BulkWriteContext
|
||||
|
||||
def add_insert(self, document: _DocumentOut) -> None:
|
||||
"""Add an insert document to the list of ops."""
|
||||
validate_is_document_type("document", document)
|
||||
# Generate ObjectId client side.
|
||||
if not (isinstance(document, RawBSONDocument) or "_id" in document):
|
||||
document["_id"] = ObjectId()
|
||||
self.ops.append((_INSERT, document))
|
||||
|
||||
def add_update(
|
||||
self,
|
||||
selector: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
multi: bool = False,
|
||||
upsert: bool = False,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
array_filters: Optional[list[Mapping[str, Any]]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
) -> None:
|
||||
"""Create an update document and add it to the list of ops."""
|
||||
validate_ok_for_update(update)
|
||||
cmd: dict[str, Any] = dict( # noqa: C406
|
||||
[("q", selector), ("u", update), ("multi", multi), ("upsert", upsert)]
|
||||
)
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if array_filters is not None:
|
||||
self.uses_array_filters = True
|
||||
cmd["arrayFilters"] = array_filters
|
||||
if hint is not None:
|
||||
self.uses_hint_update = True
|
||||
cmd["hint"] = hint
|
||||
if multi:
|
||||
# A bulk_write containing an update_many is not retryable.
|
||||
self.is_retryable = False
|
||||
self.ops.append((_UPDATE, cmd))
|
||||
|
||||
def add_replace(
|
||||
self,
|
||||
selector: Mapping[str, Any],
|
||||
replacement: Mapping[str, Any],
|
||||
upsert: bool = False,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
) -> None:
|
||||
"""Create a replace document and add it to the list of ops."""
|
||||
validate_ok_for_replace(replacement)
|
||||
cmd = {"q": selector, "u": replacement, "multi": False, "upsert": upsert}
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if hint is not None:
|
||||
self.uses_hint_update = True
|
||||
cmd["hint"] = hint
|
||||
self.ops.append((_UPDATE, cmd))
|
||||
|
||||
def add_delete(
|
||||
self,
|
||||
selector: Mapping[str, Any],
|
||||
limit: int,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
) -> None:
|
||||
"""Create a delete document and add it to the list of ops."""
|
||||
cmd = {"q": selector, "limit": limit}
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if hint is not None:
|
||||
self.uses_hint_delete = True
|
||||
cmd["hint"] = hint
|
||||
if limit == _DELETE_ALL:
|
||||
# A bulk_write containing a delete_many is not retryable.
|
||||
self.is_retryable = False
|
||||
self.ops.append((_DELETE, cmd))
|
||||
|
||||
def gen_ordered(self) -> Iterator[Optional[_Run]]:
|
||||
"""Generate batches of operations, batched by type of
|
||||
operation, in the order **provided**.
|
||||
"""
|
||||
run = None
|
||||
for idx, (op_type, operation) in enumerate(self.ops):
|
||||
if run is None:
|
||||
run = _Run(op_type)
|
||||
elif run.op_type != op_type:
|
||||
yield run
|
||||
run = _Run(op_type)
|
||||
run.add(idx, operation)
|
||||
yield run
|
||||
|
||||
def gen_unordered(self) -> Iterator[_Run]:
|
||||
"""Generate batches of operations, batched by type of
|
||||
operation, in arbitrary order.
|
||||
"""
|
||||
operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)]
|
||||
for idx, (op_type, operation) in enumerate(self.ops):
|
||||
operations[op_type].add(idx, operation)
|
||||
|
||||
for run in operations:
|
||||
if run.ops:
|
||||
yield run
|
||||
|
||||
async def _execute_command(
|
||||
self,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[ClientSession],
|
||||
conn: Connection,
|
||||
op_id: int,
|
||||
retryable: bool,
|
||||
full_result: MutableMapping[str, Any],
|
||||
final_write_concern: Optional[WriteConcern] = None,
|
||||
) -> None:
|
||||
db_name = self.collection.database.name
|
||||
client = self.collection.database.client
|
||||
listeners = client._event_listeners
|
||||
|
||||
if not self.current_run:
|
||||
self.current_run = next(generator)
|
||||
self.next_run = None
|
||||
run = self.current_run
|
||||
|
||||
# Connection.command validates the session, but we use
|
||||
# Connection.write_command
|
||||
conn.validate_session(client, session)
|
||||
last_run = False
|
||||
|
||||
while run:
|
||||
if not self.retrying:
|
||||
self.next_run = next(generator, None)
|
||||
if self.next_run is None:
|
||||
last_run = True
|
||||
|
||||
cmd_name = _COMMANDS[run.op_type]
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
conn,
|
||||
op_id,
|
||||
listeners,
|
||||
session,
|
||||
run.op_type,
|
||||
self.collection.codec_options,
|
||||
)
|
||||
|
||||
while run.idx_offset < len(run.ops):
|
||||
# If this is the last possible operation, use the
|
||||
# final write concern.
|
||||
if last_run and (len(run.ops) - run.idx_offset) == 1:
|
||||
write_concern = final_write_concern or write_concern
|
||||
|
||||
cmd = {cmd_name: self.collection.name, "ordered": self.ordered}
|
||||
if self.comment:
|
||||
cmd["comment"] = self.comment
|
||||
_csot.apply_write_concern(cmd, write_concern)
|
||||
if self.bypass_doc_val:
|
||||
cmd["bypassDocumentValidation"] = True
|
||||
if self.let is not None and run.op_type in (_DELETE, _UPDATE):
|
||||
cmd["let"] = self.let
|
||||
if session:
|
||||
# Start a new retryable write unless one was already
|
||||
# started for this command.
|
||||
if retryable and not self.started_retryable_write:
|
||||
session._start_retryable_write()
|
||||
self.started_retryable_write = True
|
||||
await session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
|
||||
conn.send_cluster_time(cmd, session, client)
|
||||
conn.add_server_api(cmd)
|
||||
# CSOT: apply timeout before encoding the command.
|
||||
conn.apply_timeout(client, cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
|
||||
# Run as many ops as possible in one command.
|
||||
if write_concern.acknowledged:
|
||||
result, to_send = await bwc.execute(cmd, ops, client)
|
||||
|
||||
# Retryable writeConcernErrors halt the execution of this run.
|
||||
wce = result.get("writeConcernError", {})
|
||||
if wce.get("code", 0) in _RETRYABLE_ERROR_CODES:
|
||||
# Synthesize the full bulk result without modifying the
|
||||
# current one because this write operation may be retried.
|
||||
full = copy.deepcopy(full_result)
|
||||
_merge_command(run, full, run.idx_offset, result)
|
||||
_raise_bulk_write_error(full)
|
||||
|
||||
_merge_command(run, full_result, run.idx_offset, result)
|
||||
|
||||
# We're no longer in a retry once a command succeeds.
|
||||
self.retrying = False
|
||||
self.started_retryable_write = False
|
||||
|
||||
if self.ordered and "writeErrors" in result:
|
||||
break
|
||||
else:
|
||||
to_send = await bwc.execute_unack(cmd, ops, client)
|
||||
|
||||
run.idx_offset += len(to_send)
|
||||
|
||||
# We're supposed to continue if errors are
|
||||
# at the write concern level (e.g. wtimeout)
|
||||
if self.ordered and full_result["writeErrors"]:
|
||||
break
|
||||
# Reset our state
|
||||
self.current_run = run = self.next_run
|
||||
|
||||
async def execute_command(
|
||||
self,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[ClientSession],
|
||||
operation: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute using write commands."""
|
||||
# nModified is only reported for write commands, not legacy ops.
|
||||
full_result = {
|
||||
"writeErrors": [],
|
||||
"writeConcernErrors": [],
|
||||
"nInserted": 0,
|
||||
"nUpserted": 0,
|
||||
"nMatched": 0,
|
||||
"nModified": 0,
|
||||
"nRemoved": 0,
|
||||
"upserted": [],
|
||||
}
|
||||
op_id = _randint()
|
||||
|
||||
async def retryable_bulk(
|
||||
session: Optional[ClientSession], conn: Connection, retryable: bool
|
||||
) -> None:
|
||||
await self._execute_command(
|
||||
generator,
|
||||
write_concern,
|
||||
session,
|
||||
conn,
|
||||
op_id,
|
||||
retryable,
|
||||
full_result,
|
||||
)
|
||||
|
||||
client = self.collection.database.client
|
||||
_ = await client._retryable_write(
|
||||
self.is_retryable,
|
||||
retryable_bulk,
|
||||
session,
|
||||
operation,
|
||||
bulk=self,
|
||||
operation_id=op_id,
|
||||
)
|
||||
|
||||
if full_result["writeErrors"] or full_result["writeConcernErrors"]:
|
||||
_raise_bulk_write_error(full_result)
|
||||
return full_result
|
||||
|
||||
async def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None:
|
||||
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
|
||||
db_name = self.collection.database.name
|
||||
client = self.collection.database.client
|
||||
listeners = client._event_listeners
|
||||
op_id = _randint()
|
||||
|
||||
if not self.current_run:
|
||||
self.current_run = next(generator)
|
||||
run = self.current_run
|
||||
|
||||
while run:
|
||||
cmd_name = _COMMANDS[run.op_type]
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
conn,
|
||||
op_id,
|
||||
listeners,
|
||||
None,
|
||||
run.op_type,
|
||||
self.collection.codec_options,
|
||||
)
|
||||
|
||||
while run.idx_offset < len(run.ops):
|
||||
cmd = {
|
||||
cmd_name: self.collection.name,
|
||||
"ordered": False,
|
||||
"writeConcern": {"w": 0},
|
||||
}
|
||||
conn.add_server_api(cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
# Run as many ops as possible.
|
||||
to_send = await bwc.execute_unack(cmd, ops, client)
|
||||
run.idx_offset += len(to_send)
|
||||
self.current_run = run = next(generator, None)
|
||||
|
||||
async def execute_command_no_results(
|
||||
self,
|
||||
conn: Connection,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
) -> None:
|
||||
"""Execute write commands with OP_MSG and w=0 WriteConcern, ordered."""
|
||||
full_result = {
|
||||
"writeErrors": [],
|
||||
"writeConcernErrors": [],
|
||||
"nInserted": 0,
|
||||
"nUpserted": 0,
|
||||
"nMatched": 0,
|
||||
"nModified": 0,
|
||||
"nRemoved": 0,
|
||||
"upserted": [],
|
||||
}
|
||||
# Ordered bulk writes have to be acknowledged so that we stop
|
||||
# processing at the first error, even when the application
|
||||
# specified unacknowledged writeConcern.
|
||||
initial_write_concern = WriteConcern()
|
||||
op_id = _randint()
|
||||
try:
|
||||
await self._execute_command(
|
||||
generator,
|
||||
initial_write_concern,
|
||||
None,
|
||||
conn,
|
||||
op_id,
|
||||
False,
|
||||
full_result,
|
||||
write_concern,
|
||||
)
|
||||
except OperationFailure:
|
||||
pass
|
||||
|
||||
async def execute_no_results(
|
||||
self,
|
||||
conn: Connection,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
) -> None:
|
||||
"""Execute all operations, returning no results (w=0)."""
|
||||
if self.uses_collation:
|
||||
raise ConfigurationError("Collation is unsupported for unacknowledged writes.")
|
||||
if self.uses_array_filters:
|
||||
raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.")
|
||||
# Guard against unsupported unacknowledged writes.
|
||||
unack = write_concern and not write_concern.acknowledged
|
||||
if unack and self.uses_hint_delete and conn.max_wire_version < 9:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
|
||||
)
|
||||
if unack and self.uses_hint_update and conn.max_wire_version < 8:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
|
||||
)
|
||||
# Cannot have both unacknowledged writes and bypass document validation.
|
||||
if self.bypass_doc_val:
|
||||
raise OperationFailure(
|
||||
"Cannot set bypass_document_validation with unacknowledged write concern"
|
||||
)
|
||||
|
||||
if self.ordered:
|
||||
return await self.execute_command_no_results(conn, generator, write_concern)
|
||||
return await self.execute_op_msg_no_results(conn, generator)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[ClientSession],
|
||||
operation: str,
|
||||
) -> Any:
|
||||
"""Execute operations."""
|
||||
if not self.ops:
|
||||
raise InvalidOperation("No operations to execute")
|
||||
if self.executed:
|
||||
raise InvalidOperation("Bulk operations can only be executed once.")
|
||||
self.executed = True
|
||||
write_concern = write_concern or self.collection.write_concern
|
||||
session = _validate_session_write_concern(session, write_concern)
|
||||
|
||||
if self.ordered:
|
||||
generator = self.gen_ordered()
|
||||
else:
|
||||
generator = self.gen_unordered()
|
||||
|
||||
client = self.collection.database.client
|
||||
if not write_concern.acknowledged:
|
||||
async with await client._conn_for_writes(session, operation) as connection:
|
||||
await self.execute_no_results(connection, generator, write_concern)
|
||||
return None
|
||||
else:
|
||||
return await self.execute_command(generator, write_concern, session, operation)
|
||||
499
pymongo/asynchronous/change_stream.py
Normal file
499
pymongo/asynchronous/change_stream.py
Normal file
@ -0,0 +1,499 @@
|
||||
# Copyright 2017 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Watch changes on a collection, a database, or the entire cluster."""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
|
||||
|
||||
from bson import CodecOptions, _bson_to_dict
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo.asynchronous.aggregation import (
|
||||
_AggregationCommand,
|
||||
_CollectionAggregationCommand,
|
||||
_DatabaseAggregationCommand,
|
||||
)
|
||||
from pymongo.asynchronous.collation import validate_collation_or_none
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.operations import _Op
|
||||
from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
from pymongo.errors import (
|
||||
ConnectionFailure,
|
||||
CursorNotFound,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
)
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# The change streams spec considers the following server errors from the
|
||||
# getMore command non-resumable. All other getMore errors are resumable.
|
||||
_RESUMABLE_GETMORE_ERRORS = frozenset(
|
||||
[
|
||||
6, # HostUnreachable
|
||||
7, # HostNotFound
|
||||
89, # NetworkTimeout
|
||||
91, # ShutdownInProgress
|
||||
189, # PrimarySteppedDown
|
||||
262, # ExceededTimeLimit
|
||||
9001, # SocketException
|
||||
10107, # NotWritablePrimary
|
||||
11600, # InterruptedAtShutdown
|
||||
11602, # InterruptedDueToReplStateChange
|
||||
13435, # NotPrimaryNoSecondaryOk
|
||||
13436, # NotPrimaryOrSecondary
|
||||
63, # StaleShardVersion
|
||||
150, # StaleEpoch
|
||||
13388, # StaleConfig
|
||||
234, # RetryChangeStream
|
||||
133, # FailedToSatisfyReadPreference
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.client_session import ClientSession
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
|
||||
|
||||
def _resumable(exc: PyMongoError) -> bool:
|
||||
"""Return True if given a resumable change stream error."""
|
||||
if isinstance(exc, (ConnectionFailure, CursorNotFound)):
|
||||
return True
|
||||
if isinstance(exc, OperationFailure):
|
||||
if exc._max_wire_version is None:
|
||||
return False
|
||||
return (
|
||||
exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError")
|
||||
) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS)
|
||||
return False
|
||||
|
||||
|
||||
class ChangeStream(Generic[_DocumentType]):
|
||||
"""The internal abstract base class for change stream cursors.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
:meth:`pymongo.collection.AsyncCollection.watch`,
|
||||
:meth:`pymongo.database.AsyncDatabase.watch`, or
|
||||
:meth:`pymongo.mongo_client.AsyncMongoClient.watch` instead.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
.. seealso:: The MongoDB documentation on `changeStreams <https://mongodb.com/docs/manual/changeStreams/>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[
|
||||
AsyncMongoClient[_DocumentType],
|
||||
AsyncDatabase[_DocumentType],
|
||||
AsyncCollection[_DocumentType],
|
||||
],
|
||||
pipeline: Optional[_Pipeline],
|
||||
full_document: Optional[str],
|
||||
resume_after: Optional[Mapping[str, Any]],
|
||||
max_await_time_ms: Optional[int],
|
||||
batch_size: Optional[int],
|
||||
collation: Optional[_CollationIn],
|
||||
start_at_operation_time: Optional[Timestamp],
|
||||
session: Optional[ClientSession],
|
||||
start_after: Optional[Mapping[str, Any]],
|
||||
comment: Optional[Any] = None,
|
||||
full_document_before_change: Optional[str] = None,
|
||||
show_expanded_events: Optional[bool] = None,
|
||||
) -> None:
|
||||
if pipeline is None:
|
||||
pipeline = []
|
||||
pipeline = common.validate_list("pipeline", pipeline)
|
||||
common.validate_string_or_none("full_document", full_document)
|
||||
validate_collation_or_none(collation)
|
||||
common.validate_non_negative_integer_or_none("batchSize", batch_size)
|
||||
|
||||
self._decode_custom = False
|
||||
self._orig_codec_options: CodecOptions[_DocumentType] = target.codec_options
|
||||
if target.codec_options.type_registry._decoder_map:
|
||||
self._decode_custom = True
|
||||
# Keep the type registry so that we support encoding custom types
|
||||
# in the pipeline.
|
||||
self._target = target.with_options( # type: ignore
|
||||
codec_options=target.codec_options.with_options(document_class=RawBSONDocument)
|
||||
)
|
||||
else:
|
||||
self._target = target
|
||||
|
||||
self._pipeline = copy.deepcopy(pipeline)
|
||||
self._full_document = full_document
|
||||
self._full_document_before_change = full_document_before_change
|
||||
self._uses_start_after = start_after is not None
|
||||
self._uses_resume_after = resume_after is not None
|
||||
self._resume_token = copy.deepcopy(start_after or resume_after)
|
||||
self._max_await_time_ms = max_await_time_ms
|
||||
self._batch_size = batch_size
|
||||
self._collation = collation
|
||||
self._start_at_operation_time = start_at_operation_time
|
||||
self._session = session
|
||||
self._comment = comment
|
||||
self._closed = False
|
||||
self._timeout = self._target._timeout
|
||||
self._show_expanded_events = show_expanded_events
|
||||
|
||||
async def _initialize_cursor(self) -> None:
|
||||
# Initialize cursor.
|
||||
self._cursor = await self._create_cursor()
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_AggregationCommand]:
|
||||
"""The aggregation command class to be used."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _client(self) -> AsyncMongoClient:
|
||||
"""The client against which the aggregation commands for
|
||||
this ChangeStream will be run.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _change_stream_options(self) -> dict[str, Any]:
|
||||
"""Return the options dict for the $changeStream pipeline stage."""
|
||||
options: dict[str, Any] = {}
|
||||
if self._full_document is not None:
|
||||
options["fullDocument"] = self._full_document
|
||||
|
||||
if self._full_document_before_change is not None:
|
||||
options["fullDocumentBeforeChange"] = self._full_document_before_change
|
||||
|
||||
resume_token = self.resume_token
|
||||
if resume_token is not None:
|
||||
if self._uses_start_after:
|
||||
options["startAfter"] = resume_token
|
||||
else:
|
||||
options["resumeAfter"] = resume_token
|
||||
|
||||
elif self._start_at_operation_time is not None:
|
||||
options["startAtOperationTime"] = self._start_at_operation_time
|
||||
|
||||
if self._show_expanded_events:
|
||||
options["showExpandedEvents"] = self._show_expanded_events
|
||||
|
||||
return options
|
||||
|
||||
def _command_options(self) -> dict[str, Any]:
|
||||
"""Return the options dict for the aggregation command."""
|
||||
options = {}
|
||||
if self._max_await_time_ms is not None:
|
||||
options["maxAwaitTimeMS"] = self._max_await_time_ms
|
||||
if self._batch_size is not None:
|
||||
options["batchSize"] = self._batch_size
|
||||
return options
|
||||
|
||||
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
|
||||
"""Return the full aggregation pipeline for this ChangeStream."""
|
||||
options = self._change_stream_options()
|
||||
full_pipeline: list = [{"$changeStream": options}]
|
||||
full_pipeline.extend(self._pipeline)
|
||||
return full_pipeline
|
||||
|
||||
def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None:
|
||||
"""Callback that caches the postBatchResumeToken or
|
||||
startAtOperationTime from a changeStream aggregate command response
|
||||
containing an empty batch of change documents.
|
||||
|
||||
This is implemented as a callback because we need access to the wire
|
||||
version in order to determine whether to cache this value.
|
||||
"""
|
||||
if not result["cursor"]["firstBatch"]:
|
||||
if "postBatchResumeToken" in result["cursor"]:
|
||||
self._resume_token = result["cursor"]["postBatchResumeToken"]
|
||||
elif (
|
||||
self._start_at_operation_time is None
|
||||
and self._uses_resume_after is False
|
||||
and self._uses_start_after is False
|
||||
and conn.max_wire_version >= 7
|
||||
):
|
||||
self._start_at_operation_time = result.get("operationTime")
|
||||
# PYTHON-2181: informative error on missing operationTime.
|
||||
if self._start_at_operation_time is None:
|
||||
raise OperationFailure(
|
||||
"Expected field 'operationTime' missing from command "
|
||||
f"response : {result!r}"
|
||||
)
|
||||
|
||||
async def _run_aggregation_cmd(
|
||||
self, session: Optional[ClientSession], explicit_session: bool
|
||||
) -> AsyncCommandCursor:
|
||||
"""Run the full aggregation pipeline for this ChangeStream and return
|
||||
the corresponding AsyncCommandCursor.
|
||||
"""
|
||||
cmd = self._aggregation_command_class(
|
||||
self._target,
|
||||
AsyncCommandCursor,
|
||||
self._aggregation_pipeline(),
|
||||
self._command_options(),
|
||||
explicit_session,
|
||||
result_processor=self._process_result,
|
||||
comment=self._comment,
|
||||
)
|
||||
return await self._client._retryable_read(
|
||||
cmd.get_cursor,
|
||||
self._target._read_preference_for(session),
|
||||
session,
|
||||
operation=_Op.AGGREGATE,
|
||||
)
|
||||
|
||||
async def _create_cursor(self) -> AsyncCommandCursor:
|
||||
async with self._client._tmp_session(self._session, close=False) as s:
|
||||
return await self._run_aggregation_cmd(
|
||||
session=s, explicit_session=self._session is not None
|
||||
)
|
||||
|
||||
async def _resume(self) -> None:
|
||||
"""Reestablish this change stream after a resumable error."""
|
||||
try:
|
||||
await self._cursor.close()
|
||||
except PyMongoError:
|
||||
pass
|
||||
self._cursor = await self._create_cursor()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close this ChangeStream."""
|
||||
self._closed = True
|
||||
await self._cursor.close()
|
||||
|
||||
def __aiter__(self) -> ChangeStream[_DocumentType]:
|
||||
return self
|
||||
|
||||
@property
|
||||
def resume_token(self) -> Optional[Mapping[str, Any]]:
|
||||
"""The cached resume token that will be used to resume after the most
|
||||
recently returned change.
|
||||
|
||||
.. versionadded:: 3.9
|
||||
"""
|
||||
return copy.deepcopy(self._resume_token)
|
||||
|
||||
@_csot.apply
|
||||
async def next(self) -> _DocumentType:
|
||||
"""Advance the cursor.
|
||||
|
||||
This method blocks until the next change document is returned or an
|
||||
unrecoverable error is raised. This method is used when iterating over
|
||||
all changes in the cursor. For example::
|
||||
|
||||
try:
|
||||
resume_token = None
|
||||
pipeline = [{'$match': {'operationType': 'insert'}}]
|
||||
async with db.collection.watch(pipeline) as stream:
|
||||
async for insert_change in stream:
|
||||
print(insert_change)
|
||||
resume_token = stream.resume_token
|
||||
except pymongo.errors.PyMongoError:
|
||||
# The ChangeStream encountered an unrecoverable error or the
|
||||
# resume attempt failed to recreate the cursor.
|
||||
if resume_token is None:
|
||||
# There is no usable resume token because there was a
|
||||
# failure during ChangeStream initialization.
|
||||
logging.error('...')
|
||||
else:
|
||||
# Use the interrupted ChangeStream's resume token to create
|
||||
# a new ChangeStream. The new stream will continue from the
|
||||
# last seen insert change without missing any events.
|
||||
async with db.collection.watch(
|
||||
pipeline, resume_after=resume_token) as stream:
|
||||
async for insert_change in stream:
|
||||
print(insert_change)
|
||||
|
||||
Raises :exc:`StopIteration` if this ChangeStream is closed.
|
||||
"""
|
||||
while self.alive:
|
||||
doc = await self.try_next()
|
||||
if doc is not None:
|
||||
return doc
|
||||
|
||||
raise StopAsyncIteration
|
||||
|
||||
__anext__ = next
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
"""Does this cursor have the potential to return more data?
|
||||
|
||||
.. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise
|
||||
:exc:`StopIteration` and :meth:`try_next` can return ``None``.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
"""
|
||||
return not self._closed
|
||||
|
||||
@_csot.apply
|
||||
async def try_next(self) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor without blocking indefinitely.
|
||||
|
||||
This method returns the next change document without waiting
|
||||
indefinitely for the next change. For example::
|
||||
|
||||
async with db.collection.watch() as stream:
|
||||
while stream.alive:
|
||||
change = await stream.try_next()
|
||||
# Note that the ChangeStream's resume token may be updated
|
||||
# even when no changes are returned.
|
||||
print("Current resume token: %r" % (stream.resume_token,))
|
||||
if change is not None:
|
||||
print("Change document: %r" % (change,))
|
||||
continue
|
||||
# We end up here when there are no recent changes.
|
||||
# Sleep for a while before trying again to avoid flooding
|
||||
# the server with getMore requests when no changes are
|
||||
# available.
|
||||
asyncio.sleep(10)
|
||||
|
||||
If no change document is cached locally then this method runs a single
|
||||
getMore command. If the getMore yields any documents, the next
|
||||
document is returned, otherwise, if the getMore returns no documents
|
||||
(because there have been no changes) then ``None`` is returned.
|
||||
|
||||
:return: The next change document or ``None`` when no document is available
|
||||
after running a single getMore or when the cursor is closed.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
"""
|
||||
if not self._closed and not self._cursor.alive:
|
||||
await self._resume()
|
||||
|
||||
# Attempt to get the next change with at most one getMore and at most
|
||||
# one resume attempt.
|
||||
try:
|
||||
try:
|
||||
change = await self._cursor._try_next(True)
|
||||
except PyMongoError as exc:
|
||||
if not _resumable(exc):
|
||||
raise
|
||||
await self._resume()
|
||||
change = await self._cursor._try_next(False)
|
||||
except PyMongoError as exc:
|
||||
# Close the stream after a fatal error.
|
||||
if not _resumable(exc) and not exc.timeout:
|
||||
await self.close()
|
||||
raise
|
||||
except Exception:
|
||||
await self.close()
|
||||
raise
|
||||
|
||||
# Check if the cursor was invalidated.
|
||||
if not self._cursor.alive:
|
||||
self._closed = True
|
||||
|
||||
# If no changes are available.
|
||||
if change is None:
|
||||
# We have either iterated over all documents in the cursor,
|
||||
# OR the most-recently returned batch is empty. In either case,
|
||||
# update the cached resume token with the postBatchResumeToken if
|
||||
# one was returned. We also clear the startAtOperationTime.
|
||||
if self._cursor._post_batch_resume_token is not None:
|
||||
self._resume_token = self._cursor._post_batch_resume_token
|
||||
self._start_at_operation_time = None
|
||||
return change
|
||||
|
||||
# Else, changes are available.
|
||||
try:
|
||||
resume_token = change["_id"]
|
||||
except KeyError:
|
||||
await self.close()
|
||||
raise InvalidOperation(
|
||||
"Cannot provide resume functionality when the resume token is missing."
|
||||
) from None
|
||||
|
||||
# If this is the last change document from the current batch, cache the
|
||||
# postBatchResumeToken.
|
||||
if not self._cursor._has_next() and self._cursor._post_batch_resume_token:
|
||||
resume_token = self._cursor._post_batch_resume_token
|
||||
|
||||
# Hereafter, don't use startAfter; instead use resumeAfter.
|
||||
self._uses_start_after = False
|
||||
self._uses_resume_after = True
|
||||
|
||||
# Cache the resume token and clear startAtOperationTime.
|
||||
self._resume_token = resume_token
|
||||
self._start_at_operation_time = None
|
||||
|
||||
if self._decode_custom:
|
||||
return _bson_to_dict(change.raw, self._orig_codec_options)
|
||||
return change
|
||||
|
||||
async def __aenter__(self) -> ChangeStream[_DocumentType]:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
|
||||
class CollectionChangeStream(ChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on a single collection.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.collection.AsyncCollection.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
_target: AsyncCollection[_DocumentType]
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_CollectionAggregationCommand]:
|
||||
return _CollectionAggregationCommand
|
||||
|
||||
@property
|
||||
def _client(self) -> AsyncMongoClient[_DocumentType]:
|
||||
return self._target.database.client
|
||||
|
||||
|
||||
class DatabaseChangeStream(ChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in a database.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.database.AsyncDatabase.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
_target: AsyncDatabase[_DocumentType]
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_DatabaseAggregationCommand]:
|
||||
return _DatabaseAggregationCommand
|
||||
|
||||
@property
|
||||
def _client(self) -> AsyncMongoClient[_DocumentType]:
|
||||
return self._target.client
|
||||
|
||||
|
||||
class ClusterChangeStream(DatabaseChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in the cluster.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.mongo_client.AsyncMongoClient.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
def _change_stream_options(self) -> dict[str, Any]:
|
||||
options = super()._change_stream_options()
|
||||
options["allChangesForCluster"] = True
|
||||
return options
|
||||
334
pymongo/asynchronous/client_options.py
Normal file
334
pymongo/asynchronous/client_options.py
Normal file
@ -0,0 +1,334 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Tools to parse mongo client options."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast
|
||||
|
||||
from bson.codec_options import _parse_codec_options
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo.asynchronous.compression_support import CompressionSettings
|
||||
from pymongo.asynchronous.monitoring import _EventListener, _EventListeners
|
||||
from pymongo.asynchronous.pool import PoolOptions
|
||||
from pymongo.asynchronous.read_preferences import (
|
||||
_ServerMode,
|
||||
make_read_preference,
|
||||
read_pref_mode_from_name,
|
||||
)
|
||||
from pymongo.asynchronous.server_selectors import any_server_selector
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.ssl_support import get_ssl_context
|
||||
from pymongo.write_concern import WriteConcern, validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.codec_options import CodecOptions
|
||||
from pymongo.asynchronous.auth import MongoCredential
|
||||
from pymongo.asynchronous.encryption_options import AutoEncryptionOpts
|
||||
from pymongo.asynchronous.topology_description import _ServerSelector
|
||||
from pymongo.pyopenssl_context import SSLContext
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _parse_credentials(
|
||||
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
) -> Optional[MongoCredential]:
|
||||
"""Parse authentication credentials."""
|
||||
mechanism = options.get("authmechanism", "DEFAULT" if username else None)
|
||||
source = options.get("authsource")
|
||||
if username or mechanism:
|
||||
from pymongo.asynchronous.auth import _build_credentials_tuple
|
||||
|
||||
return _build_credentials_tuple(mechanism, source, username, password, options, database)
|
||||
return None
|
||||
|
||||
|
||||
def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode:
|
||||
"""Parse read preference options."""
|
||||
if "read_preference" in options:
|
||||
return options["read_preference"]
|
||||
|
||||
name = options.get("readpreference", "primary")
|
||||
mode = read_pref_mode_from_name(name)
|
||||
tags = options.get("readpreferencetags")
|
||||
max_staleness = options.get("maxstalenessseconds", -1)
|
||||
return make_read_preference(mode, tags, max_staleness)
|
||||
|
||||
|
||||
def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern:
|
||||
"""Parse write concern options."""
|
||||
concern = options.get("w")
|
||||
wtimeout = options.get("wtimeoutms")
|
||||
j = options.get("journal")
|
||||
fsync = options.get("fsync")
|
||||
return WriteConcern(concern, wtimeout, j, fsync)
|
||||
|
||||
|
||||
def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
|
||||
"""Parse read concern options."""
|
||||
concern = options.get("readconcernlevel")
|
||||
return ReadConcern(concern)
|
||||
|
||||
|
||||
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
|
||||
"""Parse ssl options."""
|
||||
use_tls = options.get("tls")
|
||||
if use_tls is not None:
|
||||
validate_boolean("tls", use_tls)
|
||||
|
||||
certfile = options.get("tlscertificatekeyfile")
|
||||
passphrase = options.get("tlscertificatekeyfilepassword")
|
||||
ca_certs = options.get("tlscafile")
|
||||
crlfile = options.get("tlscrlfile")
|
||||
allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False)
|
||||
allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False)
|
||||
disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False)
|
||||
|
||||
enabled_tls_opts = []
|
||||
for opt in (
|
||||
"tlscertificatekeyfile",
|
||||
"tlscertificatekeyfilepassword",
|
||||
"tlscafile",
|
||||
"tlscrlfile",
|
||||
):
|
||||
# Any non-null value of these options implies tls=True.
|
||||
if opt in options and options[opt]:
|
||||
enabled_tls_opts.append(opt)
|
||||
for opt in (
|
||||
"tlsallowinvalidcertificates",
|
||||
"tlsallowinvalidhostnames",
|
||||
"tlsdisableocspendpointcheck",
|
||||
):
|
||||
# A value of False for these options implies tls=True.
|
||||
if opt in options and not options[opt]:
|
||||
enabled_tls_opts.append(opt)
|
||||
|
||||
if enabled_tls_opts:
|
||||
if use_tls is None:
|
||||
# Implicitly enable TLS when one of the tls* options is set.
|
||||
use_tls = True
|
||||
elif not use_tls:
|
||||
# Error since tls is explicitly disabled but a tls option is set.
|
||||
raise ConfigurationError(
|
||||
"TLS has not been enabled but the "
|
||||
"following tls parameters have been set: "
|
||||
"%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts)
|
||||
)
|
||||
|
||||
if use_tls:
|
||||
ctx = get_ssl_context(
|
||||
certfile,
|
||||
passphrase,
|
||||
ca_certs,
|
||||
crlfile,
|
||||
allow_invalid_certificates,
|
||||
allow_invalid_hostnames,
|
||||
disable_ocsp_endpoint_check,
|
||||
)
|
||||
return ctx, allow_invalid_hostnames
|
||||
return None, allow_invalid_hostnames
|
||||
|
||||
|
||||
def _parse_pool_options(
|
||||
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
) -> PoolOptions:
|
||||
"""Parse connection pool options."""
|
||||
credentials = _parse_credentials(username, password, database, options)
|
||||
max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE)
|
||||
min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE)
|
||||
max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC)
|
||||
if max_pool_size is not None and min_pool_size > max_pool_size:
|
||||
raise ValueError("minPoolSize must be smaller or equal to maxPoolSize")
|
||||
connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT)
|
||||
socket_timeout = options.get("sockettimeoutms")
|
||||
wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT)
|
||||
event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners"))
|
||||
appname = options.get("appname")
|
||||
driver = options.get("driver")
|
||||
server_api = options.get("server_api")
|
||||
compression_settings = CompressionSettings(
|
||||
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
|
||||
)
|
||||
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
|
||||
load_balanced = options.get("loadbalanced")
|
||||
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
|
||||
return PoolOptions(
|
||||
max_pool_size,
|
||||
min_pool_size,
|
||||
max_idle_time_seconds,
|
||||
connect_timeout,
|
||||
socket_timeout,
|
||||
wait_queue_timeout,
|
||||
ssl_context,
|
||||
tls_allow_invalid_hostnames,
|
||||
_EventListeners(event_listeners),
|
||||
appname,
|
||||
driver,
|
||||
compression_settings,
|
||||
max_connecting=max_connecting,
|
||||
server_api=server_api,
|
||||
load_balanced=load_balanced,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
|
||||
class ClientOptions:
|
||||
"""Read only configuration options for an AsyncMongoClient.
|
||||
|
||||
Should not be instantiated directly by application developers. Access
|
||||
a client's options via :attr:`pymongo.mongo_client.AsyncMongoClient.options`
|
||||
instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
):
|
||||
self.__options = options
|
||||
self.__codec_options = _parse_codec_options(options)
|
||||
self.__direct_connection = options.get("directconnection")
|
||||
self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS)
|
||||
# self.__server_selection_timeout is in seconds. Must use full name for
|
||||
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
|
||||
self.__server_selection_timeout = options.get(
|
||||
"serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT
|
||||
)
|
||||
self.__pool_options = _parse_pool_options(username, password, database, options)
|
||||
self.__read_preference = _parse_read_preference(options)
|
||||
self.__replica_set_name = options.get("replicaset")
|
||||
self.__write_concern = _parse_write_concern(options)
|
||||
self.__read_concern = _parse_read_concern(options)
|
||||
self.__connect = options.get("connect")
|
||||
self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY)
|
||||
self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES)
|
||||
self.__retry_reads = options.get("retryreads", common.RETRY_READS)
|
||||
self.__server_selector = options.get("server_selector", any_server_selector)
|
||||
self.__auto_encryption_opts = options.get("auto_encryption_opts")
|
||||
self.__load_balanced = options.get("loadbalanced")
|
||||
self.__timeout = options.get("timeoutms")
|
||||
self.__server_monitoring_mode = options.get(
|
||||
"servermonitoringmode", common.SERVER_MONITORING_MODE
|
||||
)
|
||||
|
||||
@property
|
||||
def _options(self) -> Mapping[str, Any]:
|
||||
"""The original options used to create this ClientOptions."""
|
||||
return self.__options
|
||||
|
||||
@property
|
||||
def connect(self) -> Optional[bool]:
|
||||
"""Whether to begin discovering a MongoDB topology automatically."""
|
||||
return self.__connect
|
||||
|
||||
@property
|
||||
def codec_options(self) -> CodecOptions:
|
||||
"""A :class:`~bson.codec_options.CodecOptions` instance."""
|
||||
return self.__codec_options
|
||||
|
||||
@property
|
||||
def direct_connection(self) -> Optional[bool]:
|
||||
"""Whether to connect to the deployment in 'Single' topology."""
|
||||
return self.__direct_connection
|
||||
|
||||
@property
|
||||
def local_threshold_ms(self) -> int:
|
||||
"""The local threshold for this instance."""
|
||||
return self.__local_threshold_ms
|
||||
|
||||
@property
|
||||
def server_selection_timeout(self) -> int:
|
||||
"""The server selection timeout for this instance in seconds."""
|
||||
return self.__server_selection_timeout
|
||||
|
||||
@property
|
||||
def server_selector(self) -> _ServerSelector:
|
||||
return self.__server_selector
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self) -> int:
|
||||
"""The monitoring frequency in seconds."""
|
||||
return self.__heartbeat_frequency
|
||||
|
||||
@property
|
||||
def pool_options(self) -> PoolOptions:
|
||||
"""A :class:`~pymongo.pool.PoolOptions` instance."""
|
||||
return self.__pool_options
|
||||
|
||||
@property
|
||||
def read_preference(self) -> _ServerMode:
|
||||
"""A read preference instance."""
|
||||
return self.__read_preference
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""Replica set name or None."""
|
||||
return self.__replica_set_name
|
||||
|
||||
@property
|
||||
def write_concern(self) -> WriteConcern:
|
||||
"""A :class:`~pymongo.write_concern.WriteConcern` instance."""
|
||||
return self.__write_concern
|
||||
|
||||
@property
|
||||
def read_concern(self) -> ReadConcern:
|
||||
"""A :class:`~pymongo.read_concern.ReadConcern` instance."""
|
||||
return self.__read_concern
|
||||
|
||||
@property
|
||||
def timeout(self) -> Optional[float]:
|
||||
"""The configured timeoutMS converted to seconds, or None.
|
||||
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
return self.__timeout
|
||||
|
||||
@property
|
||||
def retry_writes(self) -> bool:
|
||||
"""If this instance should retry supported write operations."""
|
||||
return self.__retry_writes
|
||||
|
||||
@property
|
||||
def retry_reads(self) -> bool:
|
||||
"""If this instance should retry supported read operations."""
|
||||
return self.__retry_reads
|
||||
|
||||
@property
|
||||
def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]:
|
||||
"""A :class:`~pymongo.encryption.AutoEncryptionOpts` or None."""
|
||||
return self.__auto_encryption_opts
|
||||
|
||||
@property
|
||||
def load_balanced(self) -> Optional[bool]:
|
||||
"""True if the client was configured to connect to a load balancer."""
|
||||
return self.__load_balanced
|
||||
|
||||
@property
|
||||
def event_listeners(self) -> list[_EventListeners]:
|
||||
"""The event listeners registered for this client.
|
||||
|
||||
See :mod:`~pymongo.monitoring` for details.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
assert self.__pool_options._event_listeners is not None
|
||||
return self.__pool_options._event_listeners.event_listeners()
|
||||
|
||||
@property
|
||||
def server_monitoring_mode(self) -> str:
|
||||
"""The configured serverMonitoringMode option.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
"""
|
||||
return self.__server_monitoring_mode
|
||||
1161
pymongo/asynchronous/client_session.py
Normal file
1161
pymongo/asynchronous/client_session.py
Normal file
File diff suppressed because it is too large
Load Diff
226
pymongo/asynchronous/collation.py
Normal file
226
pymongo/asynchronous/collation.py
Normal file
@ -0,0 +1,226 @@
|
||||
# Copyright 2016 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tools for working with `collations`_.
|
||||
|
||||
.. _collations: https://www.mongodb.com/docs/manual/reference/collation/
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping, Optional, Union
|
||||
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class CollationStrength:
|
||||
"""
|
||||
An enum that defines values for `strength` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
PRIMARY = 1
|
||||
"""Differentiate base (unadorned) characters."""
|
||||
|
||||
SECONDARY = 2
|
||||
"""Differentiate character accents."""
|
||||
|
||||
TERTIARY = 3
|
||||
"""Differentiate character case."""
|
||||
|
||||
QUATERNARY = 4
|
||||
"""Differentiate words with and without punctuation."""
|
||||
|
||||
IDENTICAL = 5
|
||||
"""Differentiate unicode code point (characters are exactly identical)."""
|
||||
|
||||
|
||||
class CollationAlternate:
|
||||
"""
|
||||
An enum that defines values for `alternate` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
NON_IGNORABLE = "non-ignorable"
|
||||
"""Spaces and punctuation are treated as base characters."""
|
||||
|
||||
SHIFTED = "shifted"
|
||||
"""Spaces and punctuation are *not* considered base characters.
|
||||
|
||||
Spaces and punctuation are distinguished regardless when the
|
||||
:class:`~pymongo.collation.Collation` strength is at least
|
||||
:data:`~pymongo.collation.CollationStrength.QUATERNARY`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class CollationMaxVariable:
|
||||
"""
|
||||
An enum that defines values for `max_variable` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
PUNCT = "punct"
|
||||
"""Both punctuation and spaces are ignored."""
|
||||
|
||||
SPACE = "space"
|
||||
"""Spaces alone are ignored."""
|
||||
|
||||
|
||||
class CollationCaseFirst:
|
||||
"""
|
||||
An enum that defines values for `case_first` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
UPPER = "upper"
|
||||
"""Sort uppercase characters first."""
|
||||
|
||||
LOWER = "lower"
|
||||
"""Sort lowercase characters first."""
|
||||
|
||||
OFF = "off"
|
||||
"""Default for locale or collation strength."""
|
||||
|
||||
|
||||
class Collation:
|
||||
"""Collation
|
||||
|
||||
:param locale: (string) The locale of the collation. This should be a string
|
||||
that identifies an `ICU locale ID` exactly. For example, ``en_US`` is
|
||||
valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB
|
||||
documentation for a list of supported locales.
|
||||
:param caseLevel: (optional) If ``True``, turn on case sensitivity if
|
||||
`strength` is 1 or 2 (case sensitivity is implied if `strength` is
|
||||
greater than 2). Defaults to ``False``.
|
||||
:param caseFirst: (optional) Specify that either uppercase or lowercase
|
||||
characters take precedence. Must be one of the following values:
|
||||
|
||||
* :data:`~CollationCaseFirst.UPPER`
|
||||
* :data:`~CollationCaseFirst.LOWER`
|
||||
* :data:`~CollationCaseFirst.OFF` (the default)
|
||||
|
||||
:param strength: Specify the comparison strength. This is also
|
||||
known as the ICU comparison level. This must be one of the following
|
||||
values:
|
||||
|
||||
* :data:`~CollationStrength.PRIMARY`
|
||||
* :data:`~CollationStrength.SECONDARY`
|
||||
* :data:`~CollationStrength.TERTIARY` (the default)
|
||||
* :data:`~CollationStrength.QUATERNARY`
|
||||
* :data:`~CollationStrength.IDENTICAL`
|
||||
|
||||
Each successive level builds upon the previous. For example, a
|
||||
`strength` of :data:`~CollationStrength.SECONDARY` differentiates
|
||||
characters based both on the unadorned base character and its accents.
|
||||
|
||||
:param numericOrdering: If ``True``, order numbers numerically
|
||||
instead of in collation order (defaults to ``False``).
|
||||
:param alternate: Specify whether spaces and punctuation are
|
||||
considered base characters. This must be one of the following values:
|
||||
|
||||
* :data:`~CollationAlternate.NON_IGNORABLE` (the default)
|
||||
* :data:`~CollationAlternate.SHIFTED`
|
||||
|
||||
:param maxVariable: When `alternate` is
|
||||
:data:`~CollationAlternate.SHIFTED`, this option specifies what
|
||||
characters may be ignored. This must be one of the following values:
|
||||
|
||||
* :data:`~CollationMaxVariable.PUNCT` (the default)
|
||||
* :data:`~CollationMaxVariable.SPACE`
|
||||
|
||||
:param normalization: If ``True``, normalizes text into Unicode
|
||||
NFD. Defaults to ``False``.
|
||||
:param backwards: If ``True``, accents on characters are
|
||||
considered from the back of the word to the front, as it is done in some
|
||||
French dictionary ordering traditions. Defaults to ``False``.
|
||||
:param kwargs: Keyword arguments supplying any additional options
|
||||
to be sent with this Collation object.
|
||||
|
||||
.. versionadded: 3.4
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
locale: str,
|
||||
caseLevel: Optional[bool] = None,
|
||||
caseFirst: Optional[str] = None,
|
||||
strength: Optional[int] = None,
|
||||
numericOrdering: Optional[bool] = None,
|
||||
alternate: Optional[str] = None,
|
||||
maxVariable: Optional[str] = None,
|
||||
normalization: Optional[bool] = None,
|
||||
backwards: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
locale = common.validate_string("locale", locale)
|
||||
self.__document: dict[str, Any] = {"locale": locale}
|
||||
if caseLevel is not None:
|
||||
self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel)
|
||||
if caseFirst is not None:
|
||||
self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst)
|
||||
if strength is not None:
|
||||
self.__document["strength"] = common.validate_integer("strength", strength)
|
||||
if numericOrdering is not None:
|
||||
self.__document["numericOrdering"] = validate_boolean(
|
||||
"numericOrdering", numericOrdering
|
||||
)
|
||||
if alternate is not None:
|
||||
self.__document["alternate"] = common.validate_string("alternate", alternate)
|
||||
if maxVariable is not None:
|
||||
self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable)
|
||||
if normalization is not None:
|
||||
self.__document["normalization"] = validate_boolean("normalization", normalization)
|
||||
if backwards is not None:
|
||||
self.__document["backwards"] = validate_boolean("backwards", backwards)
|
||||
self.__document.update(kwargs)
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""The document representation of this collation.
|
||||
|
||||
.. note::
|
||||
:class:`Collation` is immutable. Mutating the value of
|
||||
:attr:`document` does not mutate this :class:`Collation`.
|
||||
"""
|
||||
return self.__document.copy()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
document = self.document
|
||||
return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Collation):
|
||||
return self.document == other.document
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
def validate_collation_or_none(
|
||||
value: Optional[Union[Mapping[str, Any], Collation]]
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, Collation):
|
||||
return value.document
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")
|
||||
3556
pymongo/asynchronous/collection.py
Normal file
3556
pymongo/asynchronous/collection.py
Normal file
File diff suppressed because it is too large
Load Diff
415
pymongo/asynchronous/command_cursor.py
Normal file
415
pymongo/asynchronous/command_cursor.py
Normal file
@ -0,0 +1,415 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CommandCursor class to iterate over command results."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Generic,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson import CodecOptions, _convert_raw_document_lists_to_streams
|
||||
from pymongo.asynchronous.cursor import _ConnectionManager
|
||||
from pymongo.asynchronous.message import (
|
||||
_CursorAddress,
|
||||
_GetMore,
|
||||
_OpMsg,
|
||||
_OpReply,
|
||||
_RawBatchGetMore,
|
||||
)
|
||||
from pymongo.asynchronous.response import PinnedResponse
|
||||
from pymongo.asynchronous.typings import _Address, _DocumentOut, _DocumentType
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.client_session import ClientSession
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
"""An asynchronous cursor / iterator over command cursors."""
|
||||
|
||||
_getmore_class = _GetMore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: AsyncCollection[_DocumentType],
|
||||
cursor_info: Mapping[str, Any],
|
||||
address: Optional[_Address],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
explicit_session: bool = False,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
"""Create a new command cursor."""
|
||||
self._sock_mgr: Any = None
|
||||
self._collection: AsyncCollection[_DocumentType] = collection
|
||||
self._id = cursor_info["id"]
|
||||
self._data = deque(cursor_info["firstBatch"])
|
||||
self._postbatchresumetoken: Optional[Mapping[str, Any]] = cursor_info.get(
|
||||
"postBatchResumeToken"
|
||||
)
|
||||
self._address = address
|
||||
self._batch_size = batch_size
|
||||
self._max_await_time_ms = max_await_time_ms
|
||||
self._session = session
|
||||
self._explicit_session = explicit_session
|
||||
self._killed = self._id == 0
|
||||
self._comment = comment
|
||||
if _IS_SYNC and self._killed:
|
||||
self._end_session(True) # type: ignore[unused-coroutine]
|
||||
|
||||
if "ns" in cursor_info: # noqa: SIM401
|
||||
self._ns = cursor_info["ns"]
|
||||
else:
|
||||
self._ns = collection.full_name
|
||||
|
||||
self.batch_size(batch_size)
|
||||
|
||||
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
|
||||
raise TypeError("max_await_time_ms must be an integer or None")
|
||||
|
||||
def __del__(self) -> None:
|
||||
if _IS_SYNC:
|
||||
self._die(False) # type: ignore[unused-coroutine]
|
||||
|
||||
def batch_size(self, batch_size: int) -> AsyncCommandCursor[_DocumentType]:
|
||||
"""Limits the number of documents returned in one batch. Each batch
|
||||
requires a round trip to the server. It can be adjusted to optimize
|
||||
performance and limit data transfer.
|
||||
|
||||
.. note:: batch_size can not override MongoDB's internal limits on the
|
||||
amount of data it will return to the client in a single batch (i.e
|
||||
if you set batch size to 1,000,000,000, MongoDB will currently only
|
||||
return 4-16MB of results per batch).
|
||||
|
||||
Raises :exc:`TypeError` if `batch_size` is not an integer.
|
||||
Raises :exc:`ValueError` if `batch_size` is less than ``0``.
|
||||
|
||||
:param batch_size: The size of each batch of results requested.
|
||||
"""
|
||||
if not isinstance(batch_size, int):
|
||||
raise TypeError("batch_size must be an integer")
|
||||
if batch_size < 0:
|
||||
raise ValueError("batch_size must be >= 0")
|
||||
|
||||
self._batch_size = batch_size == 1 and 2 or batch_size
|
||||
return self
|
||||
|
||||
def _has_next(self) -> bool:
|
||||
"""Returns `True` if the cursor has documents remaining from the
|
||||
previous batch.
|
||||
"""
|
||||
return len(self._data) > 0
|
||||
|
||||
@property
|
||||
def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]:
|
||||
"""Retrieve the postBatchResumeToken from the response to a
|
||||
changeStream aggregate or getMore.
|
||||
"""
|
||||
return self._postbatchresumetoken
|
||||
|
||||
async def _maybe_pin_connection(self, conn: Connection) -> None:
|
||||
client = self._collection.database.client
|
||||
if not client._should_pin_cursor(self._session):
|
||||
return
|
||||
if not self._sock_mgr:
|
||||
conn.pin_cursor()
|
||||
conn_mgr = _ConnectionManager(conn, False)
|
||||
# Ensure the connection gets returned when the entire result is
|
||||
# returned in the first batch.
|
||||
if self._id == 0:
|
||||
await conn_mgr.close()
|
||||
else:
|
||||
self._sock_mgr = conn_mgr
|
||||
|
||||
def _unpack_response(
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions[Mapping[str, Any]],
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> Sequence[_DocumentOut]:
|
||||
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
"""Does this cursor have the potential to return more data?
|
||||
|
||||
Even if :attr:`alive` is ``True``, :meth:`next` can raise
|
||||
:exc:`StopIteration`. Best to use a for loop::
|
||||
|
||||
async for doc in collection.aggregate(pipeline):
|
||||
print(doc)
|
||||
|
||||
.. note:: :attr:`alive` can be True while iterating a cursor from
|
||||
a failed server. In this case :attr:`alive` will return False after
|
||||
:meth:`next` fails to retrieve the next batch of results from the
|
||||
server.
|
||||
"""
|
||||
return bool(len(self._data) or (not self._killed))
|
||||
|
||||
@property
|
||||
def cursor_id(self) -> int:
|
||||
"""Returns the id of the cursor."""
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def address(self) -> Optional[_Address]:
|
||||
"""The (host, port) of the server used, or None.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return self._address
|
||||
|
||||
@property
|
||||
def session(self) -> Optional[ClientSession]:
|
||||
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
"""
|
||||
if self._explicit_session:
|
||||
return self._session
|
||||
return None
|
||||
|
||||
async def _die(self, synchronous: bool = False) -> None:
|
||||
"""Closes this cursor."""
|
||||
already_killed = self._killed
|
||||
self._killed = True
|
||||
if self._id and not already_killed:
|
||||
cursor_id = self._id
|
||||
assert self._address is not None
|
||||
address = _CursorAddress(self._address, self._ns)
|
||||
else:
|
||||
# Skip killCursors.
|
||||
cursor_id = 0
|
||||
address = None
|
||||
await self._collection.database.client._cleanup_cursor(
|
||||
synchronous,
|
||||
cursor_id,
|
||||
address,
|
||||
self._sock_mgr,
|
||||
self._session,
|
||||
self._explicit_session,
|
||||
)
|
||||
if not self._explicit_session:
|
||||
self._session = None
|
||||
self._sock_mgr = None
|
||||
|
||||
async def _end_session(self, synchronous: bool) -> None:
|
||||
if self._session and not self._explicit_session:
|
||||
await self._session._end_session(lock=synchronous)
|
||||
self._session = None
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Explicitly close / kill this cursor."""
|
||||
await self._die(True)
|
||||
|
||||
async def _send_message(self, operation: _GetMore) -> None:
|
||||
"""Send a getmore message and handle the response."""
|
||||
client = self._collection.database.client
|
||||
try:
|
||||
response = await client._run_operation(
|
||||
operation, self._unpack_response, address=self._address
|
||||
)
|
||||
except OperationFailure as exc:
|
||||
if exc.code in _CURSOR_CLOSED_ERRORS:
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self._killed = True
|
||||
if exc.timeout:
|
||||
await self._die(False)
|
||||
else:
|
||||
# Return the session and pinned connection, if necessary.
|
||||
await self.close()
|
||||
raise
|
||||
except ConnectionFailure:
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self._killed = True
|
||||
# Return the session and pinned connection, if necessary.
|
||||
await self.close()
|
||||
raise
|
||||
except Exception:
|
||||
await self.close()
|
||||
raise
|
||||
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self._sock_mgr:
|
||||
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
|
||||
if response.from_command:
|
||||
cursor = response.docs[0]["cursor"]
|
||||
documents = cursor["nextBatch"]
|
||||
self._postbatchresumetoken = cursor.get("postBatchResumeToken")
|
||||
self._id = cursor["id"]
|
||||
else:
|
||||
documents = response.docs
|
||||
assert isinstance(response.data, _OpReply)
|
||||
self._id = response.data.cursor_id
|
||||
|
||||
if self._id == 0:
|
||||
await self.close()
|
||||
self._data = deque(documents)
|
||||
|
||||
async def _refresh(self) -> int:
|
||||
"""Refreshes the cursor with more data from the server.
|
||||
|
||||
Returns the length of self._data after refresh. Will exit early if
|
||||
self._data is already non-empty. Raises OperationFailure when the
|
||||
cursor cannot be refreshed due to an error on the query.
|
||||
"""
|
||||
if len(self._data) or self._killed:
|
||||
return len(self._data)
|
||||
|
||||
if self._id: # Get More
|
||||
dbname, collname = self._ns.split(".", 1)
|
||||
read_pref = self._collection._read_preference_for(self.session)
|
||||
await self._send_message(
|
||||
self._getmore_class(
|
||||
dbname,
|
||||
collname,
|
||||
self._batch_size,
|
||||
self._id,
|
||||
self._collection.codec_options,
|
||||
read_pref,
|
||||
self._session,
|
||||
self._collection.database.client,
|
||||
self._max_await_time_ms,
|
||||
self._sock_mgr,
|
||||
False,
|
||||
self._comment,
|
||||
)
|
||||
)
|
||||
else: # Cursor id is zero nothing else to return
|
||||
await self._die(True)
|
||||
|
||||
return len(self._data)
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[_DocumentType]:
|
||||
return self
|
||||
|
||||
async def next(self) -> _DocumentType:
|
||||
"""Advance the cursor."""
|
||||
# Block until a document is returnable.
|
||||
while self.alive:
|
||||
doc = await self._try_next(True)
|
||||
if doc is not None:
|
||||
return doc
|
||||
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def __anext__(self) -> _DocumentType:
|
||||
return await self.next()
|
||||
|
||||
async def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor blocking for at most one getMore command."""
|
||||
if not len(self._data) and not self._killed and get_more_allowed:
|
||||
await self._refresh()
|
||||
if len(self._data):
|
||||
return self._data.popleft()
|
||||
else:
|
||||
return None
|
||||
|
||||
async def try_next(self) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor without blocking indefinitely.
|
||||
|
||||
This method returns the next document without waiting
|
||||
indefinitely for data.
|
||||
|
||||
If no document is cached locally then this method runs a single
|
||||
getMore command. If the getMore yields any documents, the next
|
||||
document is returned, otherwise, if the getMore returns no documents
|
||||
(because there is no additional data) then ``None`` is returned.
|
||||
|
||||
:return: The next document or ``None`` when no document is available
|
||||
after running a single getMore or when the cursor is closed.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
"""
|
||||
return await self._try_next(get_more_allowed=True)
|
||||
|
||||
async def __aenter__(self) -> AsyncCommandCursor[_DocumentType]:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
async def to_list(self) -> list[_DocumentType]:
|
||||
return [x async for x in self] # noqa: C416,RUF100
|
||||
|
||||
|
||||
class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):
|
||||
_getmore_class = _RawBatchGetMore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: AsyncCollection[_DocumentType],
|
||||
cursor_info: Mapping[str, Any],
|
||||
address: Optional[_Address],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
explicit_session: bool = False,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
"""Create a new cursor / iterator over raw batches of BSON data.
|
||||
|
||||
Should not be called directly by application developers -
|
||||
see :meth:`~pymongo.collection.AsyncCollection.aggregate_raw_batches`
|
||||
instead.
|
||||
|
||||
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
|
||||
"""
|
||||
assert not cursor_info.get("firstBatch")
|
||||
super().__init__(
|
||||
collection,
|
||||
cursor_info,
|
||||
address,
|
||||
batch_size,
|
||||
max_await_time_ms,
|
||||
session,
|
||||
explicit_session,
|
||||
comment,
|
||||
)
|
||||
|
||||
def _unpack_response( # type: ignore[override]
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
|
||||
if not legacy_response:
|
||||
# OP_MSG returns firstBatch/nextBatch documents as a BSON array
|
||||
# Re-assemble the array of documents into a document stream
|
||||
_convert_raw_document_lists_to_streams(raw_response[0])
|
||||
return raw_response # type: ignore[return-value]
|
||||
|
||||
def __getitem__(self, index: int) -> NoReturn:
|
||||
raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor")
|
||||
1062
pymongo/asynchronous/common.py
Normal file
1062
pymongo/asynchronous/common.py
Normal file
File diff suppressed because it is too large
Load Diff
178
pymongo/asynchronous/compression_support.py
Normal file
178
pymongo/asynchronous/compression_support.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright 2018 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
from pymongo.asynchronous.hello_compat import HelloCompat
|
||||
from pymongo.helpers_constants import _SENSITIVE_COMMANDS
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"}
|
||||
_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD}
|
||||
_NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
|
||||
|
||||
|
||||
def _have_snappy() -> bool:
|
||||
try:
|
||||
import snappy # type:ignore[import] # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _have_zlib() -> bool:
|
||||
try:
|
||||
import zlib # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _have_zstd() -> bool:
|
||||
try:
|
||||
import zstandard # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[str]:
|
||||
try:
|
||||
# `value` is string.
|
||||
compressors = value.split(",") # type: ignore[union-attr]
|
||||
except AttributeError:
|
||||
# `value` is an iterable.
|
||||
compressors = list(value)
|
||||
|
||||
for compressor in compressors[:]:
|
||||
if compressor not in _SUPPORTED_COMPRESSORS:
|
||||
compressors.remove(compressor)
|
||||
warnings.warn(f"Unsupported compressor: {compressor}", stacklevel=2)
|
||||
elif compressor == "snappy" and not _have_snappy():
|
||||
compressors.remove(compressor)
|
||||
warnings.warn(
|
||||
"Wire protocol compression with snappy is not available. "
|
||||
"You must install the python-snappy module for snappy support.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif compressor == "zlib" and not _have_zlib():
|
||||
compressors.remove(compressor)
|
||||
warnings.warn(
|
||||
"Wire protocol compression with zlib is not available. "
|
||||
"The zlib module is not available.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif compressor == "zstd" and not _have_zstd():
|
||||
compressors.remove(compressor)
|
||||
warnings.warn(
|
||||
"Wire protocol compression with zstandard is not available. "
|
||||
"You must install the zstandard module for zstandard support.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return compressors
|
||||
|
||||
|
||||
def validate_zlib_compression_level(option: str, value: Any) -> int:
|
||||
try:
|
||||
level = int(value)
|
||||
except Exception:
|
||||
raise TypeError(f"{option} must be an integer, not {value!r}.") from None
|
||||
if level < -1 or level > 9:
|
||||
raise ValueError("%s must be between -1 and 9, not %d." % (option, level))
|
||||
return level
|
||||
|
||||
|
||||
class CompressionSettings:
|
||||
def __init__(self, compressors: list[str], zlib_compression_level: int):
|
||||
self.compressors = compressors
|
||||
self.zlib_compression_level = zlib_compression_level
|
||||
|
||||
def get_compression_context(
|
||||
self, compressors: Optional[list[str]]
|
||||
) -> Union[SnappyContext, ZlibContext, ZstdContext, None]:
|
||||
if compressors:
|
||||
chosen = compressors[0]
|
||||
if chosen == "snappy":
|
||||
return SnappyContext()
|
||||
elif chosen == "zlib":
|
||||
return ZlibContext(self.zlib_compression_level)
|
||||
elif chosen == "zstd":
|
||||
return ZstdContext()
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
class SnappyContext:
|
||||
compressor_id = 1
|
||||
|
||||
@staticmethod
|
||||
def compress(data: bytes) -> bytes:
|
||||
import snappy
|
||||
|
||||
return snappy.compress(data)
|
||||
|
||||
|
||||
class ZlibContext:
|
||||
compressor_id = 2
|
||||
|
||||
def __init__(self, level: int):
|
||||
self.level = level
|
||||
|
||||
def compress(self, data: bytes) -> bytes:
|
||||
import zlib
|
||||
|
||||
return zlib.compress(data, self.level)
|
||||
|
||||
|
||||
class ZstdContext:
|
||||
compressor_id = 3
|
||||
|
||||
@staticmethod
|
||||
def compress(data: bytes) -> bytes:
|
||||
# ZstdCompressor is not thread safe.
|
||||
# TODO: Use a pool?
|
||||
|
||||
import zstandard
|
||||
|
||||
return zstandard.ZstdCompressor().compress(data)
|
||||
|
||||
|
||||
def decompress(data: bytes, compressor_id: int) -> bytes:
|
||||
if compressor_id == SnappyContext.compressor_id:
|
||||
# python-snappy doesn't support the buffer interface.
|
||||
# https://github.com/andrix/python-snappy/issues/65
|
||||
# This only matters when data is a memoryview since
|
||||
# id(bytes(data)) == id(data) when data is a bytes.
|
||||
import snappy
|
||||
|
||||
return snappy.uncompress(bytes(data))
|
||||
elif compressor_id == ZlibContext.compressor_id:
|
||||
import zlib
|
||||
|
||||
return zlib.decompress(data)
|
||||
elif compressor_id == ZstdContext.compressor_id:
|
||||
# ZstdDecompressor is not thread safe.
|
||||
# TODO: Use a pool?
|
||||
import zstandard
|
||||
|
||||
return zstandard.ZstdDecompressor().decompress(data)
|
||||
else:
|
||||
raise ValueError("Unknown compressorId %d" % (compressor_id,))
|
||||
1293
pymongo/asynchronous/cursor.py
Normal file
1293
pymongo/asynchronous/cursor.py
Normal file
File diff suppressed because it is too large
Load Diff
1426
pymongo/asynchronous/database.py
Normal file
1426
pymongo/asynchronous/database.py
Normal file
File diff suppressed because it is too large
Load Diff
1122
pymongo/asynchronous/encryption.py
Normal file
1122
pymongo/asynchronous/encryption.py
Normal file
File diff suppressed because it is too large
Load Diff
270
pymongo/asynchronous/encryption_options.py
Normal file
270
pymongo/asynchronous/encryption_options.py
Normal file
@ -0,0 +1,270 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Support for automatic client-side field level encryption."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
||||
|
||||
try:
|
||||
import pymongocrypt # type:ignore[import] # noqa: F401
|
||||
|
||||
_HAVE_PYMONGOCRYPT = True
|
||||
except ImportError:
|
||||
_HAVE_PYMONGOCRYPT = False
|
||||
from bson import int64
|
||||
from pymongo.asynchronous.common import validate_is_mapping
|
||||
from pymongo.asynchronous.uri_parser import _parse_kms_tls_options
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.typings import _DocumentTypeArg
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class AutoEncryptionOpts:
|
||||
"""Options to configure automatic client-side field level encryption."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: Optional[AsyncMongoClient[_DocumentTypeArg]] = None,
|
||||
schema_map: Optional[Mapping[str, Any]] = None,
|
||||
bypass_auto_encryption: bool = False,
|
||||
mongocryptd_uri: str = "mongodb://localhost:27020",
|
||||
mongocryptd_bypass_spawn: bool = False,
|
||||
mongocryptd_spawn_path: str = "mongocryptd",
|
||||
mongocryptd_spawn_args: Optional[list[str]] = None,
|
||||
kms_tls_options: Optional[Mapping[str, Any]] = None,
|
||||
crypt_shared_lib_path: Optional[str] = None,
|
||||
crypt_shared_lib_required: bool = False,
|
||||
bypass_query_analysis: bool = False,
|
||||
encrypted_fields_map: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Options to configure automatic client-side field level encryption.
|
||||
|
||||
Automatic client-side field level encryption requires MongoDB >=4.2
|
||||
enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not
|
||||
supported for operations on a database or view and will result in
|
||||
error.
|
||||
|
||||
Although automatic encryption requires MongoDB >=4.2 enterprise or a
|
||||
MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all
|
||||
users. To configure automatic *decryption* without automatic
|
||||
*encryption* set ``bypass_auto_encryption=True``. Explicit
|
||||
encryption and explicit decryption is also supported for all users
|
||||
with the :class:`~pymongo.encryption.ClientEncryption` class.
|
||||
|
||||
See :ref:`automatic-client-side-encryption` for an example.
|
||||
|
||||
:param kms_providers: Map of KMS provider options. The `kms_providers`
|
||||
map values differ by provider:
|
||||
|
||||
- `aws`: Map with "accessKeyId" and "secretAccessKey" as strings.
|
||||
These are the AWS access key ID and AWS secret access key used
|
||||
to generate KMS messages. An optional "sessionToken" may be
|
||||
included to support temporary AWS credentials.
|
||||
- `azure`: Map with "tenantId", "clientId", and "clientSecret" as
|
||||
strings. Additionally, "identityPlatformEndpoint" may also be
|
||||
specified as a string (defaults to 'login.microsoftonline.com').
|
||||
These are the Azure Active Directory credentials used to
|
||||
generate Azure Key Vault messages.
|
||||
- `gcp`: Map with "email" as a string and "privateKey"
|
||||
as `bytes` or a base64 encoded string.
|
||||
Additionally, "endpoint" may also be specified as a string
|
||||
(defaults to 'oauth2.googleapis.com'). These are the
|
||||
credentials used to generate Google Cloud KMS messages.
|
||||
- `kmip`: Map with "endpoint" as a host with required port.
|
||||
For example: ``{"endpoint": "example.com:443"}``.
|
||||
- `local`: Map with "key" as `bytes` (96 bytes in length) or
|
||||
a base64 encoded string which decodes
|
||||
to 96 bytes. "key" is the master key used to encrypt/decrypt
|
||||
data keys. This key should be generated and stored as securely
|
||||
as possible.
|
||||
|
||||
KMS providers may be specified with an optional name suffix
|
||||
separated by a colon, for example "kmip:name" or "aws:name".
|
||||
Named KMS providers do not support :ref:`CSFLE on-demand credentials`.
|
||||
Named KMS providers enables more than one of each KMS provider type to be configured.
|
||||
For example, to configure multiple local KMS providers::
|
||||
|
||||
kms_providers = {
|
||||
"local": {"key": local_kek1}, # Unnamed KMS provider.
|
||||
"local:myname": {"key": local_kek2}, # Named KMS provider with name "myname".
|
||||
}
|
||||
|
||||
:param key_vault_namespace: The namespace for the key vault collection.
|
||||
The key vault collection contains all data keys used for encryption
|
||||
and decryption. Data keys are stored as documents in this MongoDB
|
||||
collection. Data keys are protected with encryption by a KMS
|
||||
provider.
|
||||
:param key_vault_client: By default, the key vault collection
|
||||
is assumed to reside in the same MongoDB cluster as the encrypted
|
||||
AsyncMongoClient. Use this option to route data key queries to a
|
||||
separate MongoDB cluster.
|
||||
:param schema_map: Map of collection namespace ("db.coll") to
|
||||
JSON Schema. By default, a collection's JSONSchema is periodically
|
||||
polled with the listCollections command. But a JSONSchema may be
|
||||
specified locally with the schemaMap option.
|
||||
|
||||
**Supplying a `schema_map` provides more security than relying on
|
||||
JSON Schemas obtained from the server. It protects against a
|
||||
malicious server advertising a false JSON Schema, which could trick
|
||||
the client into sending unencrypted data that should be
|
||||
encrypted.**
|
||||
|
||||
Schemas supplied in the schemaMap only apply to configuring
|
||||
automatic encryption for client side encryption. Other validation
|
||||
rules in the JSON schema will not be enforced by the driver and
|
||||
will result in an error.
|
||||
:param bypass_auto_encryption: If ``True``, automatic
|
||||
encryption will be disabled but automatic decryption will still be
|
||||
enabled. Defaults to ``False``.
|
||||
:param mongocryptd_uri: The MongoDB URI used to connect
|
||||
to the *local* mongocryptd process. Defaults to
|
||||
``'mongodb://localhost:27020'``.
|
||||
:param mongocryptd_bypass_spawn: If ``True``, the encrypted
|
||||
AsyncMongoClient will not attempt to spawn the mongocryptd process.
|
||||
Defaults to ``False``.
|
||||
:param mongocryptd_spawn_path: Used for spawning the
|
||||
mongocryptd process. Defaults to ``'mongocryptd'`` and spawns
|
||||
mongocryptd from the system path.
|
||||
:param mongocryptd_spawn_args: A list of string arguments to
|
||||
use when spawning the mongocryptd process. Defaults to
|
||||
``['--idleShutdownTimeoutSecs=60']``. If the list does not include
|
||||
the ``idleShutdownTimeoutSecs`` option then
|
||||
``'--idleShutdownTimeoutSecs=60'`` will be added.
|
||||
:param kms_tls_options: A map of KMS provider names to TLS
|
||||
options to use when creating secure connections to KMS providers.
|
||||
Accepts the same TLS options as
|
||||
:class:`pymongo.mongo_client.AsyncMongoClient`. For example, to
|
||||
override the system default CA file::
|
||||
|
||||
kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}}
|
||||
|
||||
Or to supply a client certificate::
|
||||
|
||||
kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}}
|
||||
:param crypt_shared_lib_path: Override the path to load the crypt_shared library.
|
||||
:param crypt_shared_lib_required: If True, raise an error if libmongocrypt is
|
||||
unable to load the crypt_shared library.
|
||||
:param bypass_query_analysis: If ``True``, disable automatic analysis
|
||||
of outgoing commands. Set `bypass_query_analysis` to use explicit
|
||||
encryption on indexed fields without the MongoDB Enterprise Advanced
|
||||
licensed crypt_shared library.
|
||||
:param encrypted_fields_map: Map of collection namespace ("db.coll") to documents
|
||||
that described the encrypted fields for Queryable Encryption. For example::
|
||||
|
||||
{
|
||||
"db.encryptedCollection": {
|
||||
"escCollection": "enxcol_.encryptedCollection.esc",
|
||||
"ecocCollection": "enxcol_.encryptedCollection.ecoc",
|
||||
"fields": [
|
||||
{
|
||||
"path": "firstName",
|
||||
"keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')),
|
||||
"bsonType": "string",
|
||||
"queries": {"queryType": "equality"}
|
||||
},
|
||||
{
|
||||
"path": "ssn",
|
||||
"keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')),
|
||||
"bsonType": "string"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`,
|
||||
and `bypass_query_analysis` parameters.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Added the `kms_tls_options` parameter and the "kmip" KMS provider.
|
||||
|
||||
.. versionadded:: 3.9
|
||||
"""
|
||||
if not _HAVE_PYMONGOCRYPT:
|
||||
raise ConfigurationError(
|
||||
"client side encryption requires the pymongocrypt library: "
|
||||
"install a compatible version with: "
|
||||
"python -m pip install 'pymongo[encryption]'"
|
||||
)
|
||||
if encrypted_fields_map:
|
||||
validate_is_mapping("encrypted_fields_map", encrypted_fields_map)
|
||||
self._encrypted_fields_map = encrypted_fields_map
|
||||
self._bypass_query_analysis = bypass_query_analysis
|
||||
self._crypt_shared_lib_path = crypt_shared_lib_path
|
||||
self._crypt_shared_lib_required = crypt_shared_lib_required
|
||||
self._kms_providers = kms_providers
|
||||
self._key_vault_namespace = key_vault_namespace
|
||||
self._key_vault_client = key_vault_client
|
||||
self._schema_map = schema_map
|
||||
self._bypass_auto_encryption = bypass_auto_encryption
|
||||
self._mongocryptd_uri = mongocryptd_uri
|
||||
self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
|
||||
self._mongocryptd_spawn_path = mongocryptd_spawn_path
|
||||
if mongocryptd_spawn_args is None:
|
||||
mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"]
|
||||
self._mongocryptd_spawn_args = mongocryptd_spawn_args
|
||||
if not isinstance(self._mongocryptd_spawn_args, list):
|
||||
raise TypeError("mongocryptd_spawn_args must be a list")
|
||||
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
|
||||
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
|
||||
# Maps KMS provider name to a SSLContext.
|
||||
self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options)
|
||||
self._bypass_query_analysis = bypass_query_analysis
|
||||
|
||||
|
||||
class RangeOpts:
|
||||
"""Options to configure encrypted queries using the rangePreview algorithm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sparsity: int,
|
||||
min: Optional[Any] = None,
|
||||
max: Optional[Any] = None,
|
||||
precision: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Options to configure encrypted queries using the rangePreview algorithm.
|
||||
|
||||
.. note:: This feature is experimental only, and not intended for public use.
|
||||
|
||||
:param sparsity: An integer.
|
||||
:param min: A BSON scalar value corresponding to the type being queried.
|
||||
:param max: A BSON scalar value corresponding to the type being queried.
|
||||
:param precision: An integer, may only be set for double or decimal128 types.
|
||||
|
||||
.. versionadded:: 4.4
|
||||
"""
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.sparsity = sparsity
|
||||
self.precision = precision
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
doc = {}
|
||||
for k, v in [
|
||||
("sparsity", int64.Int64(self.sparsity)),
|
||||
("precision", self.precision),
|
||||
("min", self.min),
|
||||
("max", self.max),
|
||||
]:
|
||||
if v is not None:
|
||||
doc[k] = v
|
||||
return doc
|
||||
225
pymongo/asynchronous/event_loggers.py
Normal file
225
pymongo/asynchronous/event_loggers.py
Normal file
@ -0,0 +1,225 @@
|
||||
# Copyright 2020-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Example event logger classes.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
|
||||
These loggers can be registered using :func:`register` or
|
||||
:class:`~pymongo.mongo_client.MongoClient`.
|
||||
|
||||
``monitoring.register(CommandLogger())``
|
||||
|
||||
or
|
||||
|
||||
``MongoClient(event_listeners=[CommandLogger()])``
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from pymongo.asynchronous import monitoring
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class CommandLogger(monitoring.CommandListener):
|
||||
"""A simple listener that logs command events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.CommandStartedEvent`,
|
||||
:class:`~pymongo.monitoring.CommandSucceededEvent` and
|
||||
:class:`~pymongo.monitoring.CommandFailedEvent` events and
|
||||
logs them at the `INFO` severity level using :mod:`logging`.
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def started(self, event: monitoring.CommandStartedEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} started on server "
|
||||
f"{event.connection_id}"
|
||||
)
|
||||
|
||||
def succeeded(self, event: monitoring.CommandSucceededEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} on server {event.connection_id} "
|
||||
f"succeeded in {event.duration_micros} "
|
||||
"microseconds"
|
||||
)
|
||||
|
||||
def failed(self, event: monitoring.CommandFailedEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} on server {event.connection_id} "
|
||||
f"failed in {event.duration_micros} "
|
||||
"microseconds"
|
||||
)
|
||||
|
||||
|
||||
class ServerLogger(monitoring.ServerListener):
|
||||
"""A simple listener that logs server discovery events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`,
|
||||
:class:`~pymongo.monitoring.ServerDescriptionChangedEvent`,
|
||||
and :class:`~pymongo.monitoring.ServerClosedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def opened(self, event: monitoring.ServerOpeningEvent) -> None:
|
||||
logging.info(f"Server {event.server_address} added to topology {event.topology_id}")
|
||||
|
||||
def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None:
|
||||
previous_server_type = event.previous_description.server_type
|
||||
new_server_type = event.new_description.server_type
|
||||
if new_server_type != previous_server_type:
|
||||
# server_type_name was added in PyMongo 3.4
|
||||
logging.info(
|
||||
f"Server {event.server_address} changed type from "
|
||||
f"{event.previous_description.server_type_name} to "
|
||||
f"{event.new_description.server_type_name}"
|
||||
)
|
||||
|
||||
def closed(self, event: monitoring.ServerClosedEvent) -> None:
|
||||
logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}")
|
||||
|
||||
|
||||
class HeartbeatLogger(monitoring.ServerHeartbeatListener):
|
||||
"""A simple listener that logs server heartbeat events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`,
|
||||
:class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`,
|
||||
and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
|
||||
logging.info(f"Heartbeat sent to server {event.connection_id}")
|
||||
|
||||
def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
|
||||
# The reply.document attribute was added in PyMongo 3.4.
|
||||
logging.info(
|
||||
f"Heartbeat to server {event.connection_id} "
|
||||
"succeeded with reply "
|
||||
f"{event.reply.document}"
|
||||
)
|
||||
|
||||
def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
|
||||
logging.warning(
|
||||
f"Heartbeat to server {event.connection_id} failed with error {event.reply}"
|
||||
)
|
||||
|
||||
|
||||
class TopologyLogger(monitoring.TopologyListener):
|
||||
"""A simple listener that logs server topology events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`,
|
||||
:class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`,
|
||||
and :class:`~pymongo.monitoring.TopologyClosedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def opened(self, event: monitoring.TopologyOpenedEvent) -> None:
|
||||
logging.info(f"Topology with id {event.topology_id} opened")
|
||||
|
||||
def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None:
|
||||
logging.info(f"Topology description updated for topology id {event.topology_id}")
|
||||
previous_topology_type = event.previous_description.topology_type
|
||||
new_topology_type = event.new_description.topology_type
|
||||
if new_topology_type != previous_topology_type:
|
||||
# topology_type_name was added in PyMongo 3.4
|
||||
logging.info(
|
||||
f"Topology {event.topology_id} changed type from "
|
||||
f"{event.previous_description.topology_type_name} to "
|
||||
f"{event.new_description.topology_type_name}"
|
||||
)
|
||||
# The has_writable_server and has_readable_server methods
|
||||
# were added in PyMongo 3.4.
|
||||
if not event.new_description.has_writable_server():
|
||||
logging.warning("No writable servers available.")
|
||||
if not event.new_description.has_readable_server():
|
||||
logging.warning("No readable servers available.")
|
||||
|
||||
def closed(self, event: monitoring.TopologyClosedEvent) -> None:
|
||||
logging.info(f"Topology with id {event.topology_id} closed")
|
||||
|
||||
|
||||
class ConnectionPoolLogger(monitoring.ConnectionPoolListener):
|
||||
"""A simple listener that logs server connection pool events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`,
|
||||
:class:`~pymongo.monitoring.PoolClearedEvent`,
|
||||
:class:`~pymongo.monitoring.PoolClosedEvent`,
|
||||
:~pymongo.monitoring.class:`ConnectionCreatedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionReadyEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionClosedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckedOutEvent`,
|
||||
and :class:`~pymongo.monitoring.ConnectionCheckedInEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def pool_created(self, event: monitoring.PoolCreatedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool created")
|
||||
|
||||
def pool_ready(self, event: monitoring.PoolReadyEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool ready")
|
||||
|
||||
def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool cleared")
|
||||
|
||||
def pool_closed(self, event: monitoring.PoolClosedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool closed")
|
||||
|
||||
def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created")
|
||||
|
||||
def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded"
|
||||
)
|
||||
|
||||
def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] "
|
||||
f'connection closed, reason: "{event.reason}"'
|
||||
)
|
||||
|
||||
def connection_check_out_started(
|
||||
self, event: monitoring.ConnectionCheckOutStartedEvent
|
||||
) -> None:
|
||||
logging.info(f"[pool {event.address}] connection check out started")
|
||||
|
||||
def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}")
|
||||
|
||||
def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool"
|
||||
)
|
||||
|
||||
def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool"
|
||||
)
|
||||
@ -21,17 +21,12 @@ import itertools
|
||||
from typing import Any, Generic, Mapping, Optional
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo import common
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo.asynchronous.hello_compat import HelloCompat
|
||||
from pymongo.asynchronous.typings import ClusterTime, _DocumentType
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.typings import ClusterTime, _DocumentType
|
||||
|
||||
|
||||
class HelloCompat:
|
||||
CMD = "hello"
|
||||
LEGACY_CMD = "ismaster"
|
||||
PRIMARY = "isWritablePrimary"
|
||||
LEGACY_PRIMARY = "ismaster"
|
||||
LEGACY_ERROR = "not master"
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _get_server_type(doc: Mapping[str, Any]) -> int:
|
||||
26
pymongo/asynchronous/hello_compat.py
Normal file
26
pymongo/asynchronous/hello_compat.py
Normal file
@ -0,0 +1,26 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The HelloCompat class, placed here to break circular import issues."""
|
||||
from __future__ import annotations
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class HelloCompat:
|
||||
CMD = "hello"
|
||||
LEGACY_CMD = "ismaster"
|
||||
PRIMARY = "isWritablePrimary"
|
||||
LEGACY_PRIMARY = "ismaster"
|
||||
LEGACY_ERROR = "not master"
|
||||
321
pymongo/asynchronous/helpers.py
Normal file
321
pymongo/asynchronous/helpers.py
Normal file
@ -0,0 +1,321 @@
|
||||
# Copyright 2009-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Bits and pieces used by the driver that don't really fit elsewhere."""
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import sys
|
||||
import traceback
|
||||
from collections import abc
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Container,
|
||||
Iterable,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pymongo import ASCENDING
|
||||
from pymongo.asynchronous.hello_compat import HelloCompat
|
||||
from pymongo.errors import (
|
||||
CursorNotFound,
|
||||
DuplicateKeyError,
|
||||
ExecutionTimeout,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
WriteConcernError,
|
||||
WriteError,
|
||||
WTimeoutError,
|
||||
_wtimeout_error,
|
||||
)
|
||||
from pymongo.helpers_constants import _NOT_PRIMARY_CODES, _REAUTHENTICATION_REQUIRED_CODE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.operations import _IndexList
|
||||
from pymongo.asynchronous.typings import _DocumentOut
|
||||
from pymongo.cursor_shared import _Hint
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _gen_index_name(keys: _IndexList) -> str:
|
||||
"""Generate an index name from the set of fields it is over."""
|
||||
return "_".join(["{}_{}".format(*item) for item in keys])
|
||||
|
||||
|
||||
def _index_list(
|
||||
key_or_list: _Hint, direction: Optional[Union[int, str]] = None
|
||||
) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]:
|
||||
"""Helper to generate a list of (key, direction) pairs.
|
||||
|
||||
Takes such a list, or a single key, or a single key and direction.
|
||||
"""
|
||||
if direction is not None:
|
||||
if not isinstance(key_or_list, str):
|
||||
raise TypeError("Expected a string and a direction")
|
||||
return [(key_or_list, direction)]
|
||||
else:
|
||||
if isinstance(key_or_list, str):
|
||||
return [(key_or_list, ASCENDING)]
|
||||
elif isinstance(key_or_list, abc.ItemsView):
|
||||
return list(key_or_list) # type: ignore[arg-type]
|
||||
elif isinstance(key_or_list, abc.Mapping):
|
||||
return list(key_or_list.items())
|
||||
elif not isinstance(key_or_list, (list, tuple)):
|
||||
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
|
||||
values: list[tuple[str, int]] = []
|
||||
for item in key_or_list:
|
||||
if isinstance(item, str):
|
||||
item = (item, ASCENDING) # noqa: PLW2901
|
||||
values.append(item)
|
||||
return values
|
||||
|
||||
|
||||
def _index_document(index_list: _IndexList) -> dict[str, Any]:
|
||||
"""Helper to generate an index specifying document.
|
||||
|
||||
Takes a list of (key, direction) pairs.
|
||||
"""
|
||||
if not isinstance(index_list, (list, tuple, abc.Mapping)):
|
||||
raise TypeError(
|
||||
"must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list)
|
||||
)
|
||||
if not len(index_list):
|
||||
raise ValueError("key_or_list must not be empty")
|
||||
|
||||
index: dict[str, Any] = {}
|
||||
|
||||
if isinstance(index_list, abc.Mapping):
|
||||
for key in index_list:
|
||||
value = index_list[key]
|
||||
_validate_index_key_pair(key, value)
|
||||
index[key] = value
|
||||
else:
|
||||
for item in index_list:
|
||||
if isinstance(item, str):
|
||||
item = (item, ASCENDING) # noqa: PLW2901
|
||||
key, value = item
|
||||
_validate_index_key_pair(key, value)
|
||||
index[key] = value
|
||||
return index
|
||||
|
||||
|
||||
def _validate_index_key_pair(key: Any, value: Any) -> None:
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("first item in each key pair must be an instance of str")
|
||||
if not isinstance(value, (str, int, abc.Mapping)):
|
||||
raise TypeError(
|
||||
"second item in each key pair must be 1, -1, "
|
||||
"'2d', or another valid MongoDB index specifier."
|
||||
)
|
||||
|
||||
|
||||
def _check_command_response(
|
||||
response: _DocumentOut,
|
||||
max_wire_version: Optional[int],
|
||||
allowable_errors: Optional[Container[Union[int, str]]] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
) -> None:
|
||||
"""Check the response to a command for errors."""
|
||||
if "ok" not in response:
|
||||
# Server didn't recognize our message as a command.
|
||||
raise OperationFailure(
|
||||
response.get("$err"), # type: ignore[arg-type]
|
||||
response.get("code"),
|
||||
response,
|
||||
max_wire_version,
|
||||
)
|
||||
|
||||
if parse_write_concern_error and "writeConcernError" in response:
|
||||
_error = response["writeConcernError"]
|
||||
_labels = response.get("errorLabels")
|
||||
if _labels:
|
||||
_error.update({"errorLabels": _labels})
|
||||
_raise_write_concern_error(_error)
|
||||
|
||||
if response["ok"]:
|
||||
return
|
||||
|
||||
details = response
|
||||
# Mongos returns the error details in a 'raw' object
|
||||
# for some errors.
|
||||
if "raw" in response:
|
||||
for shard in response["raw"].values():
|
||||
# Grab the first non-empty raw error from a shard.
|
||||
if shard.get("errmsg") and not shard.get("ok"):
|
||||
details = shard
|
||||
break
|
||||
|
||||
errmsg = details["errmsg"]
|
||||
code = details.get("code")
|
||||
|
||||
# For allowable errors, only check for error messages when the code is not
|
||||
# included.
|
||||
if allowable_errors:
|
||||
if code is not None:
|
||||
if code in allowable_errors:
|
||||
return
|
||||
elif errmsg in allowable_errors:
|
||||
return
|
||||
|
||||
# Server is "not primary" or "recovering"
|
||||
if code is not None:
|
||||
if code in _NOT_PRIMARY_CODES:
|
||||
raise NotPrimaryError(errmsg, response)
|
||||
elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg:
|
||||
raise NotPrimaryError(errmsg, response)
|
||||
|
||||
# Other errors
|
||||
# findAndModify with upsert can raise duplicate key error
|
||||
if code in (11000, 11001, 12582):
|
||||
raise DuplicateKeyError(errmsg, code, response, max_wire_version)
|
||||
elif code == 50:
|
||||
raise ExecutionTimeout(errmsg, code, response, max_wire_version)
|
||||
elif code == 43:
|
||||
raise CursorNotFound(errmsg, code, response, max_wire_version)
|
||||
|
||||
raise OperationFailure(errmsg, code, response, max_wire_version)
|
||||
|
||||
|
||||
def _raise_last_write_error(write_errors: list[Any]) -> NoReturn:
|
||||
# If the last batch had multiple errors only report
|
||||
# the last error to emulate continue_on_error.
|
||||
error = write_errors[-1]
|
||||
if error.get("code") == 11000:
|
||||
raise DuplicateKeyError(error.get("errmsg"), 11000, error)
|
||||
raise WriteError(error.get("errmsg"), error.get("code"), error)
|
||||
|
||||
|
||||
def _raise_write_concern_error(error: Any) -> NoReturn:
|
||||
if _wtimeout_error(error):
|
||||
# Make sure we raise WTimeoutError
|
||||
raise WTimeoutError(error.get("errmsg"), error.get("code"), error)
|
||||
raise WriteConcernError(error.get("errmsg"), error.get("code"), error)
|
||||
|
||||
|
||||
def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
|
||||
"""Return the writeConcernError or None."""
|
||||
wce = result.get("writeConcernError")
|
||||
if wce:
|
||||
# The server reports errorLabels at the top level but it's more
|
||||
# convenient to attach it to the writeConcernError doc itself.
|
||||
error_labels = result.get("errorLabels")
|
||||
if error_labels:
|
||||
# Copy to avoid changing the original document.
|
||||
wce = wce.copy()
|
||||
wce["errorLabels"] = error_labels
|
||||
return wce
|
||||
|
||||
|
||||
def _check_write_command_response(result: Mapping[str, Any]) -> None:
|
||||
"""Backward compatibility helper for write command error handling."""
|
||||
# Prefer write errors over write concern errors
|
||||
write_errors = result.get("writeErrors")
|
||||
if write_errors:
|
||||
_raise_last_write_error(write_errors)
|
||||
|
||||
wce = _get_wce_doc(result)
|
||||
if wce:
|
||||
_raise_write_concern_error(wce)
|
||||
|
||||
|
||||
def _fields_list_to_dict(
|
||||
fields: Union[Mapping[str, Any], Iterable[str]], option_name: str
|
||||
) -> Mapping[str, Any]:
|
||||
"""Takes a sequence of field names and returns a matching dictionary.
|
||||
|
||||
["a", "b"] becomes {"a": 1, "b": 1}
|
||||
|
||||
and
|
||||
|
||||
["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1}
|
||||
"""
|
||||
if isinstance(fields, abc.Mapping):
|
||||
return fields
|
||||
|
||||
if isinstance(fields, (abc.Sequence, abc.Set)):
|
||||
if not all(isinstance(field, str) for field in fields):
|
||||
raise TypeError(f"{option_name} must be a list of key names, each an instance of str")
|
||||
return dict.fromkeys(fields, 1)
|
||||
|
||||
raise TypeError(f"{option_name} must be a mapping or list of key names")
|
||||
|
||||
|
||||
def _handle_exception() -> None:
|
||||
"""Print exceptions raised by subscribers to stderr."""
|
||||
# Heavily influenced by logging.Handler.handleError.
|
||||
|
||||
# See note here:
|
||||
# https://docs.python.org/3.4/library/sys.html#sys.__stderr__
|
||||
if sys.stderr:
|
||||
einfo = sys.exc_info()
|
||||
try:
|
||||
traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr)
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
del einfo
|
||||
|
||||
|
||||
# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _handle_reauth(func: F) -> F:
|
||||
async def inner(*args: Any, **kwargs: Any) -> Any:
|
||||
no_reauth = kwargs.pop("no_reauth", False)
|
||||
from pymongo.asynchronous.message import _BulkWriteContext
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except OperationFailure as exc:
|
||||
if no_reauth:
|
||||
raise
|
||||
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
|
||||
# Look for an argument that either is a Connection
|
||||
# or has a connection attribute, so we can trigger
|
||||
# a reauth.
|
||||
conn = None
|
||||
for arg in args:
|
||||
if isinstance(arg, Connection):
|
||||
conn = arg
|
||||
break
|
||||
if isinstance(arg, _BulkWriteContext):
|
||||
conn = arg.conn
|
||||
break
|
||||
if conn:
|
||||
await conn.authenticate(reauthenticate=True)
|
||||
else:
|
||||
raise
|
||||
return func(*args, **kwargs)
|
||||
raise
|
||||
|
||||
return cast(F, inner)
|
||||
|
||||
|
||||
async def anext(cls: Any) -> Any:
|
||||
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
|
||||
if sys.version_info >= (3, 10):
|
||||
return await builtins.anext(cls)
|
||||
else:
|
||||
return await cls.__anext__()
|
||||
171
pymongo/asynchronous/logger.py
Normal file
171
pymongo/asynchronous/logger.py
Normal file
@ -0,0 +1,171 @@
|
||||
# Copyright 2023-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from bson import UuidRepresentation, json_util
|
||||
from bson.json_util import JSONOptions, _truncate_documents
|
||||
from pymongo.asynchronous.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class _CommandStatusMessage(str, enum.Enum):
|
||||
STARTED = "Command started"
|
||||
SUCCEEDED = "Command succeeded"
|
||||
FAILED = "Command failed"
|
||||
|
||||
|
||||
class _ServerSelectionStatusMessage(str, enum.Enum):
|
||||
STARTED = "Server selection started"
|
||||
SUCCEEDED = "Server selection succeeded"
|
||||
FAILED = "Server selection failed"
|
||||
WAITING = "Waiting for suitable server to become available"
|
||||
|
||||
|
||||
class _ConnectionStatusMessage(str, enum.Enum):
|
||||
POOL_CREATED = "Connection pool created"
|
||||
POOL_READY = "Connection pool ready"
|
||||
POOL_CLOSED = "Connection pool closed"
|
||||
POOL_CLEARED = "Connection pool cleared"
|
||||
|
||||
CONN_CREATED = "Connection created"
|
||||
CONN_READY = "Connection ready"
|
||||
CONN_CLOSED = "Connection closed"
|
||||
|
||||
CHECKOUT_STARTED = "Connection checkout started"
|
||||
CHECKOUT_SUCCEEDED = "Connection checked out"
|
||||
CHECKOUT_FAILED = "Connection checkout failed"
|
||||
CHECKEDIN = "Connection checked in"
|
||||
|
||||
|
||||
_DEFAULT_DOCUMENT_LENGTH = 1000
|
||||
_SENSITIVE_COMMANDS = [
|
||||
"authenticate",
|
||||
"saslStart",
|
||||
"saslContinue",
|
||||
"getnonce",
|
||||
"createUser",
|
||||
"updateUser",
|
||||
"copydbgetnonce",
|
||||
"copydbsaslstart",
|
||||
"copydb",
|
||||
]
|
||||
_HELLO_COMMANDS = ["hello", "ismaster", "isMaster"]
|
||||
_REDACTED_FAILURE_FIELDS = ["code", "codeName", "errorLabels"]
|
||||
_DOCUMENT_NAMES = ["command", "reply", "failure"]
|
||||
_JSON_OPTIONS = JSONOptions(uuid_representation=UuidRepresentation.STANDARD)
|
||||
_COMMAND_LOGGER = logging.getLogger("pymongo.command")
|
||||
_CONNECTION_LOGGER = logging.getLogger("pymongo.connection")
|
||||
_SERVER_SELECTION_LOGGER = logging.getLogger("pymongo.serverSelection")
|
||||
_CLIENT_LOGGER = logging.getLogger("pymongo.client")
|
||||
_VERBOSE_CONNECTION_ERROR_REASONS = {
|
||||
ConnectionClosedReason.POOL_CLOSED: "Connection pool was closed",
|
||||
ConnectionCheckOutFailedReason.POOL_CLOSED: "Connection pool was closed",
|
||||
ConnectionClosedReason.STALE: "Connection pool was stale",
|
||||
ConnectionClosedReason.ERROR: "An error occurred while using the connection",
|
||||
ConnectionCheckOutFailedReason.CONN_ERROR: "An error occurred while trying to establish a new connection",
|
||||
ConnectionClosedReason.IDLE: "Connection was idle too long",
|
||||
ConnectionCheckOutFailedReason.TIMEOUT: "Connection exceeded the specified timeout",
|
||||
}
|
||||
|
||||
|
||||
def _debug_log(logger: logging.Logger, **fields: Any) -> None:
|
||||
logger.debug(LogMessage(**fields))
|
||||
|
||||
|
||||
def _verbose_connection_error_reason(reason: str) -> str:
|
||||
return _VERBOSE_CONNECTION_ERROR_REASONS.get(reason, reason)
|
||||
|
||||
|
||||
def _info_log(logger: logging.Logger, **fields: Any) -> None:
|
||||
logger.info(LogMessage(**fields))
|
||||
|
||||
|
||||
def _log_or_warn(logger: logging.Logger, message: str) -> None:
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
logger.info(message)
|
||||
else:
|
||||
# stacklevel=4 ensures that the warning is for the user's code.
|
||||
warnings.warn(message, UserWarning, stacklevel=4)
|
||||
|
||||
|
||||
class LogMessage:
|
||||
__slots__ = ("_kwargs", "_redacted")
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
self._kwargs = kwargs
|
||||
self._redacted = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
self._redact()
|
||||
return "%s" % (
|
||||
json_util.dumps(
|
||||
self._kwargs, json_options=_JSON_OPTIONS, default=lambda o: o.__repr__()
|
||||
)
|
||||
)
|
||||
|
||||
def _is_sensitive(self, doc_name: str) -> bool:
|
||||
is_speculative_authenticate = (
|
||||
self._kwargs.pop("speculative_authenticate", False)
|
||||
or "speculativeAuthenticate" in self._kwargs[doc_name]
|
||||
)
|
||||
is_sensitive_command = (
|
||||
"commandName" in self._kwargs and self._kwargs["commandName"] in _SENSITIVE_COMMANDS
|
||||
)
|
||||
|
||||
is_sensitive_hello = (
|
||||
self._kwargs["commandName"] in _HELLO_COMMANDS and is_speculative_authenticate
|
||||
)
|
||||
|
||||
return is_sensitive_command or is_sensitive_hello
|
||||
|
||||
def _redact(self) -> None:
|
||||
if self._redacted:
|
||||
return
|
||||
self._kwargs = {k: v for k, v in self._kwargs.items() if v is not None}
|
||||
if "durationMS" in self._kwargs and hasattr(self._kwargs["durationMS"], "total_seconds"):
|
||||
self._kwargs["durationMS"] = self._kwargs["durationMS"].total_seconds() * 1000
|
||||
if "serviceId" in self._kwargs:
|
||||
self._kwargs["serviceId"] = str(self._kwargs["serviceId"])
|
||||
document_length = int(os.getenv("MONGOB_LOG_MAX_DOCUMENT_LENGTH", _DEFAULT_DOCUMENT_LENGTH))
|
||||
if document_length < 0:
|
||||
document_length = _DEFAULT_DOCUMENT_LENGTH
|
||||
is_server_side_error = self._kwargs.pop("isServerSideError", False)
|
||||
|
||||
for doc_name in _DOCUMENT_NAMES:
|
||||
doc = self._kwargs.get(doc_name)
|
||||
if doc:
|
||||
if doc_name == "failure" and is_server_side_error:
|
||||
doc = {k: v for k, v in doc.items() if k in _REDACTED_FAILURE_FIELDS}
|
||||
if doc_name != "failure" and self._is_sensitive(doc_name):
|
||||
doc = json_util.dumps({})
|
||||
else:
|
||||
truncated_doc = _truncate_documents(doc, document_length)[0]
|
||||
doc = json_util.dumps(
|
||||
truncated_doc,
|
||||
json_options=_JSON_OPTIONS,
|
||||
default=lambda o: o.__repr__(),
|
||||
)
|
||||
if len(doc) > document_length:
|
||||
doc = (
|
||||
doc.encode()[:document_length].decode("unicode-escape", "ignore")
|
||||
) + "..."
|
||||
self._kwargs[doc_name] = doc
|
||||
self._redacted = True
|
||||
125
pymongo/asynchronous/max_staleness_selectors.py
Normal file
125
pymongo/asynchronous/max_staleness_selectors.py
Normal file
@ -0,0 +1,125 @@
|
||||
# Copyright 2016 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Criteria to select ServerDescriptions based on maxStalenessSeconds.
|
||||
|
||||
The Max Staleness Spec says: When there is a known primary P,
|
||||
a secondary S's staleness is estimated with this formula:
|
||||
|
||||
(S.lastUpdateTime - S.lastWriteDate) - (P.lastUpdateTime - P.lastWriteDate)
|
||||
+ heartbeatFrequencyMS
|
||||
|
||||
When there is no known primary, a secondary S's staleness is estimated with:
|
||||
|
||||
SMax.lastWriteDate - S.lastWriteDate + heartbeatFrequencyMS
|
||||
|
||||
where "SMax" is the secondary with the greatest lastWriteDate.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.server_selectors import Selection
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Constant defined in Max Staleness Spec: An idle primary writes a no-op every
|
||||
# 10 seconds to refresh secondaries' lastWriteDate values.
|
||||
IDLE_WRITE_PERIOD = 10
|
||||
SMALLEST_MAX_STALENESS = 90
|
||||
|
||||
|
||||
def _validate_max_staleness(max_staleness: int, heartbeat_frequency: int) -> None:
|
||||
# We checked for max staleness -1 before this, it must be positive here.
|
||||
if max_staleness < heartbeat_frequency + IDLE_WRITE_PERIOD:
|
||||
raise ConfigurationError(
|
||||
"maxStalenessSeconds must be at least heartbeatFrequencyMS +"
|
||||
" %d seconds. maxStalenessSeconds is set to %d,"
|
||||
" heartbeatFrequencyMS is set to %d."
|
||||
% (IDLE_WRITE_PERIOD, max_staleness, heartbeat_frequency * 1000)
|
||||
)
|
||||
|
||||
if max_staleness < SMALLEST_MAX_STALENESS:
|
||||
raise ConfigurationError(
|
||||
"maxStalenessSeconds must be at least %d. "
|
||||
"maxStalenessSeconds is set to %d." % (SMALLEST_MAX_STALENESS, max_staleness)
|
||||
)
|
||||
|
||||
|
||||
def _with_primary(max_staleness: int, selection: Selection) -> Selection:
|
||||
"""Apply max_staleness, in seconds, to a Selection with a known primary."""
|
||||
primary = selection.primary
|
||||
assert primary
|
||||
sds = []
|
||||
|
||||
for s in selection.server_descriptions:
|
||||
if s.server_type == SERVER_TYPE.RSSecondary:
|
||||
# See max-staleness.rst for explanation of this formula.
|
||||
assert s.last_write_date and primary.last_write_date # noqa: PT018
|
||||
staleness = (
|
||||
(s.last_update_time - s.last_write_date)
|
||||
- (primary.last_update_time - primary.last_write_date)
|
||||
+ selection.heartbeat_frequency
|
||||
)
|
||||
|
||||
if staleness <= max_staleness:
|
||||
sds.append(s)
|
||||
else:
|
||||
sds.append(s)
|
||||
|
||||
return selection.with_server_descriptions(sds)
|
||||
|
||||
|
||||
def _no_primary(max_staleness: int, selection: Selection) -> Selection:
|
||||
"""Apply max_staleness, in seconds, to a Selection with no known primary."""
|
||||
# Secondary that's replicated the most recent writes.
|
||||
smax = selection.secondary_with_max_last_write_date()
|
||||
if not smax:
|
||||
# No secondaries and no primary, short-circuit out of here.
|
||||
return selection.with_server_descriptions([])
|
||||
|
||||
sds = []
|
||||
|
||||
for s in selection.server_descriptions:
|
||||
if s.server_type == SERVER_TYPE.RSSecondary:
|
||||
# See max-staleness.rst for explanation of this formula.
|
||||
assert smax.last_write_date and s.last_write_date # noqa: PT018
|
||||
staleness = smax.last_write_date - s.last_write_date + selection.heartbeat_frequency
|
||||
|
||||
if staleness <= max_staleness:
|
||||
sds.append(s)
|
||||
else:
|
||||
sds.append(s)
|
||||
|
||||
return selection.with_server_descriptions(sds)
|
||||
|
||||
|
||||
def select(max_staleness: int, selection: Selection) -> Selection:
|
||||
"""Apply max_staleness, in seconds, to a Selection."""
|
||||
if max_staleness == -1:
|
||||
return selection
|
||||
|
||||
# Server Selection Spec: If the TopologyType is ReplicaSetWithPrimary or
|
||||
# ReplicaSetNoPrimary, a client MUST raise an error if maxStaleness <
|
||||
# heartbeatFrequency + IDLE_WRITE_PERIOD, or if maxStaleness < 90.
|
||||
_validate_max_staleness(max_staleness, selection.heartbeat_frequency)
|
||||
|
||||
if selection.primary:
|
||||
return _with_primary(max_staleness, selection)
|
||||
else:
|
||||
return _no_primary(max_staleness, selection)
|
||||
1760
pymongo/asynchronous/message.py
Normal file
1760
pymongo/asynchronous/message.py
Normal file
File diff suppressed because it is too large
Load Diff
2543
pymongo/asynchronous/mongo_client.py
Normal file
2543
pymongo/asynchronous/mongo_client.py
Normal file
File diff suppressed because it is too large
Load Diff
487
pymongo/asynchronous/monitor.py
Normal file
487
pymongo/asynchronous/monitor.py
Normal file
@ -0,0 +1,487 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Class to monitor a MongoDB server on a background thread."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
|
||||
|
||||
from pymongo._csot import MovingMinimum
|
||||
from pymongo.asynchronous import common, periodic_executor
|
||||
from pymongo.asynchronous.hello import Hello
|
||||
from pymongo.asynchronous.periodic_executor import _shutdown_executors
|
||||
from pymongo.asynchronous.pool import _is_faas
|
||||
from pymongo.asynchronous.read_preferences import MovingAverage
|
||||
from pymongo.asynchronous.server_description import ServerDescription
|
||||
from pymongo.asynchronous.srv_resolver import _SrvResolver
|
||||
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
|
||||
from pymongo.lock import _create_lock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.pool import Connection, Pool, _CancellationContext
|
||||
from pymongo.asynchronous.settings import TopologySettings
|
||||
from pymongo.asynchronous.topology import Topology
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _sanitize(error: Exception) -> None:
|
||||
"""PYTHON-2433 Clear error traceback info."""
|
||||
error.__traceback__ = None
|
||||
error.__context__ = None
|
||||
error.__cause__ = None
|
||||
|
||||
|
||||
class MonitorBase:
|
||||
def __init__(self, topology: Topology, name: str, interval: int, min_interval: float):
|
||||
"""Base class to do periodic work on a background thread.
|
||||
|
||||
The background thread is signaled to stop when the Topology or
|
||||
this instance is freed.
|
||||
"""
|
||||
|
||||
# We strongly reference the executor and it weakly references us via
|
||||
# this closure. When the monitor is freed, stop the executor soon.
|
||||
async def target() -> bool:
|
||||
monitor = self_ref()
|
||||
if monitor is None:
|
||||
return False # Stop the executor.
|
||||
await monitor._run() # type:ignore[attr-defined]
|
||||
return True
|
||||
|
||||
executor = periodic_executor.PeriodicExecutor(
|
||||
interval=interval, min_interval=min_interval, target=target, name=name
|
||||
)
|
||||
|
||||
self._executor = executor
|
||||
|
||||
def _on_topology_gc(dummy: Optional[Topology] = None) -> None:
|
||||
# This prevents GC from waiting 10 seconds for hello to complete
|
||||
# See test_cleanup_executors_on_client_del.
|
||||
monitor = self_ref()
|
||||
if monitor:
|
||||
monitor.gc_safe_close()
|
||||
|
||||
# Avoid cycles. When self or topology is freed, stop executor soon.
|
||||
self_ref = weakref.ref(self, executor.close)
|
||||
self._topology = weakref.proxy(topology, _on_topology_gc)
|
||||
_register(self)
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
Multiple calls have no effect.
|
||||
"""
|
||||
self._executor.open()
|
||||
|
||||
def gc_safe_close(self) -> None:
|
||||
"""GC safe close."""
|
||||
self._executor.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close and stop monitoring.
|
||||
|
||||
open() restarts the monitor after closing.
|
||||
"""
|
||||
self.gc_safe_close()
|
||||
|
||||
def join(self, timeout: Optional[int] = None) -> None:
|
||||
"""Wait for the monitor to stop."""
|
||||
self._executor.join(timeout)
|
||||
|
||||
def request_check(self) -> None:
|
||||
"""If the monitor is sleeping, wake it soon."""
|
||||
self._executor.wake()
|
||||
|
||||
|
||||
class Monitor(MonitorBase):
|
||||
def __init__(
|
||||
self,
|
||||
server_description: ServerDescription,
|
||||
topology: Topology,
|
||||
pool: Pool,
|
||||
topology_settings: TopologySettings,
|
||||
):
|
||||
"""Class to monitor a MongoDB server on a background thread.
|
||||
|
||||
Pass an initial ServerDescription, a Topology, a Pool, and
|
||||
TopologySettings.
|
||||
|
||||
The Topology is weakly referenced. The Pool must be exclusive to this
|
||||
Monitor.
|
||||
"""
|
||||
super().__init__(
|
||||
topology,
|
||||
"pymongo_server_monitor_thread",
|
||||
topology_settings.heartbeat_frequency,
|
||||
common.MIN_HEARTBEAT_INTERVAL,
|
||||
)
|
||||
self._server_description = server_description
|
||||
self._pool = pool
|
||||
self._settings = topology_settings
|
||||
self._listeners = self._settings._pool_options._event_listeners
|
||||
self._publish = self._listeners is not None and self._listeners.enabled_for_server_heartbeat
|
||||
self._cancel_context: Optional[_CancellationContext] = None
|
||||
self._rtt_monitor = _RttMonitor(
|
||||
topology,
|
||||
topology_settings,
|
||||
topology._create_pool_for_monitor(server_description.address),
|
||||
)
|
||||
if topology_settings.server_monitoring_mode == "stream":
|
||||
self._stream = True
|
||||
elif topology_settings.server_monitoring_mode == "poll":
|
||||
self._stream = False
|
||||
else:
|
||||
self._stream = not _is_faas()
|
||||
|
||||
def cancel_check(self) -> None:
|
||||
"""Cancel any concurrent hello check.
|
||||
|
||||
Note: this is called from a weakref.proxy callback and MUST NOT take
|
||||
any locks.
|
||||
"""
|
||||
context = self._cancel_context
|
||||
if context:
|
||||
# Note: we cannot close the socket because doing so may cause
|
||||
# concurrent reads/writes to hang until a timeout occurs
|
||||
# (depending on the platform).
|
||||
context.cancel()
|
||||
|
||||
async def _start_rtt_monitor(self) -> None:
|
||||
"""Start an _RttMonitor that periodically runs ping."""
|
||||
# If this monitor is closed directly before (or during) this open()
|
||||
# call, the _RttMonitor will not be closed. Checking if this monitor
|
||||
# was closed directly after resolves the race.
|
||||
self._rtt_monitor.open()
|
||||
if self._executor._stopped:
|
||||
await self._rtt_monitor.close()
|
||||
|
||||
def gc_safe_close(self) -> None:
|
||||
self._executor.close()
|
||||
self._rtt_monitor.gc_safe_close()
|
||||
self.cancel_check()
|
||||
|
||||
async def close(self) -> None:
|
||||
self.gc_safe_close()
|
||||
await self._rtt_monitor.close()
|
||||
# Increment the generation and maybe close the socket. If the executor
|
||||
# thread has the socket checked out, it will be closed when checked in.
|
||||
await self._reset_connection()
|
||||
|
||||
async def _reset_connection(self) -> None:
|
||||
# Clear our pooled connection.
|
||||
await self._pool.reset()
|
||||
|
||||
async def _run(self) -> None:
|
||||
try:
|
||||
prev_sd = self._server_description
|
||||
try:
|
||||
self._server_description = await self._check_server()
|
||||
except _OperationCancelled as exc:
|
||||
_sanitize(exc)
|
||||
# Already closed the connection, wait for the next check.
|
||||
self._server_description = ServerDescription(
|
||||
self._server_description.address, error=exc
|
||||
)
|
||||
if prev_sd.is_server_type_known:
|
||||
# Immediately retry since we've already waited 500ms to
|
||||
# discover that we've been cancelled.
|
||||
self._executor.skip_sleep()
|
||||
return
|
||||
|
||||
# Update the Topology and clear the server pool on error.
|
||||
await self._topology.on_change(
|
||||
self._server_description,
|
||||
reset_pool=self._server_description.error,
|
||||
interrupt_connections=isinstance(self._server_description.error, NetworkTimeout),
|
||||
)
|
||||
|
||||
if self._stream and (
|
||||
self._server_description.is_server_type_known
|
||||
and self._server_description.topology_version
|
||||
):
|
||||
await self._start_rtt_monitor()
|
||||
# Immediately check for the next streaming response.
|
||||
self._executor.skip_sleep()
|
||||
|
||||
if self._server_description.error and prev_sd.is_server_type_known:
|
||||
# Immediately retry on network errors.
|
||||
self._executor.skip_sleep()
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
await self.close()
|
||||
|
||||
async def _check_server(self) -> ServerDescription:
|
||||
"""Call hello or read the next streaming response.
|
||||
|
||||
Returns a ServerDescription.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
try:
|
||||
try:
|
||||
return await self._check_once()
|
||||
except (OperationFailure, NotPrimaryError) as exc:
|
||||
# Update max cluster time even when hello fails.
|
||||
details = cast(Mapping[str, Any], exc.details)
|
||||
self._topology.receive_cluster_time(details.get("$clusterTime"))
|
||||
raise
|
||||
except ReferenceError:
|
||||
raise
|
||||
except Exception as error:
|
||||
_sanitize(error)
|
||||
sd = self._server_description
|
||||
address = sd.address
|
||||
duration = time.monotonic() - start
|
||||
if self._publish:
|
||||
awaited = bool(self._stream and sd.is_server_type_known and sd.topology_version)
|
||||
assert self._listeners is not None
|
||||
self._listeners.publish_server_heartbeat_failed(address, duration, error, awaited)
|
||||
await self._reset_connection()
|
||||
if isinstance(error, _OperationCancelled):
|
||||
raise
|
||||
self._rtt_monitor.reset()
|
||||
# Server type defaults to Unknown.
|
||||
return ServerDescription(address, error=error)
|
||||
|
||||
async def _check_once(self) -> ServerDescription:
|
||||
"""A single attempt to call hello.
|
||||
|
||||
Returns a ServerDescription, or raises an exception.
|
||||
"""
|
||||
address = self._server_description.address
|
||||
if self._publish:
|
||||
assert self._listeners is not None
|
||||
sd = self._server_description
|
||||
# XXX: "awaited" could be incorrectly set to True in the rare case
|
||||
# the pool checkout closes and recreates a connection.
|
||||
awaited = bool(
|
||||
self._pool.conns
|
||||
and self._stream
|
||||
and sd.is_server_type_known
|
||||
and sd.topology_version
|
||||
)
|
||||
self._listeners.publish_server_heartbeat_started(address, awaited)
|
||||
|
||||
if self._cancel_context and self._cancel_context.cancelled:
|
||||
await self._reset_connection()
|
||||
async with self._pool.checkout() as conn:
|
||||
self._cancel_context = conn.cancel_context
|
||||
response, round_trip_time = await self._check_with_socket(conn)
|
||||
if not response.awaitable:
|
||||
self._rtt_monitor.add_sample(round_trip_time)
|
||||
|
||||
avg_rtt, min_rtt = self._rtt_monitor.get()
|
||||
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
|
||||
if self._publish:
|
||||
assert self._listeners is not None
|
||||
self._listeners.publish_server_heartbeat_succeeded(
|
||||
address, round_trip_time, response, response.awaitable
|
||||
)
|
||||
return sd
|
||||
|
||||
async def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]:
|
||||
"""Return (Hello, round_trip_time).
|
||||
|
||||
Can raise ConnectionFailure or OperationFailure.
|
||||
"""
|
||||
cluster_time = self._topology.max_cluster_time()
|
||||
start = time.monotonic()
|
||||
if conn.more_to_come:
|
||||
# Read the next streaming hello (MongoDB 4.4+).
|
||||
response = Hello(await conn._next_reply(), awaitable=True)
|
||||
elif (
|
||||
self._stream and conn.performed_handshake and self._server_description.topology_version
|
||||
):
|
||||
# Initiate streaming hello (MongoDB 4.4+).
|
||||
response = await conn._hello(
|
||||
cluster_time,
|
||||
self._server_description.topology_version,
|
||||
self._settings.heartbeat_frequency,
|
||||
)
|
||||
else:
|
||||
# New connection handshake or polling hello (MongoDB <4.4).
|
||||
response = await conn._hello(cluster_time, None, None)
|
||||
return response, time.monotonic() - start
|
||||
|
||||
|
||||
class SrvMonitor(MonitorBase):
|
||||
def __init__(self, topology: Topology, topology_settings: TopologySettings):
|
||||
"""Class to poll SRV records on a background thread.
|
||||
|
||||
Pass a Topology and a TopologySettings.
|
||||
|
||||
The Topology is weakly referenced.
|
||||
"""
|
||||
super().__init__(
|
||||
topology,
|
||||
"pymongo_srv_polling_thread",
|
||||
common.MIN_SRV_RESCAN_INTERVAL,
|
||||
topology_settings.heartbeat_frequency,
|
||||
)
|
||||
self._settings = topology_settings
|
||||
self._seedlist = self._settings._seeds
|
||||
assert isinstance(self._settings.fqdn, str)
|
||||
self._fqdn: str = self._settings.fqdn
|
||||
self._startup_time = time.monotonic()
|
||||
|
||||
async def _run(self) -> None:
|
||||
# Don't poll right after creation, wait 60 seconds first
|
||||
if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL:
|
||||
return
|
||||
seedlist = self._get_seedlist()
|
||||
if seedlist:
|
||||
self._seedlist = seedlist
|
||||
try:
|
||||
await self._topology.on_srv_update(self._seedlist)
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
await self.close()
|
||||
|
||||
def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
|
||||
"""Poll SRV records for a seedlist.
|
||||
|
||||
Returns a list of ServerDescriptions.
|
||||
"""
|
||||
try:
|
||||
resolver = _SrvResolver(
|
||||
self._fqdn,
|
||||
self._settings.pool_options.connect_timeout,
|
||||
self._settings.srv_service_name,
|
||||
)
|
||||
seedlist, ttl = resolver.get_hosts_and_min_ttl()
|
||||
if len(seedlist) == 0:
|
||||
# As per the spec: this should be treated as a failure.
|
||||
raise Exception
|
||||
except Exception:
|
||||
# As per the spec, upon encountering an error:
|
||||
# - An error must not be raised
|
||||
# - SRV records must be rescanned every heartbeatFrequencyMS
|
||||
# - Topology must be left unchanged
|
||||
self.request_check()
|
||||
return None
|
||||
else:
|
||||
self._executor.update_interval(max(ttl, common.MIN_SRV_RESCAN_INTERVAL))
|
||||
return seedlist
|
||||
|
||||
|
||||
class _RttMonitor(MonitorBase):
|
||||
def __init__(self, topology: Topology, topology_settings: TopologySettings, pool: Pool):
|
||||
"""Maintain round trip times for a server.
|
||||
|
||||
The Topology is weakly referenced.
|
||||
"""
|
||||
super().__init__(
|
||||
topology,
|
||||
"pymongo_server_rtt_thread",
|
||||
topology_settings.heartbeat_frequency,
|
||||
common.MIN_HEARTBEAT_INTERVAL,
|
||||
)
|
||||
|
||||
self._pool = pool
|
||||
self._moving_average = MovingAverage()
|
||||
self._moving_min = MovingMinimum()
|
||||
self._lock = _create_lock()
|
||||
|
||||
async def close(self) -> None:
|
||||
self.gc_safe_close()
|
||||
# Increment the generation and maybe close the socket. If the executor
|
||||
# thread has the socket checked out, it will be closed when checked in.
|
||||
await self._pool.reset()
|
||||
|
||||
def add_sample(self, sample: float) -> None:
|
||||
"""Add a RTT sample."""
|
||||
with self._lock:
|
||||
self._moving_average.add_sample(sample)
|
||||
self._moving_min.add_sample(sample)
|
||||
|
||||
def get(self) -> tuple[Optional[float], float]:
|
||||
"""Get the calculated average, or None if no samples yet and the min."""
|
||||
with self._lock:
|
||||
return self._moving_average.get(), self._moving_min.get()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the average RTT."""
|
||||
with self._lock:
|
||||
self._moving_average.reset()
|
||||
self._moving_min.reset()
|
||||
|
||||
async def _run(self) -> None:
|
||||
try:
|
||||
# NOTE: This thread is only run when using the streaming
|
||||
# heartbeat protocol (MongoDB 4.4+).
|
||||
# XXX: Skip check if the server is unknown?
|
||||
rtt = await self._ping()
|
||||
self.add_sample(rtt)
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
await self.close()
|
||||
except Exception:
|
||||
await self._pool.reset()
|
||||
|
||||
async def _ping(self) -> float:
|
||||
"""Run a "hello" command and return the RTT."""
|
||||
async with self._pool.checkout() as conn:
|
||||
if self._executor._stopped:
|
||||
raise Exception("_RttMonitor closed")
|
||||
start = time.monotonic()
|
||||
await conn.hello()
|
||||
return time.monotonic() - start
|
||||
|
||||
|
||||
# Close monitors to cancel any in progress streaming checks before joining
|
||||
# executor threads. For an explanation of how this works see the comment
|
||||
# about _EXECUTORS in periodic_executor.py.
|
||||
_MONITORS = set()
|
||||
|
||||
|
||||
def _register(monitor: MonitorBase) -> None:
|
||||
ref = weakref.ref(monitor, _unregister)
|
||||
_MONITORS.add(ref)
|
||||
|
||||
|
||||
def _unregister(monitor_ref: weakref.ReferenceType[MonitorBase]) -> None:
|
||||
_MONITORS.remove(monitor_ref)
|
||||
|
||||
|
||||
def _shutdown_monitors() -> None:
|
||||
if _MONITORS is None:
|
||||
return
|
||||
|
||||
# Copy the set. Closing monitors removes them.
|
||||
monitors = list(_MONITORS)
|
||||
|
||||
# Close all monitors.
|
||||
for ref in monitors:
|
||||
monitor = ref()
|
||||
if monitor:
|
||||
monitor.gc_safe_close()
|
||||
|
||||
monitor = None
|
||||
|
||||
|
||||
def _shutdown_resources() -> None:
|
||||
# _shutdown_monitors/_shutdown_executors may already be GC'd at shutdown.
|
||||
shutdown = _shutdown_monitors
|
||||
if shutdown: # type:ignore[truthy-function]
|
||||
shutdown()
|
||||
shutdown = _shutdown_executors
|
||||
if shutdown: # type:ignore[truthy-function]
|
||||
shutdown()
|
||||
|
||||
|
||||
atexit.register(_shutdown_resources)
|
||||
1903
pymongo/asynchronous/monitoring.py
Normal file
1903
pymongo/asynchronous/monitoring.py
Normal file
File diff suppressed because it is too large
Load Diff
418
pymongo/asynchronous/network.py
Normal file
418
pymongo/asynchronous/network.py
Normal file
@ -0,0 +1,418 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Internal network layer helper methods."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import errno
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from bson import _decode_all_selective
|
||||
from pymongo import _csot
|
||||
from pymongo.asynchronous import helpers as _async_helpers
|
||||
from pymongo.asynchronous import message as _async_message
|
||||
from pymongo.asynchronous.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.asynchronous.compression_support import _NO_COMPRESSION, decompress
|
||||
from pymongo.asynchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.asynchronous.message import _UNPACK_REPLY, _OpMsg, _OpReply
|
||||
from pymongo.asynchronous.monitoring import _is_speculative_authenticate
|
||||
from pymongo.errors import (
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
ProtocolError,
|
||||
_OperationCancelled,
|
||||
)
|
||||
from pymongo.network_layer import (
|
||||
_POLL_TIMEOUT,
|
||||
_UNPACK_COMPRESSION_HEADER,
|
||||
_UNPACK_HEADER,
|
||||
BLOCKING_IO_ERRORS,
|
||||
async_sendall,
|
||||
)
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson import CodecOptions
|
||||
from pymongo.asynchronous.client_session import ClientSession
|
||||
from pymongo.asynchronous.compression_support import SnappyContext, ZlibContext, ZstdContext
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.monitoring import _EventListeners
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.read_preferences import _ServerMode
|
||||
from pymongo.asynchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
async def command(
|
||||
conn: Connection,
|
||||
dbname: str,
|
||||
spec: MutableMapping[str, Any],
|
||||
is_mongos: bool,
|
||||
read_preference: Optional[_ServerMode],
|
||||
codec_options: CodecOptions[_DocumentType],
|
||||
session: Optional[ClientSession],
|
||||
client: Optional[AsyncMongoClient],
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
address: Optional[_Address] = None,
|
||||
listeners: Optional[_EventListeners] = None,
|
||||
max_bson_size: Optional[int] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
|
||||
use_op_msg: bool = False,
|
||||
unacknowledged: bool = False,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
exhaust_allowed: bool = False,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
) -> _DocumentType:
|
||||
"""Execute a command over the socket, or raise socket.error.
|
||||
|
||||
:param conn: a Connection instance
|
||||
:param dbname: name of the database on which to run the command
|
||||
:param spec: a command document as an ordered dict type, eg SON.
|
||||
:param is_mongos: are we connected to a mongos?
|
||||
:param read_preference: a read preference
|
||||
:param codec_options: a CodecOptions instance
|
||||
:param session: optional ClientSession instance.
|
||||
:param client: optional AsyncMongoClient instance for updating $clusterTime.
|
||||
:param check: raise OperationFailure if there are errors
|
||||
:param allowable_errors: errors to ignore if `check` is True
|
||||
:param address: the (host, port) of `conn`
|
||||
:param listeners: An instance of :class:`~pymongo.monitoring.EventListeners`
|
||||
:param max_bson_size: The maximum encoded bson size for this server
|
||||
:param read_concern: The read concern for this command.
|
||||
:param parse_write_concern_error: Whether to parse the ``writeConcernError``
|
||||
field in the command response.
|
||||
:param collation: The collation for this command.
|
||||
:param compression_ctx: optional compression Context.
|
||||
:param use_op_msg: True if we should use OP_MSG.
|
||||
:param unacknowledged: True if this is an unacknowledged command.
|
||||
:param user_fields: Response fields that should be decoded
|
||||
using the TypeDecoders from codec_options, passed to
|
||||
bson._decode_all_selective.
|
||||
:param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed.
|
||||
"""
|
||||
name = next(iter(spec))
|
||||
ns = dbname + ".$cmd"
|
||||
speculative_hello = False
|
||||
|
||||
# Publish the original command document, perhaps with lsid and $clusterTime.
|
||||
orig = spec
|
||||
if is_mongos and not use_op_msg:
|
||||
assert read_preference is not None
|
||||
spec = _async_message._maybe_add_read_preference(spec, read_preference)
|
||||
if read_concern and not (session and session.in_transaction):
|
||||
if read_concern.level:
|
||||
spec["readConcern"] = read_concern.document
|
||||
if session:
|
||||
session._update_read_concern(spec, conn)
|
||||
if collation is not None:
|
||||
spec["collation"] = collation
|
||||
|
||||
publish = listeners is not None and listeners.enabled_for_commands
|
||||
start = datetime.datetime.now()
|
||||
if publish:
|
||||
speculative_hello = _is_speculative_authenticate(name, spec)
|
||||
|
||||
if compression_ctx and name.lower() in _NO_COMPRESSION:
|
||||
compression_ctx = None
|
||||
|
||||
if client and client._encrypter and not client._encrypter._bypass_auto_encryption:
|
||||
spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options)
|
||||
|
||||
# Support CSOT
|
||||
if client:
|
||||
conn.apply_timeout(client, spec)
|
||||
_csot.apply_write_concern(spec, write_concern)
|
||||
|
||||
if use_op_msg:
|
||||
flags = _OpMsg.MORE_TO_COME if unacknowledged else 0
|
||||
flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0
|
||||
request_id, msg, size, max_doc_size = _async_message._op_msg(
|
||||
flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx
|
||||
)
|
||||
# If this is an unacknowledged write then make sure the encoded doc(s)
|
||||
# are small enough, otherwise rely on the server to return an error.
|
||||
if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size:
|
||||
_async_message._raise_document_too_large(name, size, max_bson_size)
|
||||
else:
|
||||
request_id, msg, size = _async_message._query(
|
||||
0, ns, 0, -1, spec, None, codec_options, compression_ctx
|
||||
)
|
||||
|
||||
if max_bson_size is not None and size > max_bson_size + _async_message._COMMAND_OVERHEAD:
|
||||
_async_message._raise_document_too_large(
|
||||
name, size, max_bson_size + _async_message._COMMAND_OVERHEAD
|
||||
)
|
||||
if client is not None:
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
command=spec,
|
||||
commandName=next(iter(spec)),
|
||||
databaseName=dbname,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
assert address is not None
|
||||
listeners.publish_command_start(
|
||||
orig,
|
||||
dbname,
|
||||
request_id,
|
||||
address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await async_sendall(conn.conn, msg)
|
||||
if use_op_msg and unacknowledged:
|
||||
# Unacknowledged, fake a successful command response.
|
||||
reply = None
|
||||
response_doc: _DocumentOut = {"ok": 1}
|
||||
else:
|
||||
reply = await receive_message(conn, request_id)
|
||||
conn.more_to_come = reply.more_to_come
|
||||
unpacked_docs = reply.unpack_response(
|
||||
codec_options=codec_options, user_fields=user_fields
|
||||
)
|
||||
|
||||
response_doc = unpacked_docs[0]
|
||||
if client:
|
||||
await client._process_response(response_doc, session)
|
||||
if check:
|
||||
_async_helpers._check_command_response(
|
||||
response_doc,
|
||||
conn.max_wire_version,
|
||||
allowable_errors,
|
||||
parse_write_concern_error=parse_write_concern_error,
|
||||
)
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - start
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _async_message._convert_exception(exc)
|
||||
if client is not None:
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(spec)),
|
||||
databaseName=dbname,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
assert address is not None
|
||||
listeners.publish_command_failure(
|
||||
duration,
|
||||
failure,
|
||||
name,
|
||||
request_id,
|
||||
address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
database_name=dbname,
|
||||
)
|
||||
raise
|
||||
duration = datetime.datetime.now() - start
|
||||
if client is not None:
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
durationMS=duration,
|
||||
reply=response_doc,
|
||||
commandName=next(iter(spec)),
|
||||
databaseName=dbname,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
speculative_authenticate="speculativeAuthenticate" in orig,
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
assert address is not None
|
||||
listeners.publish_command_success(
|
||||
duration,
|
||||
response_doc,
|
||||
name,
|
||||
request_id,
|
||||
address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
speculative_hello=speculative_hello,
|
||||
database_name=dbname,
|
||||
)
|
||||
|
||||
if client and client._encrypter and reply:
|
||||
decrypted = client._encrypter.decrypt(reply.raw_command_response())
|
||||
response_doc = cast(
|
||||
"_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0]
|
||||
)
|
||||
|
||||
return response_doc # type: ignore[return-value]
|
||||
|
||||
|
||||
async def receive_message(
|
||||
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
if _csot.get_timeout():
|
||||
deadline = _csot.get_deadline()
|
||||
else:
|
||||
timeout = conn.conn.gettimeout()
|
||||
if timeout:
|
||||
deadline = time.monotonic() + timeout
|
||||
else:
|
||||
deadline = None
|
||||
# Ignore the response's request id.
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(
|
||||
await _receive_data_on_socket(conn, 16, deadline)
|
||||
)
|
||||
# No request_id for exhaust cursor "getMore".
|
||||
if request_id is not None:
|
||||
if request_id != response_to:
|
||||
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
|
||||
if length <= 16:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) not longer than standard message header size (16)"
|
||||
)
|
||||
if length > max_message_size:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) is larger than server max "
|
||||
f"message size ({max_message_size!r})"
|
||||
)
|
||||
if op_code == 2012:
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
|
||||
await _receive_data_on_socket(conn, 9, deadline)
|
||||
)
|
||||
data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id)
|
||||
else:
|
||||
data = await _receive_data_on_socket(conn, length - 16, deadline)
|
||||
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
except KeyError:
|
||||
raise ProtocolError(
|
||||
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
|
||||
) from None
|
||||
return unpack_reply(data)
|
||||
|
||||
|
||||
async def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
|
||||
"""Block until at least one byte is read, or a timeout, or a cancel."""
|
||||
sock = conn.conn
|
||||
timed_out = False
|
||||
# Check if the connection's socket has been manually closed
|
||||
if sock.fileno() == -1:
|
||||
return
|
||||
while True:
|
||||
# SSLSocket can have buffered data which won't be caught by select.
|
||||
if hasattr(sock, "pending") and sock.pending() > 0:
|
||||
readable = True
|
||||
else:
|
||||
# Wait up to 500ms for the socket to become readable and then
|
||||
# check for cancellation.
|
||||
if deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
# When the timeout has expired perform one final check to
|
||||
# see if the socket is readable. This helps avoid spurious
|
||||
# timeouts on AWS Lambda and other FaaS environments.
|
||||
if remaining <= 0:
|
||||
timed_out = True
|
||||
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
|
||||
else:
|
||||
timeout = _POLL_TIMEOUT
|
||||
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
|
||||
if conn.cancel_context.cancelled:
|
||||
raise _OperationCancelled("operation cancelled")
|
||||
if readable:
|
||||
return
|
||||
if timed_out:
|
||||
raise socket.timeout("timed out")
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
async def _receive_data_on_socket(
|
||||
conn: Connection, length: int, deadline: Optional[float]
|
||||
) -> memoryview:
|
||||
buf = bytearray(length)
|
||||
mv = memoryview(buf)
|
||||
bytes_read = 0
|
||||
while bytes_read < length:
|
||||
try:
|
||||
await wait_for_read(conn, deadline)
|
||||
# CSOT: Update timeout. When the timeout has expired perform one
|
||||
# final non-blocking recv. This helps avoid spurious timeouts when
|
||||
# the response is actually already buffered on the client.
|
||||
if _csot.get_timeout() and deadline is not None:
|
||||
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
|
||||
chunk_length = conn.conn.recv_into(mv[bytes_read:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
raise socket.timeout("timed out") from None
|
||||
except OSError as exc:
|
||||
if _errno_from_exception(exc) == errno.EINTR:
|
||||
continue
|
||||
raise
|
||||
if chunk_length == 0:
|
||||
raise OSError("connection closed")
|
||||
|
||||
bytes_read += chunk_length
|
||||
|
||||
return mv
|
||||
625
pymongo/asynchronous/operations.py
Normal file
625
pymongo/asynchronous/operations.py
Normal file
@ -0,0 +1,625 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Operation class definitions."""
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo.asynchronous import helpers
|
||||
from pymongo.asynchronous.collation import validate_collation_or_none
|
||||
from pymongo.asynchronous.common import validate_is_mapping, validate_list
|
||||
from pymongo.asynchronous.helpers import _gen_index_name, _index_document, _index_list
|
||||
from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.bulk import _Bulk
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary
|
||||
_IndexList = Union[
|
||||
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
|
||||
]
|
||||
_IndexKeyHint = Union[str, _IndexList]
|
||||
|
||||
|
||||
class _Op(str, enum.Enum):
|
||||
ABORT = "abortTransaction"
|
||||
AGGREGATE = "aggregate"
|
||||
COMMIT = "commitTransaction"
|
||||
COUNT = "count"
|
||||
CREATE = "create"
|
||||
CREATE_INDEXES = "createIndexes"
|
||||
CREATE_SEARCH_INDEXES = "createSearchIndexes"
|
||||
DELETE = "delete"
|
||||
DISTINCT = "distinct"
|
||||
DROP = "drop"
|
||||
DROP_DATABASE = "dropDatabase"
|
||||
DROP_INDEXES = "dropIndexes"
|
||||
DROP_SEARCH_INDEXES = "dropSearchIndexes"
|
||||
END_SESSIONS = "endSessions"
|
||||
FIND_AND_MODIFY = "findAndModify"
|
||||
FIND = "find"
|
||||
INSERT = "insert"
|
||||
LIST_COLLECTIONS = "listCollections"
|
||||
LIST_INDEXES = "listIndexes"
|
||||
LIST_SEARCH_INDEX = "listSearchIndexes"
|
||||
LIST_DATABASES = "listDatabases"
|
||||
UPDATE = "update"
|
||||
UPDATE_INDEX = "updateIndex"
|
||||
UPDATE_SEARCH_INDEX = "updateSearchIndex"
|
||||
RENAME = "rename"
|
||||
GETMORE = "getMore"
|
||||
KILL_CURSORS = "killCursors"
|
||||
TEST = "testOperation"
|
||||
|
||||
|
||||
class InsertOne(Generic[_DocumentType]):
|
||||
"""Represents an insert_one operation."""
|
||||
|
||||
__slots__ = ("_doc",)
|
||||
|
||||
def __init__(self, document: _DocumentType) -> None:
|
||||
"""Create an InsertOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
|
||||
|
||||
:param document: The document to insert. If the document is missing an
|
||||
_id field one will be added.
|
||||
"""
|
||||
self._doc = document
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_insert(self._doc) # type: ignore[arg-type]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"InsertOne({self._doc!r})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return other._doc == self._doc
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
class DeleteOne:
|
||||
"""Represents a delete_one operation."""
|
||||
|
||||
__slots__ = ("_filter", "_collation", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create a DeleteOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the document to delete.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.4 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hint`` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
self._filter = filter
|
||||
self._collation = collation
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_delete(
|
||||
self._filter,
|
||||
1,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return (other._filter, other._collation, other._hint) == (
|
||||
self._filter,
|
||||
self._collation,
|
||||
self._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
class DeleteMany:
|
||||
"""Represents a delete_many operation."""
|
||||
|
||||
__slots__ = ("_filter", "_collation", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create a DeleteMany instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the documents to delete.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.4 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hint`` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
self._filter = filter
|
||||
self._collation = collation
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_delete(
|
||||
self._filter,
|
||||
0,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return (other._filter, other._collation, other._hint) == (
|
||||
self._filter,
|
||||
self._collation,
|
||||
self._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
class ReplaceOne(Generic[_DocumentType]):
|
||||
"""Represents a replace_one operation."""
|
||||
|
||||
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
replacement: Union[_DocumentType, RawBSONDocument],
|
||||
upsert: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create a ReplaceOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the document to replace.
|
||||
:param replacement: The new document.
|
||||
:param upsert: If ``True``, perform an insert if no documents
|
||||
match the filter.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hint`` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the ``collation`` option.
|
||||
"""
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if upsert is not None:
|
||||
validate_boolean("upsert", upsert)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
self._filter = filter
|
||||
self._doc = replacement
|
||||
self._upsert = upsert
|
||||
self._collation = collation
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_replace(
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return (
|
||||
other._filter,
|
||||
other._doc,
|
||||
other._upsert,
|
||||
other._collation,
|
||||
other._hint,
|
||||
) == (
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
other._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format(
|
||||
self.__class__.__name__,
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
self._hint,
|
||||
)
|
||||
|
||||
|
||||
class _UpdateOp:
|
||||
"""Private base class for update operations."""
|
||||
|
||||
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
doc: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool,
|
||||
collation: Optional[_CollationIn],
|
||||
array_filters: Optional[list[Mapping[str, Any]]],
|
||||
hint: Optional[_IndexKeyHint],
|
||||
):
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if upsert is not None:
|
||||
validate_boolean("upsert", upsert)
|
||||
if array_filters is not None:
|
||||
validate_list("array_filters", array_filters)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
|
||||
self._filter = filter
|
||||
self._doc = doc
|
||||
self._upsert = upsert
|
||||
self._collation = collation
|
||||
self._array_filters = array_filters
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, type(self)):
|
||||
return (
|
||||
other._filter,
|
||||
other._doc,
|
||||
other._upsert,
|
||||
other._collation,
|
||||
other._array_filters,
|
||||
other._hint,
|
||||
) == (
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
self._array_filters,
|
||||
self._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format(
|
||||
self.__class__.__name__,
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
self._array_filters,
|
||||
self._hint,
|
||||
)
|
||||
|
||||
|
||||
class UpdateOne(_UpdateOp):
|
||||
"""Represents an update_one operation."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[list[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Represents an update_one operation.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the document to update.
|
||||
:param update: The modifications to apply.
|
||||
:param upsert: If ``True``, perform an insert if no documents
|
||||
match the filter.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param array_filters: A list of filters specifying which
|
||||
array elements an update should apply.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the `hint` option.
|
||||
.. versionchanged:: 3.9
|
||||
Added the ability to accept a pipeline as the `update`.
|
||||
.. versionchanged:: 3.6
|
||||
Added the `array_filters` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
super().__init__(filter, update, upsert, collation, array_filters, hint)
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_update(
|
||||
self._filter,
|
||||
self._doc,
|
||||
False,
|
||||
self._upsert,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
array_filters=self._array_filters,
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
|
||||
class UpdateMany(_UpdateOp):
|
||||
"""Represents an update_many operation."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[list[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create an UpdateMany instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the documents to update.
|
||||
:param update: The modifications to apply.
|
||||
:param upsert: If ``True``, perform an insert if no documents
|
||||
match the filter.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param array_filters: A list of filters specifying which
|
||||
array elements an update should apply.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.AsyncCollection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the `hint` option.
|
||||
.. versionchanged:: 3.9
|
||||
Added the ability to accept a pipeline as the `update`.
|
||||
.. versionchanged:: 3.6
|
||||
Added the `array_filters` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
super().__init__(filter, update, upsert, collation, array_filters, hint)
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_update(
|
||||
self._filter,
|
||||
self._doc,
|
||||
True,
|
||||
self._upsert,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
array_filters=self._array_filters,
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
|
||||
class IndexModel:
|
||||
"""Represents an index to create."""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None:
|
||||
"""Create an Index instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.create_indexes`.
|
||||
|
||||
Takes either a single key or a list containing (key, direction) pairs
|
||||
or keys. If no direction is given, :data:`~pymongo.ASCENDING` will
|
||||
be assumed.
|
||||
The key(s) must be an instance of :class:`str`, and the direction(s) must
|
||||
be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`,
|
||||
:data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`,
|
||||
:data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`).
|
||||
|
||||
Valid options include, but are not limited to:
|
||||
|
||||
- `name`: custom name to use for this index - if none is
|
||||
given, a name will be generated.
|
||||
- `unique`: if ``True``, creates a uniqueness constraint on the index.
|
||||
- `background`: if ``True``, this index should be created in the
|
||||
background.
|
||||
- `sparse`: if ``True``, omit from the index any documents that lack
|
||||
the indexed field.
|
||||
- `bucketSize`: for use with geoHaystack indexes.
|
||||
Number of documents to group together within a certain proximity
|
||||
to a given longitude and latitude.
|
||||
- `min`: minimum value for keys in a :data:`~pymongo.GEO2D`
|
||||
index.
|
||||
- `max`: maximum value for keys in a :data:`~pymongo.GEO2D`
|
||||
index.
|
||||
- `expireAfterSeconds`: <int> Used to create an expiring (TTL)
|
||||
collection. MongoDB will automatically delete documents from
|
||||
this collection after <int> seconds. The indexed field must
|
||||
be a UTC datetime or the data will not expire.
|
||||
- `partialFilterExpression`: A document that specifies a filter for
|
||||
a partial index.
|
||||
- `collation`: An instance of :class:`~pymongo.collation.Collation`
|
||||
that specifies the collation to use.
|
||||
- `wildcardProjection`: Allows users to include or exclude specific
|
||||
field paths from a `wildcard index`_ using the { "$**" : 1} key
|
||||
pattern. Requires MongoDB >= 4.2.
|
||||
- `hidden`: if ``True``, this index will be hidden from the query
|
||||
planner and will not be evaluated as part of query plan
|
||||
selection. Requires MongoDB >= 4.4.
|
||||
|
||||
See the MongoDB documentation for a full list of supported options by
|
||||
server version.
|
||||
|
||||
:param keys: a single key or a list containing (key, direction) pairs
|
||||
or keys specifying the index to create.
|
||||
:param kwargs: any additional index creation
|
||||
options (see the above list) should be passed as keyword
|
||||
arguments.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hidden`` option.
|
||||
.. versionchanged:: 3.2
|
||||
Added the ``partialFilterExpression`` option to support partial
|
||||
indexes.
|
||||
|
||||
.. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/
|
||||
"""
|
||||
keys = _index_list(keys)
|
||||
if kwargs.get("name") is None:
|
||||
kwargs["name"] = _gen_index_name(keys)
|
||||
kwargs["key"] = _index_document(keys)
|
||||
collation = validate_collation_or_none(kwargs.pop("collation", None))
|
||||
self.__document = kwargs
|
||||
if collation is not None:
|
||||
self.__document["collation"] = collation
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""An index document suitable for passing to the createIndexes
|
||||
command.
|
||||
"""
|
||||
return self.__document
|
||||
|
||||
|
||||
class SearchIndexModel:
|
||||
"""Represents a search index to create."""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
definition: Mapping[str, Any],
|
||||
name: Optional[str] = None,
|
||||
type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a Search Index instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.AsyncCollection.create_search_index` and :meth:`~pymongo.collection.AsyncCollection.create_search_indexes`.
|
||||
|
||||
:param definition: The definition for this index.
|
||||
:param name: The name for this index, if present.
|
||||
:param type: The type for this index which defaults to "search". Alternative values include "vectorSearch".
|
||||
:param kwargs: Keyword arguments supplying any additional options.
|
||||
|
||||
.. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster.
|
||||
.. versionadded:: 4.5
|
||||
.. versionchanged:: 4.7
|
||||
Added the type and kwargs arguments.
|
||||
"""
|
||||
self.__document: dict[str, Any] = {}
|
||||
if name is not None:
|
||||
self.__document["name"] = name
|
||||
self.__document["definition"] = definition
|
||||
if type is not None:
|
||||
self.__document["type"] = type
|
||||
self.__document.update(kwargs)
|
||||
|
||||
@property
|
||||
def document(self) -> Mapping[str, Any]:
|
||||
"""The document for this index."""
|
||||
return self.__document
|
||||
209
pymongo/asynchronous/periodic_executor.py
Normal file
209
pymongo/asynchronous/periodic_executor.py
Normal file
@ -0,0 +1,209 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Run a target function on a background thread."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Optional
|
||||
|
||||
from pymongo.lock import _ALock, _create_lock
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class PeriodicExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
interval: float,
|
||||
min_interval: float,
|
||||
target: Any,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Run a target function periodically on a background thread.
|
||||
|
||||
If the target's return value is false, the executor stops.
|
||||
|
||||
:param interval: Seconds between calls to `target`.
|
||||
:param min_interval: Minimum seconds between calls if `wake` is
|
||||
called very often.
|
||||
:param target: A function.
|
||||
:param name: A name to give the underlying thread.
|
||||
"""
|
||||
# threading.Event and its internal condition variable are expensive
|
||||
# in Python 2, see PYTHON-983. Use a boolean to know when to wake.
|
||||
# The executor's design is constrained by several Python issues, see
|
||||
# "periodic_executor.rst" in this repository.
|
||||
self._event = False
|
||||
self._interval = interval
|
||||
self._min_interval = min_interval
|
||||
self._target = target
|
||||
self._stopped = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._name = name
|
||||
self._skip_sleep = False
|
||||
self._thread_will_exit = False
|
||||
self._lock = _ALock(_create_lock())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
|
||||
|
||||
def _run_async(self) -> None:
|
||||
asyncio.run(self._run()) # type: ignore[func-returns-value]
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start. Multiple calls have no effect.
|
||||
|
||||
Not safe to call from multiple threads at once.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._thread_will_exit:
|
||||
# If the background thread has read self._stopped as True
|
||||
# there is a chance that it has not yet exited. The call to
|
||||
# join should not block indefinitely because there is no
|
||||
# other work done outside the while loop in self._run.
|
||||
try:
|
||||
assert self._thread is not None
|
||||
self._thread.join()
|
||||
except ReferenceError:
|
||||
# Thread terminated.
|
||||
pass
|
||||
self._thread_will_exit = False
|
||||
self._stopped = False
|
||||
started: Any = False
|
||||
try:
|
||||
started = self._thread and self._thread.is_alive()
|
||||
except ReferenceError:
|
||||
# Thread terminated.
|
||||
pass
|
||||
|
||||
if not started:
|
||||
if _IS_SYNC:
|
||||
thread = threading.Thread(target=self._run, name=self._name)
|
||||
else:
|
||||
thread = threading.Thread(target=self._run_async, name=self._name)
|
||||
thread.daemon = True
|
||||
self._thread = weakref.proxy(thread)
|
||||
_register_executor(self)
|
||||
# Mitigation to RuntimeError firing when thread starts on shutdown
|
||||
# https://github.com/python/cpython/issues/114570
|
||||
try:
|
||||
thread.start()
|
||||
except RuntimeError as e:
|
||||
if "interpreter shutdown" in str(e) or sys.is_finalizing():
|
||||
self._thread = None
|
||||
return
|
||||
raise
|
||||
|
||||
def close(self, dummy: Any = None) -> None:
|
||||
"""Stop. To restart, call open().
|
||||
|
||||
The dummy parameter allows an executor's close method to be a weakref
|
||||
callback; see monitor.py.
|
||||
"""
|
||||
self._stopped = True
|
||||
|
||||
def join(self, timeout: Optional[int] = None) -> None:
|
||||
if self._thread is not None:
|
||||
try:
|
||||
self._thread.join(timeout)
|
||||
except (ReferenceError, RuntimeError):
|
||||
# Thread already terminated, or not yet started.
|
||||
pass
|
||||
|
||||
def wake(self) -> None:
|
||||
"""Execute the target function soon."""
|
||||
self._event = True
|
||||
|
||||
def update_interval(self, new_interval: int) -> None:
|
||||
self._interval = new_interval
|
||||
|
||||
def skip_sleep(self) -> None:
|
||||
self._skip_sleep = True
|
||||
|
||||
async def _should_stop(self) -> bool:
|
||||
async with self._lock:
|
||||
if self._stopped:
|
||||
self._thread_will_exit = True
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _run(self) -> None:
|
||||
while not await self._should_stop():
|
||||
try:
|
||||
if not await self._target():
|
||||
self._stopped = True
|
||||
break
|
||||
except BaseException:
|
||||
async with self._lock:
|
||||
self._stopped = True
|
||||
self._thread_will_exit = True
|
||||
|
||||
raise
|
||||
|
||||
if self._skip_sleep:
|
||||
self._skip_sleep = False
|
||||
else:
|
||||
deadline = time.monotonic() + self._interval
|
||||
while not self._stopped and time.monotonic() < deadline:
|
||||
await asyncio.sleep(self._min_interval)
|
||||
if self._event:
|
||||
break # Early wake.
|
||||
|
||||
self._event = False
|
||||
|
||||
|
||||
# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started,
|
||||
# an executor is kept alive by a strong reference from its thread and perhaps
|
||||
# from other objects. When the thread dies and all other referrers are freed,
|
||||
# the executor is freed and removed from _EXECUTORS. If any threads are
|
||||
# running when the interpreter begins to shut down, we try to halt and join
|
||||
# them to avoid spurious errors.
|
||||
_EXECUTORS = set()
|
||||
|
||||
|
||||
def _register_executor(executor: PeriodicExecutor) -> None:
|
||||
ref = weakref.ref(executor, _on_executor_deleted)
|
||||
_EXECUTORS.add(ref)
|
||||
|
||||
|
||||
def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None:
|
||||
_EXECUTORS.remove(ref)
|
||||
|
||||
|
||||
def _shutdown_executors() -> None:
|
||||
if _EXECUTORS is None:
|
||||
return
|
||||
|
||||
# Copy the set. Stopping threads has the side effect of removing executors.
|
||||
executors = list(_EXECUTORS)
|
||||
|
||||
# First signal all executors to close...
|
||||
for ref in executors:
|
||||
executor = ref()
|
||||
if executor:
|
||||
executor.close()
|
||||
|
||||
# ...then try to join them.
|
||||
for ref in executors:
|
||||
executor = ref()
|
||||
if executor:
|
||||
executor.join(1)
|
||||
|
||||
executor = None
|
||||
2128
pymongo/asynchronous/pool.py
Normal file
2128
pymongo/asynchronous/pool.py
Normal file
File diff suppressed because it is too large
Load Diff
624
pymongo/asynchronous/read_preferences.py
Normal file
624
pymongo/asynchronous/read_preferences.py
Normal file
@ -0,0 +1,624 @@
|
||||
# Copyright 2012-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License",
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for choosing which member of a replica set to read from."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import abc
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
|
||||
|
||||
from pymongo.asynchronous import max_staleness_selectors
|
||||
from pymongo.asynchronous.server_selectors import (
|
||||
member_with_tags_server_selector,
|
||||
secondary_with_tags_server_selector,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.server_selectors import Selection
|
||||
from pymongo.asynchronous.topology_description import TopologyDescription
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
_PRIMARY = 0
|
||||
_PRIMARY_PREFERRED = 1
|
||||
_SECONDARY = 2
|
||||
_SECONDARY_PREFERRED = 3
|
||||
_NEAREST = 4
|
||||
|
||||
|
||||
_MONGOS_MODES = (
|
||||
"primary",
|
||||
"primaryPreferred",
|
||||
"secondary",
|
||||
"secondaryPreferred",
|
||||
"nearest",
|
||||
)
|
||||
|
||||
_Hedge = Mapping[str, Any]
|
||||
_TagSets = Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]:
|
||||
"""Validate tag sets for a MongoClient."""
|
||||
if tag_sets is None:
|
||||
return tag_sets
|
||||
|
||||
if not isinstance(tag_sets, (list, tuple)):
|
||||
raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence")
|
||||
if len(tag_sets) == 0:
|
||||
raise ValueError(
|
||||
f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags"
|
||||
)
|
||||
|
||||
for tags in tag_sets:
|
||||
if not isinstance(tags, abc.Mapping):
|
||||
raise TypeError(
|
||||
f"Tag set {tags!r} invalid, must be an instance of dict, "
|
||||
"bson.son.SON or other type that inherits from "
|
||||
"collection.Mapping"
|
||||
)
|
||||
|
||||
return list(tag_sets)
|
||||
|
||||
|
||||
def _invalid_max_staleness_msg(max_staleness: Any) -> str:
|
||||
return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness
|
||||
|
||||
|
||||
# Some duplication with common.py to avoid import cycle.
|
||||
def _validate_max_staleness(max_staleness: Any) -> int:
|
||||
"""Validate max_staleness."""
|
||||
if max_staleness == -1:
|
||||
return -1
|
||||
|
||||
if not isinstance(max_staleness, int):
|
||||
raise TypeError(_invalid_max_staleness_msg(max_staleness))
|
||||
|
||||
if max_staleness <= 0:
|
||||
raise ValueError(_invalid_max_staleness_msg(max_staleness))
|
||||
|
||||
return max_staleness
|
||||
|
||||
|
||||
def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]:
|
||||
"""Validate hedge."""
|
||||
if hedge is None:
|
||||
return None
|
||||
|
||||
if not isinstance(hedge, dict):
|
||||
raise TypeError(f"hedge must be a dictionary, not {hedge!r}")
|
||||
|
||||
return hedge
|
||||
|
||||
|
||||
class _ServerMode:
|
||||
"""Base class for all read preferences."""
|
||||
|
||||
__slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: int,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
self.__mongos_mode = _MONGOS_MODES[mode]
|
||||
self.__mode = mode
|
||||
self.__tag_sets = _validate_tag_sets(tag_sets)
|
||||
self.__max_staleness = _validate_max_staleness(max_staleness)
|
||||
self.__hedge = _validate_hedge(hedge)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of this read preference."""
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def mongos_mode(self) -> str:
|
||||
"""The mongos mode of this read preference."""
|
||||
return self.__mongos_mode
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""Read preference as a document."""
|
||||
doc: dict[str, Any] = {"mode": self.__mongos_mode}
|
||||
if self.__tag_sets not in (None, [{}]):
|
||||
doc["tags"] = self.__tag_sets
|
||||
if self.__max_staleness != -1:
|
||||
doc["maxStalenessSeconds"] = self.__max_staleness
|
||||
if self.__hedge not in (None, {}):
|
||||
doc["hedge"] = self.__hedge
|
||||
return doc
|
||||
|
||||
@property
|
||||
def mode(self) -> int:
|
||||
"""The mode of this read preference instance."""
|
||||
return self.__mode
|
||||
|
||||
@property
|
||||
def tag_sets(self) -> _TagSets:
|
||||
"""Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to
|
||||
read only from members whose ``dc`` tag has the value ``"ny"``.
|
||||
To specify a priority-order for tag sets, provide a list of
|
||||
tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag
|
||||
set, ``{}``, means "read from any member that matches the mode,
|
||||
ignoring tags." MongoClient tries each set of tags in turn
|
||||
until it finds a set of tags with at least one matching member.
|
||||
For example, to only send a query to an analytic node::
|
||||
|
||||
Nearest(tag_sets=[{"node":"analytics"}])
|
||||
|
||||
Or using :class:`SecondaryPreferred`::
|
||||
|
||||
SecondaryPreferred(tag_sets=[{"node":"analytics"}])
|
||||
|
||||
.. seealso:: `Data-Center Awareness
|
||||
<https://www.mongodb.com/docs/manual/data-center-awareness/>`_
|
||||
"""
|
||||
return list(self.__tag_sets) if self.__tag_sets else [{}]
|
||||
|
||||
@property
|
||||
def max_staleness(self) -> int:
|
||||
"""The maximum estimated length of time (in seconds) a replica set
|
||||
secondary can fall behind the primary in replication before it will
|
||||
no longer be selected for operations, or -1 for no maximum.
|
||||
"""
|
||||
return self.__max_staleness
|
||||
|
||||
@property
|
||||
def hedge(self) -> Optional[_Hedge]:
|
||||
"""The read preference ``hedge`` parameter.
|
||||
|
||||
A dictionary that configures how the server will perform hedged reads.
|
||||
It consists of the following keys:
|
||||
|
||||
- ``enabled``: Enables or disables hedged reads in sharded clusters.
|
||||
|
||||
Hedged reads are automatically enabled in MongoDB 4.4+ when using a
|
||||
``nearest`` read preference. To explicitly enable hedged reads, set
|
||||
the ``enabled`` key to ``true``::
|
||||
|
||||
>>> Nearest(hedge={'enabled': True})
|
||||
|
||||
To explicitly disable hedged reads, set the ``enabled`` key to
|
||||
``False``::
|
||||
|
||||
>>> Nearest(hedge={'enabled': False})
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
return self.__hedge
|
||||
|
||||
@property
|
||||
def min_wire_version(self) -> int:
|
||||
"""The wire protocol version the server must support.
|
||||
|
||||
Some read preferences impose version requirements on all servers (e.g.
|
||||
maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5).
|
||||
|
||||
All servers' maxWireVersion must be at least this read preference's
|
||||
`min_wire_version`, or the driver raises
|
||||
:exc:`~pymongo.errors.ConfigurationError`.
|
||||
"""
|
||||
return 0 if self.__max_staleness == -1 else 5
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format(
|
||||
self.name,
|
||||
self.__tag_sets,
|
||||
self.__max_staleness,
|
||||
self.__hedge,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, _ServerMode):
|
||||
return (
|
||||
self.mode == other.mode
|
||||
and self.tag_sets == other.tag_sets
|
||||
and self.max_staleness == other.max_staleness
|
||||
and self.hedge == other.hedge
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
"""Return value of object for pickling.
|
||||
|
||||
Needed explicitly because __slots__() defined.
|
||||
"""
|
||||
return {
|
||||
"mode": self.__mode,
|
||||
"tag_sets": self.__tag_sets,
|
||||
"max_staleness": self.__max_staleness,
|
||||
"hedge": self.__hedge,
|
||||
}
|
||||
|
||||
def __setstate__(self, value: Mapping[str, Any]) -> None:
|
||||
"""Restore from pickling."""
|
||||
self.__mode = value["mode"]
|
||||
self.__mongos_mode = _MONGOS_MODES[self.__mode]
|
||||
self.__tag_sets = _validate_tag_sets(value["tag_sets"])
|
||||
self.__max_staleness = _validate_max_staleness(value["max_staleness"])
|
||||
self.__hedge = _validate_hedge(value["hedge"])
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
return selection
|
||||
|
||||
|
||||
class Primary(_ServerMode):
|
||||
"""Primary read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed if the server
|
||||
is standalone or a replica set primary.
|
||||
* When connected to a mongos queries are sent to the primary of a shard.
|
||||
* When connected to a replica set queries are sent to the primary of
|
||||
the replica set.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(_PRIMARY)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to a Selection."""
|
||||
return selection.primary_selection
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Primary()"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, _ServerMode):
|
||||
return other.mode == _PRIMARY
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class PrimaryPreferred(_ServerMode):
|
||||
"""PrimaryPreferred read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are sent to the primary of a shard if
|
||||
available, otherwise a shard secondary.
|
||||
* When connected to a replica set queries are sent to the primary if
|
||||
available, otherwise a secondary.
|
||||
|
||||
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
|
||||
created reads will be routed to an available secondary until the
|
||||
primary of the replica set is discovered.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` to use if the primary is not
|
||||
available.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` to use if the primary is not available.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
if selection.primary:
|
||||
return selection.primary_selection
|
||||
else:
|
||||
return secondary_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
|
||||
class Secondary(_ServerMode):
|
||||
"""Secondary read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are distributed among shard
|
||||
secondaries. An error is raised if no secondaries are available.
|
||||
* When connected to a replica set queries are distributed among
|
||||
secondaries. An error is raised if no secondaries are available.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` for this read preference.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` for this read preference.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_SECONDARY, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
return secondary_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
|
||||
class SecondaryPreferred(_ServerMode):
|
||||
"""SecondaryPreferred read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are distributed among shard
|
||||
secondaries, or the shard primary if no secondary is available.
|
||||
* When connected to a replica set queries are distributed among
|
||||
secondaries, or the primary if no secondary is available.
|
||||
|
||||
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
|
||||
created reads will be routed to the primary of the replica set until
|
||||
an available secondary is discovered.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` for this read preference.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` for this read preference.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
secondaries = secondary_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
if secondaries:
|
||||
return secondaries
|
||||
else:
|
||||
return selection.primary_selection
|
||||
|
||||
|
||||
class Nearest(_ServerMode):
|
||||
"""Nearest read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are distributed among all members of
|
||||
a shard.
|
||||
* When connected to a replica set queries are distributed among all
|
||||
members.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` for this read preference.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` for this read preference.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_NEAREST, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
return member_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
|
||||
class _AggWritePref:
|
||||
"""Agg $out/$merge write preference.
|
||||
|
||||
* If there are readable servers and there is any pre-5.0 server, use
|
||||
primary read preference.
|
||||
* Otherwise use `pref` read preference.
|
||||
|
||||
:param pref: The read preference to use on MongoDB 5.0+.
|
||||
"""
|
||||
|
||||
__slots__ = ("pref", "effective_pref")
|
||||
|
||||
def __init__(self, pref: _ServerMode):
|
||||
self.pref = pref
|
||||
self.effective_pref: _ServerMode = ReadPreference.PRIMARY
|
||||
|
||||
def selection_hook(self, topology_description: TopologyDescription) -> None:
|
||||
common_wv = topology_description.common_wire_version
|
||||
if (
|
||||
topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED)
|
||||
and common_wv
|
||||
and common_wv < 13
|
||||
):
|
||||
self.effective_pref = ReadPreference.PRIMARY
|
||||
else:
|
||||
self.effective_pref = self.pref
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to a Selection."""
|
||||
return self.effective_pref(selection)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"_AggWritePref(pref={self.pref!r})"
|
||||
|
||||
# Proxy other calls to the effective_pref so that _AggWritePref can be
|
||||
# used in place of an actual read preference.
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self.effective_pref, name)
|
||||
|
||||
|
||||
_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest)
|
||||
|
||||
|
||||
def make_read_preference(
|
||||
mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1
|
||||
) -> _ServerMode:
|
||||
if mode == _PRIMARY:
|
||||
if tag_sets not in (None, [{}]):
|
||||
raise ConfigurationError("Read preference primary cannot be combined with tags")
|
||||
if max_staleness != -1:
|
||||
raise ConfigurationError(
|
||||
"Read preference primary cannot be combined with maxStalenessSeconds"
|
||||
)
|
||||
return Primary()
|
||||
return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore
|
||||
|
||||
|
||||
_MODES = (
|
||||
"PRIMARY",
|
||||
"PRIMARY_PREFERRED",
|
||||
"SECONDARY",
|
||||
"SECONDARY_PREFERRED",
|
||||
"NEAREST",
|
||||
)
|
||||
|
||||
|
||||
class ReadPreference:
|
||||
"""An enum that defines some commonly used read preference modes.
|
||||
|
||||
Apps can also create a custom read preference, for example::
|
||||
|
||||
Nearest(tag_sets=[{"node":"analytics"}])
|
||||
|
||||
See :doc:`/examples/high_availability` for code examples.
|
||||
|
||||
A read preference is used in three cases:
|
||||
|
||||
:class:`~pymongo.mongo_client.MongoClient` connected to a single mongod:
|
||||
|
||||
- ``PRIMARY``: Queries are allowed if the server is standalone or a replica
|
||||
set primary.
|
||||
- All other modes allow queries to standalone servers, to a replica set
|
||||
primary, or to replica set secondaries.
|
||||
|
||||
:class:`~pymongo.mongo_client.MongoClient` initialized with the
|
||||
``replicaSet`` option:
|
||||
|
||||
- ``PRIMARY``: Read from the primary. This is the default, and provides the
|
||||
strongest consistency. If no primary is available, raise
|
||||
:class:`~pymongo.errors.AutoReconnect`.
|
||||
|
||||
- ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is
|
||||
none, read from a secondary.
|
||||
|
||||
- ``SECONDARY``: Read from a secondary. If no secondary is available,
|
||||
raise :class:`~pymongo.errors.AutoReconnect`.
|
||||
|
||||
- ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise
|
||||
from the primary.
|
||||
|
||||
- ``NEAREST``: Read from any member.
|
||||
|
||||
:class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a
|
||||
sharded cluster of replica sets:
|
||||
|
||||
- ``PRIMARY``: Read from the primary of the shard, or raise
|
||||
:class:`~pymongo.errors.OperationFailure` if there is none.
|
||||
This is the default.
|
||||
|
||||
- ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is
|
||||
none, read from a secondary of the shard.
|
||||
|
||||
- ``SECONDARY``: Read from a secondary of the shard, or raise
|
||||
:class:`~pymongo.errors.OperationFailure` if there is none.
|
||||
|
||||
- ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available,
|
||||
otherwise from the shard primary.
|
||||
|
||||
- ``NEAREST``: Read from any shard member.
|
||||
"""
|
||||
|
||||
PRIMARY = Primary()
|
||||
PRIMARY_PREFERRED = PrimaryPreferred()
|
||||
SECONDARY = Secondary()
|
||||
SECONDARY_PREFERRED = SecondaryPreferred()
|
||||
NEAREST = Nearest()
|
||||
|
||||
|
||||
def read_pref_mode_from_name(name: str) -> int:
|
||||
"""Get the read preference mode from mongos/uri name."""
|
||||
return _MONGOS_MODES.index(name)
|
||||
|
||||
|
||||
class MovingAverage:
|
||||
"""Tracks an exponentially-weighted moving average."""
|
||||
|
||||
average: Optional[float]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.average = None
|
||||
|
||||
def add_sample(self, sample: float) -> None:
|
||||
if sample < 0:
|
||||
# Likely system time change while waiting for hello response
|
||||
# and not using time.monotonic. Ignore it, the next one will
|
||||
# probably be valid.
|
||||
return
|
||||
if self.average is None:
|
||||
self.average = sample
|
||||
else:
|
||||
# The Server Selection Spec requires an exponentially weighted
|
||||
# average with alpha = 0.2.
|
||||
self.average = 0.8 * self.average + 0.2 * sample
|
||||
|
||||
def get(self) -> Optional[float]:
|
||||
"""Get the calculated average, or None if no samples yet."""
|
||||
return self.average
|
||||
|
||||
def reset(self) -> None:
|
||||
self.average = None
|
||||
133
pymongo/asynchronous/response.py
Normal file
133
pymongo/asynchronous/response.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Represent a response from the server."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from pymongo.asynchronous.message import _OpMsg, _OpReply
|
||||
from pymongo.asynchronous.pool import Connection
|
||||
from pymongo.asynchronous.typings import _Address, _DocumentOut
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class Response:
|
||||
__slots__ = ("_data", "_address", "_request_id", "_duration", "_from_command", "_docs")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Union[_OpMsg, _OpReply],
|
||||
address: _Address,
|
||||
request_id: int,
|
||||
duration: Optional[timedelta],
|
||||
from_command: bool,
|
||||
docs: Sequence[Mapping[str, Any]],
|
||||
):
|
||||
"""Represent a response from the server.
|
||||
|
||||
:param data: A network response message.
|
||||
:param address: (host, port) of the source server.
|
||||
:param request_id: The request id of this operation.
|
||||
:param duration: The duration of the operation.
|
||||
:param from_command: if the response is the result of a db command.
|
||||
"""
|
||||
self._data = data
|
||||
self._address = address
|
||||
self._request_id = request_id
|
||||
self._duration = duration
|
||||
self._from_command = from_command
|
||||
self._docs = docs
|
||||
|
||||
@property
|
||||
def data(self) -> Union[_OpMsg, _OpReply]:
|
||||
"""Server response's raw BSON bytes."""
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def address(self) -> _Address:
|
||||
"""(host, port) of the source server."""
|
||||
return self._address
|
||||
|
||||
@property
|
||||
def request_id(self) -> int:
|
||||
"""The request id of this operation."""
|
||||
return self._request_id
|
||||
|
||||
@property
|
||||
def duration(self) -> Optional[timedelta]:
|
||||
"""The duration of the operation."""
|
||||
return self._duration
|
||||
|
||||
@property
|
||||
def from_command(self) -> bool:
|
||||
"""If the response is a result from a db command."""
|
||||
return self._from_command
|
||||
|
||||
@property
|
||||
def docs(self) -> Sequence[Mapping[str, Any]]:
|
||||
"""The decoded document(s)."""
|
||||
return self._docs
|
||||
|
||||
|
||||
class PinnedResponse(Response):
|
||||
__slots__ = ("_conn", "_more_to_come")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Union[_OpMsg, _OpReply],
|
||||
address: _Address,
|
||||
conn: Connection,
|
||||
request_id: int,
|
||||
duration: Optional[timedelta],
|
||||
from_command: bool,
|
||||
docs: list[_DocumentOut],
|
||||
more_to_come: bool,
|
||||
):
|
||||
"""Represent a response to an exhaust cursor's initial query.
|
||||
|
||||
:param data: A network response message.
|
||||
:param address: (host, port) of the source server.
|
||||
:param conn: The Connection used for the initial query.
|
||||
:param request_id: The request id of this operation.
|
||||
:param duration: The duration of the operation.
|
||||
:param from_command: If the response is the result of a db command.
|
||||
:param docs: List of documents.
|
||||
:param more_to_come: Bool indicating whether cursor is ready to be
|
||||
exhausted.
|
||||
"""
|
||||
super().__init__(data, address, request_id, duration, from_command, docs)
|
||||
self._conn = conn
|
||||
self._more_to_come = more_to_come
|
||||
|
||||
@property
|
||||
def conn(self) -> Connection:
|
||||
"""The Connection used for the initial query.
|
||||
|
||||
The server will send batches on this socket, without waiting for
|
||||
getMores from the client, until the result set is exhausted or there
|
||||
is an error.
|
||||
"""
|
||||
return self._conn
|
||||
|
||||
@property
|
||||
def more_to_come(self) -> bool:
|
||||
"""If true, server is ready to send batches on the socket until the
|
||||
result set is exhausted or there is an error.
|
||||
"""
|
||||
return self._more_to_come
|
||||
355
pymongo/asynchronous/server.py
Normal file
355
pymongo/asynchronous/server.py
Normal file
@ -0,0 +1,355 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Communicate with one MongoDB server in a topology."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
Callable,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson import _decode_all_selective
|
||||
from pymongo.asynchronous.helpers import _check_command_response, _handle_reauth
|
||||
from pymongo.asynchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.asynchronous.message import _convert_exception, _GetMore, _OpMsg, _Query
|
||||
from pymongo.asynchronous.response import PinnedResponse, Response
|
||||
from pymongo.errors import NotPrimaryError, OperationFailure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from queue import Queue
|
||||
from weakref import ReferenceType
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler
|
||||
from pymongo.asynchronous.monitor import Monitor
|
||||
from pymongo.asynchronous.monitoring import _EventListeners
|
||||
from pymongo.asynchronous.pool import Connection, Pool
|
||||
from pymongo.asynchronous.read_preferences import _ServerMode
|
||||
from pymongo.asynchronous.server_description import ServerDescription
|
||||
from pymongo.asynchronous.typings import _DocumentOut
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(
|
||||
self,
|
||||
server_description: ServerDescription,
|
||||
pool: Pool,
|
||||
monitor: Monitor,
|
||||
topology_id: Optional[ObjectId] = None,
|
||||
listeners: Optional[_EventListeners] = None,
|
||||
events: Optional[ReferenceType[Queue]] = None,
|
||||
) -> None:
|
||||
"""Represent one MongoDB server."""
|
||||
self._description = server_description
|
||||
self._pool = pool
|
||||
self._monitor = monitor
|
||||
self._topology_id = topology_id
|
||||
self._publish = listeners is not None and listeners.enabled_for_server
|
||||
self._listener = listeners
|
||||
self._events = None
|
||||
if self._publish:
|
||||
self._events = events() # type: ignore[misc]
|
||||
|
||||
async def open(self) -> None:
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
Multiple calls have no effect.
|
||||
"""
|
||||
if not self._pool.opts.load_balanced:
|
||||
self._monitor.open()
|
||||
|
||||
async def reset(self, service_id: Optional[ObjectId] = None) -> None:
|
||||
"""Clear the connection pool."""
|
||||
await self.pool.reset(service_id)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clear the connection pool and stop the monitor.
|
||||
|
||||
Reconnect with open().
|
||||
"""
|
||||
if self._publish:
|
||||
assert self._listener is not None
|
||||
assert self._events is not None
|
||||
self._events.put(
|
||||
(
|
||||
self._listener.publish_server_closed,
|
||||
(self._description.address, self._topology_id),
|
||||
)
|
||||
)
|
||||
await self._monitor.close()
|
||||
await self._pool.close()
|
||||
|
||||
def request_check(self) -> None:
|
||||
"""Check the server's state soon."""
|
||||
self._monitor.request_check()
|
||||
|
||||
@_handle_reauth
|
||||
async def run_operation(
|
||||
self,
|
||||
conn: Connection,
|
||||
operation: Union[_Query, _GetMore],
|
||||
read_preference: _ServerMode,
|
||||
listeners: Optional[_EventListeners],
|
||||
unpack_res: Callable[..., list[_DocumentOut]],
|
||||
client: AsyncMongoClient,
|
||||
) -> Response:
|
||||
"""Run a _Query or _GetMore operation and return a Response object.
|
||||
|
||||
This method is used only to run _Query/_GetMore operations from
|
||||
cursors.
|
||||
Can raise ConnectionFailure, OperationFailure, etc.
|
||||
|
||||
:param conn: A Connection instance.
|
||||
:param operation: A _Query or _GetMore object.
|
||||
:param read_preference: The read preference to use.
|
||||
:param listeners: Instance of _EventListeners or None.
|
||||
:param unpack_res: A callable that decodes the wire protocol response.
|
||||
"""
|
||||
duration = None
|
||||
assert listeners is not None
|
||||
publish = listeners.enabled_for_commands
|
||||
start = datetime.now()
|
||||
|
||||
use_cmd = operation.use_command(conn)
|
||||
more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come
|
||||
if more_to_come:
|
||||
request_id = 0
|
||||
else:
|
||||
message = await operation.get_message(read_preference, conn, use_cmd)
|
||||
request_id, data, max_doc_size = self._split_message(message)
|
||||
|
||||
cmd, dbn = await operation.as_command(conn)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=dbn,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
)
|
||||
|
||||
if publish:
|
||||
cmd, dbn = await operation.as_command(conn)
|
||||
if "$db" not in cmd:
|
||||
cmd["$db"] = dbn
|
||||
assert listeners is not None
|
||||
listeners.publish_command_start(
|
||||
cmd,
|
||||
dbn,
|
||||
request_id,
|
||||
conn.address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
)
|
||||
|
||||
try:
|
||||
if more_to_come:
|
||||
reply = await conn.receive_message(None)
|
||||
else:
|
||||
await conn.send_message(data, max_doc_size)
|
||||
reply = await conn.receive_message(request_id)
|
||||
|
||||
# Unpack and check for command errors.
|
||||
if use_cmd:
|
||||
user_fields = _CURSOR_DOC_FIELDS
|
||||
legacy_response = False
|
||||
else:
|
||||
user_fields = None
|
||||
legacy_response = True
|
||||
docs = unpack_res(
|
||||
reply,
|
||||
operation.cursor_id,
|
||||
operation.codec_options,
|
||||
legacy_response=legacy_response,
|
||||
user_fields=user_fields,
|
||||
)
|
||||
if use_cmd:
|
||||
first = docs[0]
|
||||
await operation.client._process_response(first, operation.session)
|
||||
_check_command_response(first, conn.max_wire_version)
|
||||
except Exception as exc:
|
||||
duration = datetime.now() - start
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=dbn,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
listeners.publish_command_failure(
|
||||
duration,
|
||||
failure,
|
||||
operation.name,
|
||||
request_id,
|
||||
conn.address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
database_name=dbn,
|
||||
)
|
||||
raise
|
||||
duration = datetime.now() - start
|
||||
# Must publish in find / getMore / explain command response
|
||||
# format.
|
||||
if use_cmd:
|
||||
res = docs[0]
|
||||
elif operation.name == "explain":
|
||||
res = docs[0] if docs else {}
|
||||
else:
|
||||
res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr]
|
||||
if operation.name == "find":
|
||||
res["cursor"]["firstBatch"] = docs
|
||||
else:
|
||||
res["cursor"]["nextBatch"] = docs
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
durationMS=duration,
|
||||
reply=res,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=dbn,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
listeners.publish_command_success(
|
||||
duration,
|
||||
res,
|
||||
operation.name,
|
||||
request_id,
|
||||
conn.address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
database_name=dbn,
|
||||
)
|
||||
|
||||
# Decrypt response.
|
||||
client = operation.client
|
||||
if client and client._encrypter:
|
||||
if use_cmd:
|
||||
decrypted = client._encrypter.decrypt(reply.raw_command_response())
|
||||
docs = _decode_all_selective(decrypted, operation.codec_options, user_fields)
|
||||
|
||||
response: Response
|
||||
|
||||
if client._should_pin_cursor(operation.session) or operation.exhaust:
|
||||
conn.pin_cursor()
|
||||
if isinstance(reply, _OpMsg):
|
||||
# In OP_MSG, the server keeps sending only if the
|
||||
# more_to_come flag is set.
|
||||
more_to_come = reply.more_to_come
|
||||
else:
|
||||
# In OP_REPLY, the server keeps sending until cursor_id is 0.
|
||||
more_to_come = bool(operation.exhaust and reply.cursor_id)
|
||||
if operation.conn_mgr:
|
||||
operation.conn_mgr.update_exhaust(more_to_come)
|
||||
response = PinnedResponse(
|
||||
data=reply,
|
||||
address=self._description.address,
|
||||
conn=conn,
|
||||
duration=duration,
|
||||
request_id=request_id,
|
||||
from_command=use_cmd,
|
||||
docs=docs,
|
||||
more_to_come=more_to_come,
|
||||
)
|
||||
else:
|
||||
response = Response(
|
||||
data=reply,
|
||||
address=self._description.address,
|
||||
duration=duration,
|
||||
request_id=request_id,
|
||||
from_command=use_cmd,
|
||||
docs=docs,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def checkout(
|
||||
self, handler: Optional[_MongoClientErrorHandler] = None
|
||||
) -> AsyncContextManager[Connection]:
|
||||
return self.pool.checkout(handler)
|
||||
|
||||
@property
|
||||
def description(self) -> ServerDescription:
|
||||
return self._description
|
||||
|
||||
@description.setter
|
||||
def description(self, server_description: ServerDescription) -> None:
|
||||
assert server_description.address == self._description.address
|
||||
self._description = server_description
|
||||
|
||||
@property
|
||||
def pool(self) -> Pool:
|
||||
return self._pool
|
||||
|
||||
def _split_message(
|
||||
self, message: Union[tuple[int, Any], tuple[int, Any, int]]
|
||||
) -> tuple[int, Any, int]:
|
||||
"""Return request_id, data, max_doc_size.
|
||||
|
||||
:param message: (request_id, data, max_doc_size) or (request_id, data)
|
||||
"""
|
||||
if len(message) == 3:
|
||||
return message # type: ignore[return-value]
|
||||
else:
|
||||
# get_more and kill_cursors messages don't include BSON documents.
|
||||
request_id, data = message # type: ignore[misc]
|
||||
return request_id, data, 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} {self._description!r}>"
|
||||
301
pymongo/asynchronous/server_description.py
Normal file
301
pymongo/asynchronous/server_description.py
Normal file
@ -0,0 +1,301 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Represent one server the driver is connected to."""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from bson import EPOCH_NAIVE
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.asynchronous.hello import Hello
|
||||
from pymongo.asynchronous.typings import ClusterTime, _Address
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class ServerDescription:
|
||||
"""Immutable representation of one server.
|
||||
|
||||
:param address: A (host, port) pair
|
||||
:param hello: Optional Hello instance
|
||||
:param round_trip_time: Optional float
|
||||
:param error: Optional, the last error attempting to connect to the server
|
||||
:param round_trip_time: Optional float, the min latency from the most recent samples
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"_address",
|
||||
"_server_type",
|
||||
"_all_hosts",
|
||||
"_tags",
|
||||
"_replica_set_name",
|
||||
"_primary",
|
||||
"_max_bson_size",
|
||||
"_max_message_size",
|
||||
"_max_write_batch_size",
|
||||
"_min_wire_version",
|
||||
"_max_wire_version",
|
||||
"_round_trip_time",
|
||||
"_min_round_trip_time",
|
||||
"_me",
|
||||
"_is_writable",
|
||||
"_is_readable",
|
||||
"_ls_timeout_minutes",
|
||||
"_error",
|
||||
"_set_version",
|
||||
"_election_id",
|
||||
"_cluster_time",
|
||||
"_last_write_date",
|
||||
"_last_update_time",
|
||||
"_topology_version",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: _Address,
|
||||
hello: Optional[Hello] = None,
|
||||
round_trip_time: Optional[float] = None,
|
||||
error: Optional[Exception] = None,
|
||||
min_round_trip_time: float = 0.0,
|
||||
) -> None:
|
||||
self._address = address
|
||||
if not hello:
|
||||
hello = Hello({})
|
||||
|
||||
self._server_type = hello.server_type
|
||||
self._all_hosts = hello.all_hosts
|
||||
self._tags = hello.tags
|
||||
self._replica_set_name = hello.replica_set_name
|
||||
self._primary = hello.primary
|
||||
self._max_bson_size = hello.max_bson_size
|
||||
self._max_message_size = hello.max_message_size
|
||||
self._max_write_batch_size = hello.max_write_batch_size
|
||||
self._min_wire_version = hello.min_wire_version
|
||||
self._max_wire_version = hello.max_wire_version
|
||||
self._set_version = hello.set_version
|
||||
self._election_id = hello.election_id
|
||||
self._cluster_time = hello.cluster_time
|
||||
self._is_writable = hello.is_writable
|
||||
self._is_readable = hello.is_readable
|
||||
self._ls_timeout_minutes = hello.logical_session_timeout_minutes
|
||||
self._round_trip_time = round_trip_time
|
||||
self._min_round_trip_time = min_round_trip_time
|
||||
self._me = hello.me
|
||||
self._last_update_time = time.monotonic()
|
||||
self._error = error
|
||||
self._topology_version = hello.topology_version
|
||||
if error:
|
||||
details = getattr(error, "details", None)
|
||||
if isinstance(details, dict):
|
||||
self._topology_version = details.get("topologyVersion")
|
||||
|
||||
self._last_write_date: Optional[float]
|
||||
if hello.last_write_date:
|
||||
# Convert from datetime to seconds.
|
||||
delta = hello.last_write_date - EPOCH_NAIVE
|
||||
self._last_write_date = delta.total_seconds()
|
||||
else:
|
||||
self._last_write_date = None
|
||||
|
||||
@property
|
||||
def address(self) -> _Address:
|
||||
"""The address (host, port) of this server."""
|
||||
return self._address
|
||||
|
||||
@property
|
||||
def server_type(self) -> int:
|
||||
"""The type of this server."""
|
||||
return self._server_type
|
||||
|
||||
@property
|
||||
def server_type_name(self) -> str:
|
||||
"""The server type as a human readable string.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
return SERVER_TYPE._fields[self._server_type]
|
||||
|
||||
@property
|
||||
def all_hosts(self) -> set[tuple[str, int]]:
|
||||
"""List of hosts, passives, and arbiters known to this server."""
|
||||
return self._all_hosts
|
||||
|
||||
@property
|
||||
def tags(self) -> Mapping[str, Any]:
|
||||
return self._tags
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""Replica set name or None."""
|
||||
return self._replica_set_name
|
||||
|
||||
@property
|
||||
def primary(self) -> Optional[tuple[str, int]]:
|
||||
"""This server's opinion about who the primary is, or None."""
|
||||
return self._primary
|
||||
|
||||
@property
|
||||
def max_bson_size(self) -> int:
|
||||
return self._max_bson_size
|
||||
|
||||
@property
|
||||
def max_message_size(self) -> int:
|
||||
return self._max_message_size
|
||||
|
||||
@property
|
||||
def max_write_batch_size(self) -> int:
|
||||
return self._max_write_batch_size
|
||||
|
||||
@property
|
||||
def min_wire_version(self) -> int:
|
||||
return self._min_wire_version
|
||||
|
||||
@property
|
||||
def max_wire_version(self) -> int:
|
||||
return self._max_wire_version
|
||||
|
||||
@property
|
||||
def set_version(self) -> Optional[int]:
|
||||
return self._set_version
|
||||
|
||||
@property
|
||||
def election_id(self) -> Optional[ObjectId]:
|
||||
return self._election_id
|
||||
|
||||
@property
|
||||
def cluster_time(self) -> Optional[ClusterTime]:
|
||||
return self._cluster_time
|
||||
|
||||
@property
|
||||
def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]:
|
||||
warnings.warn(
|
||||
"'election_tuple' is deprecated, use 'set_version' and 'election_id' instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._set_version, self._election_id
|
||||
|
||||
@property
|
||||
def me(self) -> Optional[tuple[str, int]]:
|
||||
return self._me
|
||||
|
||||
@property
|
||||
def logical_session_timeout_minutes(self) -> Optional[int]:
|
||||
return self._ls_timeout_minutes
|
||||
|
||||
@property
|
||||
def last_write_date(self) -> Optional[float]:
|
||||
return self._last_write_date
|
||||
|
||||
@property
|
||||
def last_update_time(self) -> float:
|
||||
return self._last_update_time
|
||||
|
||||
@property
|
||||
def round_trip_time(self) -> Optional[float]:
|
||||
"""The current average latency or None."""
|
||||
# This override is for unittesting only!
|
||||
if self._address in self._host_to_round_trip_time:
|
||||
return self._host_to_round_trip_time[self._address]
|
||||
|
||||
return self._round_trip_time
|
||||
|
||||
@property
|
||||
def min_round_trip_time(self) -> float:
|
||||
"""The min latency from the most recent samples."""
|
||||
return self._min_round_trip_time
|
||||
|
||||
@property
|
||||
def error(self) -> Optional[Exception]:
|
||||
"""The last error attempting to connect to the server, or None."""
|
||||
return self._error
|
||||
|
||||
@property
|
||||
def is_writable(self) -> bool:
|
||||
return self._is_writable
|
||||
|
||||
@property
|
||||
def is_readable(self) -> bool:
|
||||
return self._is_readable
|
||||
|
||||
@property
|
||||
def mongos(self) -> bool:
|
||||
return self._server_type == SERVER_TYPE.Mongos
|
||||
|
||||
@property
|
||||
def is_server_type_known(self) -> bool:
|
||||
return self.server_type != SERVER_TYPE.Unknown
|
||||
|
||||
@property
|
||||
def retryable_writes_supported(self) -> bool:
|
||||
"""Checks if this server supports retryable writes."""
|
||||
return (
|
||||
self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)
|
||||
) or self._server_type == SERVER_TYPE.LoadBalancer
|
||||
|
||||
@property
|
||||
def retryable_reads_supported(self) -> bool:
|
||||
"""Checks if this server supports retryable writes."""
|
||||
return self._max_wire_version >= 6
|
||||
|
||||
@property
|
||||
def topology_version(self) -> Optional[Mapping[str, Any]]:
|
||||
return self._topology_version
|
||||
|
||||
def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription:
|
||||
unknown = ServerDescription(self.address, error=error)
|
||||
unknown._topology_version = self.topology_version
|
||||
return unknown
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, ServerDescription):
|
||||
return (
|
||||
(self._address == other.address)
|
||||
and (self._server_type == other.server_type)
|
||||
and (self._min_wire_version == other.min_wire_version)
|
||||
and (self._max_wire_version == other.max_wire_version)
|
||||
and (self._me == other.me)
|
||||
and (self._all_hosts == other.all_hosts)
|
||||
and (self._tags == other.tags)
|
||||
and (self._replica_set_name == other.replica_set_name)
|
||||
and (self._set_version == other.set_version)
|
||||
and (self._election_id == other.election_id)
|
||||
and (self._primary == other.primary)
|
||||
and (self._ls_timeout_minutes == other.logical_session_timeout_minutes)
|
||||
and (self._error == other.error)
|
||||
)
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __repr__(self) -> str:
|
||||
errmsg = ""
|
||||
if self.error:
|
||||
errmsg = f", error={self.error!r}"
|
||||
return "<{} {} server_type: {}, rtt: {}{}>".format(
|
||||
self.__class__.__name__,
|
||||
self.address,
|
||||
self.server_type_name,
|
||||
self.round_trip_time,
|
||||
errmsg,
|
||||
)
|
||||
|
||||
# For unittesting only. Use under no circumstances!
|
||||
_host_to_round_trip_time: dict = {}
|
||||
175
pymongo/asynchronous/server_selectors.py
Normal file
175
pymongo/asynchronous/server_selectors.py
Normal file
@ -0,0 +1,175 @@
|
||||
# Copyright 2014-2016 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Criteria to select some ServerDescriptions from a TopologyDescription."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, TypeVar, cast
|
||||
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.server_description import ServerDescription
|
||||
from pymongo.asynchronous.topology_description import TopologyDescription
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
T = TypeVar("T")
|
||||
TagSet = Mapping[str, Any]
|
||||
TagSets = Sequence[TagSet]
|
||||
|
||||
|
||||
class Selection:
|
||||
"""Input or output of a server selector function."""
|
||||
|
||||
@classmethod
|
||||
def from_topology_description(cls, topology_description: TopologyDescription) -> Selection:
|
||||
known_servers = topology_description.known_servers
|
||||
primary = None
|
||||
for sd in known_servers:
|
||||
if sd.server_type == SERVER_TYPE.RSPrimary:
|
||||
primary = sd
|
||||
break
|
||||
|
||||
return Selection(
|
||||
topology_description,
|
||||
topology_description.known_servers,
|
||||
topology_description.common_wire_version,
|
||||
primary,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
topology_description: TopologyDescription,
|
||||
server_descriptions: list[ServerDescription],
|
||||
common_wire_version: Optional[int],
|
||||
primary: Optional[ServerDescription],
|
||||
):
|
||||
self.topology_description = topology_description
|
||||
self.server_descriptions = server_descriptions
|
||||
self.primary = primary
|
||||
self.common_wire_version = common_wire_version
|
||||
|
||||
def with_server_descriptions(self, server_descriptions: list[ServerDescription]) -> Selection:
|
||||
return Selection(
|
||||
self.topology_description, server_descriptions, self.common_wire_version, self.primary
|
||||
)
|
||||
|
||||
def secondary_with_max_last_write_date(self) -> Optional[ServerDescription]:
|
||||
secondaries = secondary_server_selector(self)
|
||||
if secondaries.server_descriptions:
|
||||
return max(
|
||||
secondaries.server_descriptions, key=lambda sd: cast(float, sd.last_write_date)
|
||||
)
|
||||
return None
|
||||
|
||||
@property
|
||||
def primary_selection(self) -> Selection:
|
||||
primaries = [self.primary] if self.primary else []
|
||||
return self.with_server_descriptions(primaries)
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self) -> int:
|
||||
return self.topology_description.heartbeat_frequency
|
||||
|
||||
@property
|
||||
def topology_type(self) -> int:
|
||||
return self.topology_description.topology_type
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.server_descriptions)
|
||||
|
||||
def __getitem__(self, item: int) -> ServerDescription:
|
||||
return self.server_descriptions[item]
|
||||
|
||||
|
||||
def any_server_selector(selection: T) -> T:
|
||||
return selection
|
||||
|
||||
|
||||
def readable_server_selector(selection: Selection) -> Selection:
|
||||
return selection.with_server_descriptions(
|
||||
[s for s in selection.server_descriptions if s.is_readable]
|
||||
)
|
||||
|
||||
|
||||
def writable_server_selector(selection: Selection) -> Selection:
|
||||
return selection.with_server_descriptions(
|
||||
[s for s in selection.server_descriptions if s.is_writable]
|
||||
)
|
||||
|
||||
|
||||
def secondary_server_selector(selection: Selection) -> Selection:
|
||||
return selection.with_server_descriptions(
|
||||
[s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSSecondary]
|
||||
)
|
||||
|
||||
|
||||
def arbiter_server_selector(selection: Selection) -> Selection:
|
||||
return selection.with_server_descriptions(
|
||||
[s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSArbiter]
|
||||
)
|
||||
|
||||
|
||||
def writable_preferred_server_selector(selection: Selection) -> Selection:
|
||||
"""Like PrimaryPreferred but doesn't use tags or latency."""
|
||||
return writable_server_selector(selection) or secondary_server_selector(selection)
|
||||
|
||||
|
||||
def apply_single_tag_set(tag_set: TagSet, selection: Selection) -> Selection:
|
||||
"""All servers matching one tag set.
|
||||
|
||||
A tag set is a dict. A server matches if its tags are a superset:
|
||||
A server tagged {'a': '1', 'b': '2'} matches the tag set {'a': '1'}.
|
||||
|
||||
The empty tag set {} matches any server.
|
||||
"""
|
||||
|
||||
def tags_match(server_tags: Mapping[str, Any]) -> bool:
|
||||
for key, value in tag_set.items():
|
||||
if key not in server_tags or server_tags[key] != value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return selection.with_server_descriptions(
|
||||
[s for s in selection.server_descriptions if tags_match(s.tags)]
|
||||
)
|
||||
|
||||
|
||||
def apply_tag_sets(tag_sets: TagSets, selection: Selection) -> Selection:
|
||||
"""All servers match a list of tag sets.
|
||||
|
||||
tag_sets is a list of dicts. The empty tag set {} matches any server,
|
||||
and may be provided at the end of the list as a fallback. So
|
||||
[{'a': 'value'}, {}] expresses a preference for servers tagged
|
||||
{'a': 'value'}, but accepts any server if none matches the first
|
||||
preference.
|
||||
"""
|
||||
for tag_set in tag_sets:
|
||||
with_tag_set = apply_single_tag_set(tag_set, selection)
|
||||
if with_tag_set:
|
||||
return with_tag_set
|
||||
|
||||
return selection.with_server_descriptions([])
|
||||
|
||||
|
||||
def secondary_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection:
|
||||
"""All near-enough secondaries matching the tag sets."""
|
||||
return apply_tag_sets(tag_sets, secondary_server_selector(selection))
|
||||
|
||||
|
||||
def member_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection:
|
||||
"""All near-enough members matching the tag sets."""
|
||||
return apply_tag_sets(tag_sets, readable_server_selector(selection))
|
||||
170
pymongo/asynchronous/settings.py
Normal file
170
pymongo/asynchronous/settings.py
Normal file
@ -0,0 +1,170 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Represent MongoClient's configuration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Any, Collection, Optional, Type, Union
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.asynchronous import common, monitor, pool
|
||||
from pymongo.asynchronous.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT
|
||||
from pymongo.asynchronous.pool import Pool, PoolOptions
|
||||
from pymongo.asynchronous.server_description import ServerDescription
|
||||
from pymongo.asynchronous.topology_description import TOPOLOGY_TYPE, _ServerSelector
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TopologySettings:
|
||||
def __init__(
|
||||
self,
|
||||
seeds: Optional[Collection[tuple[str, int]]] = None,
|
||||
replica_set_name: Optional[str] = None,
|
||||
pool_class: Optional[Type[Pool]] = None,
|
||||
pool_options: Optional[PoolOptions] = None,
|
||||
monitor_class: Optional[Type[monitor.Monitor]] = None,
|
||||
condition_class: Optional[Type[threading.Condition]] = None,
|
||||
local_threshold_ms: int = LOCAL_THRESHOLD_MS,
|
||||
server_selection_timeout: int = SERVER_SELECTION_TIMEOUT,
|
||||
heartbeat_frequency: int = common.HEARTBEAT_FREQUENCY,
|
||||
server_selector: Optional[_ServerSelector] = None,
|
||||
fqdn: Optional[str] = None,
|
||||
direct_connection: Optional[bool] = False,
|
||||
load_balanced: Optional[bool] = None,
|
||||
srv_service_name: str = common.SRV_SERVICE_NAME,
|
||||
srv_max_hosts: int = 0,
|
||||
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
|
||||
):
|
||||
"""Represent MongoClient's configuration.
|
||||
|
||||
Take a list of (host, port) pairs and optional replica set name.
|
||||
"""
|
||||
if heartbeat_frequency < common.MIN_HEARTBEAT_INTERVAL:
|
||||
raise ConfigurationError(
|
||||
"heartbeatFrequencyMS cannot be less than %d"
|
||||
% (common.MIN_HEARTBEAT_INTERVAL * 1000,)
|
||||
)
|
||||
|
||||
self._seeds: Collection[tuple[str, int]] = seeds or [("localhost", 27017)]
|
||||
self._replica_set_name = replica_set_name
|
||||
self._pool_class: Type[Pool] = pool_class or pool.Pool
|
||||
self._pool_options: PoolOptions = pool_options or PoolOptions()
|
||||
self._monitor_class: Type[monitor.Monitor] = monitor_class or monitor.Monitor
|
||||
self._condition_class: Type[threading.Condition] = condition_class or threading.Condition
|
||||
self._local_threshold_ms = local_threshold_ms
|
||||
self._server_selection_timeout = server_selection_timeout
|
||||
self._server_selector = server_selector
|
||||
self._fqdn = fqdn
|
||||
self._heartbeat_frequency = heartbeat_frequency
|
||||
self._direct = direct_connection
|
||||
self._load_balanced = load_balanced
|
||||
self._srv_service_name = srv_service_name
|
||||
self._srv_max_hosts = srv_max_hosts or 0
|
||||
self._server_monitoring_mode = server_monitoring_mode
|
||||
|
||||
self._topology_id = ObjectId()
|
||||
# Store the allocation traceback to catch unclosed clients in the
|
||||
# test suite.
|
||||
self._stack = "".join(traceback.format_stack())
|
||||
|
||||
@property
|
||||
def seeds(self) -> Collection[tuple[str, int]]:
|
||||
"""List of server addresses."""
|
||||
return self._seeds
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
return self._replica_set_name
|
||||
|
||||
@property
|
||||
def pool_class(self) -> Type[Pool]:
|
||||
return self._pool_class
|
||||
|
||||
@property
|
||||
def pool_options(self) -> PoolOptions:
|
||||
return self._pool_options
|
||||
|
||||
@property
|
||||
def monitor_class(self) -> Type[monitor.Monitor]:
|
||||
return self._monitor_class
|
||||
|
||||
@property
|
||||
def condition_class(self) -> Type[threading.Condition]:
|
||||
return self._condition_class
|
||||
|
||||
@property
|
||||
def local_threshold_ms(self) -> int:
|
||||
return self._local_threshold_ms
|
||||
|
||||
@property
|
||||
def server_selection_timeout(self) -> int:
|
||||
return self._server_selection_timeout
|
||||
|
||||
@property
|
||||
def server_selector(self) -> Optional[_ServerSelector]:
|
||||
return self._server_selector
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self) -> int:
|
||||
return self._heartbeat_frequency
|
||||
|
||||
@property
|
||||
def fqdn(self) -> Optional[str]:
|
||||
return self._fqdn
|
||||
|
||||
@property
|
||||
def direct(self) -> Optional[bool]:
|
||||
"""Connect directly to a single server, or use a set of servers?
|
||||
|
||||
True if there is one seed and no replica_set_name.
|
||||
"""
|
||||
return self._direct
|
||||
|
||||
@property
|
||||
def load_balanced(self) -> Optional[bool]:
|
||||
"""True if the client was configured to connect to a load balancer."""
|
||||
return self._load_balanced
|
||||
|
||||
@property
|
||||
def srv_service_name(self) -> str:
|
||||
"""The srvServiceName."""
|
||||
return self._srv_service_name
|
||||
|
||||
@property
|
||||
def srv_max_hosts(self) -> int:
|
||||
"""The srvMaxHosts."""
|
||||
return self._srv_max_hosts
|
||||
|
||||
@property
|
||||
def server_monitoring_mode(self) -> str:
|
||||
"""The serverMonitoringMode."""
|
||||
return self._server_monitoring_mode
|
||||
|
||||
def get_topology_type(self) -> int:
|
||||
if self.load_balanced:
|
||||
return TOPOLOGY_TYPE.LoadBalanced
|
||||
elif self.direct:
|
||||
return TOPOLOGY_TYPE.Single
|
||||
elif self.replica_set_name is not None:
|
||||
return TOPOLOGY_TYPE.ReplicaSetNoPrimary
|
||||
else:
|
||||
return TOPOLOGY_TYPE.Unknown
|
||||
|
||||
def get_server_descriptions(self) -> dict[Union[tuple[str, int], Any], ServerDescription]:
|
||||
"""Initial dict of (address, ServerDescription) for all seeds."""
|
||||
return {address: ServerDescription(address) for address in self.seeds}
|
||||
149
pymongo/asynchronous/srv_resolver.py
Normal file
149
pymongo/asynchronous/srv_resolver.py
Normal file
@ -0,0 +1,149 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pymongo.asynchronous.common import CONNECT_TIMEOUT
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dns import resolver
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _have_dnspython() -> bool:
|
||||
try:
|
||||
import dns # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
# dnspython can return bytes or str from various parts
|
||||
# of its API depending on version. We always want str.
|
||||
def maybe_decode(text: Union[str, bytes]) -> str:
|
||||
if isinstance(text, bytes):
|
||||
return text.decode()
|
||||
return text
|
||||
|
||||
|
||||
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
|
||||
def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
|
||||
from dns import resolver
|
||||
|
||||
if hasattr(resolver, "resolve"):
|
||||
# dnspython >= 2
|
||||
return resolver.resolve(*args, **kwargs)
|
||||
# dnspython 1.X
|
||||
return resolver.query(*args, **kwargs)
|
||||
|
||||
|
||||
_INVALID_HOST_MSG = (
|
||||
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. "
|
||||
"Did you mean to use 'mongodb://'?"
|
||||
)
|
||||
|
||||
|
||||
class _SrvResolver:
|
||||
def __init__(
|
||||
self,
|
||||
fqdn: str,
|
||||
connect_timeout: Optional[float],
|
||||
srv_service_name: str,
|
||||
srv_max_hosts: int = 0,
|
||||
):
|
||||
self.__fqdn = fqdn
|
||||
self.__srv = srv_service_name
|
||||
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
|
||||
self.__srv_max_hosts = srv_max_hosts or 0
|
||||
# Validate the fully qualified domain name.
|
||||
try:
|
||||
ipaddress.ip_address(fqdn)
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.__plist = self.__fqdn.split(".")[1:]
|
||||
except Exception:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
|
||||
self.__slen = len(self.__plist)
|
||||
if self.__slen < 2:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
|
||||
|
||||
def get_options(self) -> Optional[str]:
|
||||
from dns import resolver
|
||||
|
||||
try:
|
||||
results = _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout)
|
||||
except (resolver.NoAnswer, resolver.NXDOMAIN):
|
||||
# No TXT records
|
||||
return None
|
||||
except Exception as exc:
|
||||
raise ConfigurationError(str(exc)) from None
|
||||
if len(results) > 1:
|
||||
raise ConfigurationError("Only one TXT record is supported")
|
||||
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8")
|
||||
|
||||
def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer:
|
||||
try:
|
||||
results = _resolve(
|
||||
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
|
||||
)
|
||||
except Exception as exc:
|
||||
if not encapsulate_errors:
|
||||
# Raise the original error.
|
||||
raise
|
||||
# Else, raise all errors as ConfigurationError.
|
||||
raise ConfigurationError(str(exc)) from None
|
||||
return results
|
||||
|
||||
def _get_srv_response_and_hosts(
|
||||
self, encapsulate_errors: bool
|
||||
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
|
||||
results = self._resolve_uri(encapsulate_errors)
|
||||
|
||||
# Construct address tuples
|
||||
nodes = [
|
||||
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) for res in results
|
||||
]
|
||||
|
||||
# Validate hosts
|
||||
for node in nodes:
|
||||
try:
|
||||
nlist = node[0].lower().split(".")[1:][-self.__slen :]
|
||||
except Exception:
|
||||
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
|
||||
if self.__plist != nlist:
|
||||
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
|
||||
if self.__srv_max_hosts:
|
||||
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
|
||||
return results, nodes
|
||||
|
||||
def get_hosts(self) -> list[tuple[str, Any]]:
|
||||
_, nodes = self._get_srv_response_and_hosts(True)
|
||||
return nodes
|
||||
|
||||
def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]:
|
||||
results, nodes = self._get_srv_response_and_hosts(False)
|
||||
rrset = results.rrset
|
||||
ttl = rrset.ttl if rrset else 0
|
||||
return nodes, ttl
|
||||
1030
pymongo/asynchronous/topology.py
Normal file
1030
pymongo/asynchronous/topology.py
Normal file
File diff suppressed because it is too large
Load Diff
678
pymongo/asynchronous/topology_description.py
Normal file
678
pymongo/asynchronous/topology_description.py
Normal file
@ -0,0 +1,678 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Represent a deployment of MongoDB servers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from random import sample
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
|
||||
from bson.min_key import MinKey
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.asynchronous import common
|
||||
from pymongo.asynchronous.read_preferences import ReadPreference, _AggWritePref, _ServerMode
|
||||
from pymongo.asynchronous.server_description import ServerDescription
|
||||
from pymongo.asynchronous.server_selectors import Selection
|
||||
from pymongo.asynchronous.typings import _Address
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
# Enumeration for various kinds of MongoDB cluster topologies.
|
||||
class _TopologyType(NamedTuple):
|
||||
Single: int
|
||||
ReplicaSetNoPrimary: int
|
||||
ReplicaSetWithPrimary: int
|
||||
Sharded: int
|
||||
Unknown: int
|
||||
LoadBalanced: int
|
||||
|
||||
|
||||
TOPOLOGY_TYPE = _TopologyType(*range(6))
|
||||
|
||||
# Topologies compatible with SRV record polling.
|
||||
SRV_POLLING_TOPOLOGIES: tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded)
|
||||
|
||||
|
||||
_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]]
|
||||
|
||||
|
||||
class TopologyDescription:
|
||||
def __init__(
|
||||
self,
|
||||
topology_type: int,
|
||||
server_descriptions: dict[_Address, ServerDescription],
|
||||
replica_set_name: Optional[str],
|
||||
max_set_version: Optional[int],
|
||||
max_election_id: Optional[ObjectId],
|
||||
topology_settings: Any,
|
||||
) -> None:
|
||||
"""Representation of a deployment of MongoDB servers.
|
||||
|
||||
:param topology_type: initial type
|
||||
:param server_descriptions: dict of (address, ServerDescription) for
|
||||
all seeds
|
||||
:param replica_set_name: replica set name or None
|
||||
:param max_set_version: greatest setVersion seen from a primary, or None
|
||||
:param max_election_id: greatest electionId seen from a primary, or None
|
||||
:param topology_settings: a TopologySettings
|
||||
"""
|
||||
self._topology_type = topology_type
|
||||
self._replica_set_name = replica_set_name
|
||||
self._server_descriptions = server_descriptions
|
||||
self._max_set_version = max_set_version
|
||||
self._max_election_id = max_election_id
|
||||
|
||||
# The heartbeat_frequency is used in staleness estimates.
|
||||
self._topology_settings = topology_settings
|
||||
|
||||
# Is PyMongo compatible with all servers' wire protocols?
|
||||
self._incompatible_err = None
|
||||
if self._topology_type != TOPOLOGY_TYPE.LoadBalanced:
|
||||
self._init_incompatible_err()
|
||||
|
||||
# Server Discovery And Monitoring Spec: Whenever a client updates the
|
||||
# TopologyDescription from an hello response, it MUST set
|
||||
# TopologyDescription.logicalSessionTimeoutMinutes to the smallest
|
||||
# logicalSessionTimeoutMinutes value among ServerDescriptions of all
|
||||
# data-bearing server types. If any have a null
|
||||
# logicalSessionTimeoutMinutes, then
|
||||
# TopologyDescription.logicalSessionTimeoutMinutes MUST be set to null.
|
||||
readable_servers = self.readable_servers
|
||||
if not readable_servers:
|
||||
self._ls_timeout_minutes = None
|
||||
elif any(s.logical_session_timeout_minutes is None for s in readable_servers):
|
||||
self._ls_timeout_minutes = None
|
||||
else:
|
||||
self._ls_timeout_minutes = min( # type: ignore[type-var]
|
||||
s.logical_session_timeout_minutes for s in readable_servers
|
||||
)
|
||||
|
||||
def _init_incompatible_err(self) -> None:
|
||||
"""Internal compatibility check for non-load balanced topologies."""
|
||||
for s in self._server_descriptions.values():
|
||||
if not s.is_server_type_known:
|
||||
continue
|
||||
|
||||
# s.min/max_wire_version is the server's wire protocol.
|
||||
# MIN/MAX_SUPPORTED_WIRE_VERSION is what PyMongo supports.
|
||||
server_too_new = (
|
||||
# Server too new.
|
||||
s.min_wire_version is not None
|
||||
and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION
|
||||
)
|
||||
|
||||
server_too_old = (
|
||||
# Server too old.
|
||||
s.max_wire_version is not None
|
||||
and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION
|
||||
)
|
||||
|
||||
if server_too_new:
|
||||
self._incompatible_err = (
|
||||
"Server at %s:%d requires wire version %d, but this " # type: ignore
|
||||
"version of PyMongo only supports up to %d."
|
||||
% (
|
||||
s.address[0],
|
||||
s.address[1] or 0,
|
||||
s.min_wire_version,
|
||||
common.MAX_SUPPORTED_WIRE_VERSION,
|
||||
)
|
||||
)
|
||||
|
||||
elif server_too_old:
|
||||
self._incompatible_err = (
|
||||
"Server at %s:%d reports wire version %d, but this " # type: ignore
|
||||
"version of PyMongo requires at least %d (MongoDB %s)."
|
||||
% (
|
||||
s.address[0],
|
||||
s.address[1] or 0,
|
||||
s.max_wire_version,
|
||||
common.MIN_SUPPORTED_WIRE_VERSION,
|
||||
common.MIN_SUPPORTED_SERVER_VERSION,
|
||||
)
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
def check_compatible(self) -> None:
|
||||
"""Raise ConfigurationError if any server is incompatible.
|
||||
|
||||
A server is incompatible if its wire protocol version range does not
|
||||
overlap with PyMongo's.
|
||||
"""
|
||||
if self._incompatible_err:
|
||||
raise ConfigurationError(self._incompatible_err)
|
||||
|
||||
def has_server(self, address: _Address) -> bool:
|
||||
return address in self._server_descriptions
|
||||
|
||||
def reset_server(self, address: _Address) -> TopologyDescription:
|
||||
"""A copy of this description, with one server marked Unknown."""
|
||||
unknown_sd = self._server_descriptions[address].to_unknown()
|
||||
return updated_topology_description(self, unknown_sd)
|
||||
|
||||
def reset(self) -> TopologyDescription:
|
||||
"""A copy of this description, with all servers marked Unknown."""
|
||||
if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary:
|
||||
topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary
|
||||
else:
|
||||
topology_type = self._topology_type
|
||||
|
||||
# The default ServerDescription's type is Unknown.
|
||||
sds = {address: ServerDescription(address) for address in self._server_descriptions}
|
||||
|
||||
return TopologyDescription(
|
||||
topology_type,
|
||||
sds,
|
||||
self._replica_set_name,
|
||||
self._max_set_version,
|
||||
self._max_election_id,
|
||||
self._topology_settings,
|
||||
)
|
||||
|
||||
def server_descriptions(self) -> dict[_Address, ServerDescription]:
|
||||
"""dict of (address,
|
||||
:class:`~pymongo.server_description.ServerDescription`).
|
||||
"""
|
||||
return self._server_descriptions.copy()
|
||||
|
||||
@property
|
||||
def topology_type(self) -> int:
|
||||
"""The type of this topology."""
|
||||
return self._topology_type
|
||||
|
||||
@property
|
||||
def topology_type_name(self) -> str:
|
||||
"""The topology type as a human readable string.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
return TOPOLOGY_TYPE._fields[self._topology_type]
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""The replica set name."""
|
||||
return self._replica_set_name
|
||||
|
||||
@property
|
||||
def max_set_version(self) -> Optional[int]:
|
||||
"""Greatest setVersion seen from a primary, or None."""
|
||||
return self._max_set_version
|
||||
|
||||
@property
|
||||
def max_election_id(self) -> Optional[ObjectId]:
|
||||
"""Greatest electionId seen from a primary, or None."""
|
||||
return self._max_election_id
|
||||
|
||||
@property
|
||||
def logical_session_timeout_minutes(self) -> Optional[int]:
|
||||
"""Minimum logical session timeout, or None."""
|
||||
return self._ls_timeout_minutes
|
||||
|
||||
@property
|
||||
def known_servers(self) -> list[ServerDescription]:
|
||||
"""List of Servers of types besides Unknown."""
|
||||
return [s for s in self._server_descriptions.values() if s.is_server_type_known]
|
||||
|
||||
@property
|
||||
def has_known_servers(self) -> bool:
|
||||
"""Whether there are any Servers of types besides Unknown."""
|
||||
return any(s for s in self._server_descriptions.values() if s.is_server_type_known)
|
||||
|
||||
@property
|
||||
def readable_servers(self) -> list[ServerDescription]:
|
||||
"""List of readable Servers."""
|
||||
return [s for s in self._server_descriptions.values() if s.is_readable]
|
||||
|
||||
@property
|
||||
def common_wire_version(self) -> Optional[int]:
|
||||
"""Minimum of all servers' max wire versions, or None."""
|
||||
servers = self.known_servers
|
||||
if servers:
|
||||
return min(s.max_wire_version for s in self.known_servers)
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self) -> int:
|
||||
return self._topology_settings.heartbeat_frequency
|
||||
|
||||
@property
|
||||
def srv_max_hosts(self) -> int:
|
||||
return self._topology_settings._srv_max_hosts
|
||||
|
||||
def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]:
|
||||
if not selection:
|
||||
return []
|
||||
round_trip_times: list[float] = []
|
||||
for server in selection.server_descriptions:
|
||||
if server.round_trip_time is None:
|
||||
config_err_msg = f"round_trip_time for server {server.address} is unexpectedly None: {self}, servers: {selection.server_descriptions}"
|
||||
raise ConfigurationError(config_err_msg)
|
||||
round_trip_times.append(server.round_trip_time)
|
||||
# Round trip time in seconds.
|
||||
fastest = min(round_trip_times)
|
||||
threshold = self._topology_settings.local_threshold_ms / 1000.0
|
||||
return [
|
||||
s
|
||||
for s in selection.server_descriptions
|
||||
if (cast(float, s.round_trip_time) - fastest) <= threshold
|
||||
]
|
||||
|
||||
def apply_selector(
|
||||
self,
|
||||
selector: Any,
|
||||
address: Optional[_Address] = None,
|
||||
custom_selector: Optional[_ServerSelector] = None,
|
||||
) -> list[ServerDescription]:
|
||||
"""List of servers matching the provided selector(s).
|
||||
|
||||
:param selector: a callable that takes a Selection as input and returns
|
||||
a Selection as output. For example, an instance of a read
|
||||
preference from :mod:`~pymongo.read_preferences`.
|
||||
:param address: A server address to select.
|
||||
:param custom_selector: A callable that augments server
|
||||
selection rules. Accepts a list of
|
||||
:class:`~pymongo.server_description.ServerDescription` objects and
|
||||
return a list of server descriptions that should be considered
|
||||
suitable for the desired operation.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
if getattr(selector, "min_wire_version", 0):
|
||||
common_wv = self.common_wire_version
|
||||
if common_wv and common_wv < selector.min_wire_version:
|
||||
raise ConfigurationError(
|
||||
"%s requires min wire version %d, but topology's min"
|
||||
" wire version is %d" % (selector, selector.min_wire_version, common_wv)
|
||||
)
|
||||
|
||||
if isinstance(selector, _AggWritePref):
|
||||
selector.selection_hook(self)
|
||||
|
||||
if self.topology_type == TOPOLOGY_TYPE.Unknown:
|
||||
return []
|
||||
elif self.topology_type in (TOPOLOGY_TYPE.Single, TOPOLOGY_TYPE.LoadBalanced):
|
||||
# Ignore selectors for standalone and load balancer mode.
|
||||
return self.known_servers
|
||||
if address:
|
||||
# Ignore selectors when explicit address is requested.
|
||||
description = self.server_descriptions().get(address)
|
||||
return [description] if description else []
|
||||
|
||||
selection = Selection.from_topology_description(self)
|
||||
# Ignore read preference for sharded clusters.
|
||||
if self.topology_type != TOPOLOGY_TYPE.Sharded:
|
||||
selection = selector(selection)
|
||||
|
||||
# Apply custom selector followed by localThresholdMS.
|
||||
if custom_selector is not None and selection:
|
||||
selection = selection.with_server_descriptions(
|
||||
custom_selector(selection.server_descriptions)
|
||||
)
|
||||
return self._apply_local_threshold(selection)
|
||||
|
||||
def has_readable_server(self, read_preference: _ServerMode = ReadPreference.PRIMARY) -> bool:
|
||||
"""Does this topology have any readable servers available matching the
|
||||
given read preference?
|
||||
|
||||
:param read_preference: an instance of a read preference from
|
||||
:mod:`~pymongo.read_preferences`. Defaults to
|
||||
:attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`.
|
||||
|
||||
.. note:: When connected directly to a single server this method
|
||||
always returns ``True``.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
common.validate_read_preference("read_preference", read_preference)
|
||||
return any(self.apply_selector(read_preference))
|
||||
|
||||
def has_writable_server(self) -> bool:
|
||||
"""Does this topology have a writable server available?
|
||||
|
||||
.. note:: When connected directly to a single server this method
|
||||
always returns ``True``.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
return self.has_readable_server(ReadPreference.PRIMARY)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Sort the servers by address.
|
||||
servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address)
|
||||
return "<{} id: {}, topology_type: {}, servers: {!r}>".format(
|
||||
self.__class__.__name__,
|
||||
self._topology_settings._topology_id,
|
||||
self.topology_type_name,
|
||||
servers,
|
||||
)
|
||||
|
||||
|
||||
# If topology type is Unknown and we receive a hello response, what should
|
||||
# the new topology type be?
|
||||
_SERVER_TYPE_TO_TOPOLOGY_TYPE = {
|
||||
SERVER_TYPE.Mongos: TOPOLOGY_TYPE.Sharded,
|
||||
SERVER_TYPE.RSPrimary: TOPOLOGY_TYPE.ReplicaSetWithPrimary,
|
||||
SERVER_TYPE.RSSecondary: TOPOLOGY_TYPE.ReplicaSetNoPrimary,
|
||||
SERVER_TYPE.RSArbiter: TOPOLOGY_TYPE.ReplicaSetNoPrimary,
|
||||
SERVER_TYPE.RSOther: TOPOLOGY_TYPE.ReplicaSetNoPrimary,
|
||||
# Note: SERVER_TYPE.LoadBalancer and Unknown are intentionally left out.
|
||||
}
|
||||
|
||||
|
||||
def updated_topology_description(
|
||||
topology_description: TopologyDescription, server_description: ServerDescription
|
||||
) -> TopologyDescription:
|
||||
"""Return an updated copy of a TopologyDescription.
|
||||
|
||||
:param topology_description: the current TopologyDescription
|
||||
:param server_description: a new ServerDescription that resulted from
|
||||
a hello call
|
||||
|
||||
Called after attempting (successfully or not) to call hello on the
|
||||
server at server_description.address. Does not modify topology_description.
|
||||
"""
|
||||
address = server_description.address
|
||||
|
||||
# These values will be updated, if necessary, to form the new
|
||||
# TopologyDescription.
|
||||
topology_type = topology_description.topology_type
|
||||
set_name = topology_description.replica_set_name
|
||||
max_set_version = topology_description.max_set_version
|
||||
max_election_id = topology_description.max_election_id
|
||||
server_type = server_description.server_type
|
||||
|
||||
# Don't mutate the original dict of server descriptions; copy it.
|
||||
sds = topology_description.server_descriptions()
|
||||
|
||||
# Replace this server's description with the new one.
|
||||
sds[address] = server_description
|
||||
|
||||
if topology_type == TOPOLOGY_TYPE.Single:
|
||||
# Set server type to Unknown if replica set name does not match.
|
||||
if set_name is not None and set_name != server_description.replica_set_name:
|
||||
error = ConfigurationError(
|
||||
"client is configured to connect to a replica set named "
|
||||
"'{}' but this node belongs to a set named '{}'".format(
|
||||
set_name, server_description.replica_set_name
|
||||
)
|
||||
)
|
||||
sds[address] = server_description.to_unknown(error=error)
|
||||
# Single type never changes.
|
||||
return TopologyDescription(
|
||||
TOPOLOGY_TYPE.Single,
|
||||
sds,
|
||||
set_name,
|
||||
max_set_version,
|
||||
max_election_id,
|
||||
topology_description._topology_settings,
|
||||
)
|
||||
|
||||
if topology_type == TOPOLOGY_TYPE.Unknown:
|
||||
if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.LoadBalancer):
|
||||
if len(topology_description._topology_settings.seeds) == 1:
|
||||
topology_type = TOPOLOGY_TYPE.Single
|
||||
else:
|
||||
# Remove standalone from Topology when given multiple seeds.
|
||||
sds.pop(address)
|
||||
elif server_type not in (SERVER_TYPE.Unknown, SERVER_TYPE.RSGhost):
|
||||
topology_type = _SERVER_TYPE_TO_TOPOLOGY_TYPE[server_type]
|
||||
|
||||
if topology_type == TOPOLOGY_TYPE.Sharded:
|
||||
if server_type not in (SERVER_TYPE.Mongos, SERVER_TYPE.Unknown):
|
||||
sds.pop(address)
|
||||
|
||||
elif topology_type == TOPOLOGY_TYPE.ReplicaSetNoPrimary:
|
||||
if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos):
|
||||
sds.pop(address)
|
||||
|
||||
elif server_type == SERVER_TYPE.RSPrimary:
|
||||
(topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary(
|
||||
sds, set_name, server_description, max_set_version, max_election_id
|
||||
)
|
||||
|
||||
elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther):
|
||||
topology_type, set_name = _update_rs_no_primary_from_member(
|
||||
sds, set_name, server_description
|
||||
)
|
||||
|
||||
elif topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary:
|
||||
if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos):
|
||||
sds.pop(address)
|
||||
topology_type = _check_has_primary(sds)
|
||||
|
||||
elif server_type == SERVER_TYPE.RSPrimary:
|
||||
(topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary(
|
||||
sds, set_name, server_description, max_set_version, max_election_id
|
||||
)
|
||||
|
||||
elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther):
|
||||
topology_type = _update_rs_with_primary_from_member(sds, set_name, server_description)
|
||||
|
||||
else:
|
||||
# Server type is Unknown or RSGhost: did we just lose the primary?
|
||||
topology_type = _check_has_primary(sds)
|
||||
|
||||
# Return updated copy.
|
||||
return TopologyDescription(
|
||||
topology_type,
|
||||
sds,
|
||||
set_name,
|
||||
max_set_version,
|
||||
max_election_id,
|
||||
topology_description._topology_settings,
|
||||
)
|
||||
|
||||
|
||||
def _updated_topology_description_srv_polling(
|
||||
topology_description: TopologyDescription, seedlist: list[tuple[str, Any]]
|
||||
) -> TopologyDescription:
|
||||
"""Return an updated copy of a TopologyDescription.
|
||||
|
||||
:param topology_description: the current TopologyDescription
|
||||
:param seedlist: a list of new seeds new ServerDescription that resulted from
|
||||
a hello call
|
||||
"""
|
||||
assert topology_description.topology_type in SRV_POLLING_TOPOLOGIES
|
||||
# Create a copy of the server descriptions.
|
||||
sds = topology_description.server_descriptions()
|
||||
|
||||
# If seeds haven't changed, don't do anything.
|
||||
if set(sds.keys()) == set(seedlist):
|
||||
return topology_description
|
||||
|
||||
# Remove SDs corresponding to servers no longer part of the SRV record.
|
||||
for address in list(sds.keys()):
|
||||
if address not in seedlist:
|
||||
sds.pop(address)
|
||||
|
||||
if topology_description.srv_max_hosts != 0:
|
||||
new_hosts = set(seedlist) - set(sds.keys())
|
||||
n_to_add = topology_description.srv_max_hosts - len(sds)
|
||||
if n_to_add > 0:
|
||||
seedlist = sample(sorted(new_hosts), min(n_to_add, len(new_hosts)))
|
||||
else:
|
||||
seedlist = []
|
||||
# Add SDs corresponding to servers recently added to the SRV record.
|
||||
for address in seedlist:
|
||||
if address not in sds:
|
||||
sds[address] = ServerDescription(address)
|
||||
return TopologyDescription(
|
||||
topology_description.topology_type,
|
||||
sds,
|
||||
topology_description.replica_set_name,
|
||||
topology_description.max_set_version,
|
||||
topology_description.max_election_id,
|
||||
topology_description._topology_settings,
|
||||
)
|
||||
|
||||
|
||||
def _update_rs_from_primary(
|
||||
sds: MutableMapping[_Address, ServerDescription],
|
||||
replica_set_name: Optional[str],
|
||||
server_description: ServerDescription,
|
||||
max_set_version: Optional[int],
|
||||
max_election_id: Optional[ObjectId],
|
||||
) -> tuple[int, Optional[str], Optional[int], Optional[ObjectId]]:
|
||||
"""Update topology description from a primary's hello response.
|
||||
|
||||
Pass in a dict of ServerDescriptions, current replica set name, the
|
||||
ServerDescription we are processing, and the TopologyDescription's
|
||||
max_set_version and max_election_id if any.
|
||||
|
||||
Returns (new topology type, new replica_set_name, new max_set_version,
|
||||
new max_election_id).
|
||||
"""
|
||||
if replica_set_name is None:
|
||||
replica_set_name = server_description.replica_set_name
|
||||
|
||||
elif replica_set_name != server_description.replica_set_name:
|
||||
# We found a primary but it doesn't have the replica_set_name
|
||||
# provided by the user.
|
||||
sds.pop(server_description.address)
|
||||
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
|
||||
|
||||
if server_description.max_wire_version is None or server_description.max_wire_version < 17:
|
||||
new_election_tuple: tuple = (server_description.set_version, server_description.election_id)
|
||||
max_election_tuple: tuple = (max_set_version, max_election_id)
|
||||
if None not in new_election_tuple:
|
||||
if None not in max_election_tuple and new_election_tuple < max_election_tuple:
|
||||
# Stale primary, set to type Unknown.
|
||||
sds[server_description.address] = server_description.to_unknown()
|
||||
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
|
||||
max_election_id = server_description.election_id
|
||||
|
||||
if server_description.set_version is not None and (
|
||||
max_set_version is None or server_description.set_version > max_set_version
|
||||
):
|
||||
max_set_version = server_description.set_version
|
||||
else:
|
||||
new_election_tuple = server_description.election_id, server_description.set_version
|
||||
max_election_tuple = max_election_id, max_set_version
|
||||
new_election_safe = tuple(MinKey() if i is None else i for i in new_election_tuple)
|
||||
max_election_safe = tuple(MinKey() if i is None else i for i in max_election_tuple)
|
||||
if new_election_safe < max_election_safe:
|
||||
# Stale primary, set to type Unknown.
|
||||
sds[server_description.address] = server_description.to_unknown()
|
||||
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
|
||||
else:
|
||||
max_election_id = server_description.election_id
|
||||
max_set_version = server_description.set_version
|
||||
|
||||
# We've heard from the primary. Is it the same primary as before?
|
||||
for server in sds.values():
|
||||
if (
|
||||
server.server_type is SERVER_TYPE.RSPrimary
|
||||
and server.address != server_description.address
|
||||
):
|
||||
# Reset old primary's type to Unknown.
|
||||
sds[server.address] = server.to_unknown()
|
||||
|
||||
# There can be only one prior primary.
|
||||
break
|
||||
|
||||
# Discover new hosts from this primary's response.
|
||||
for new_address in server_description.all_hosts:
|
||||
if new_address not in sds:
|
||||
sds[new_address] = ServerDescription(new_address)
|
||||
|
||||
# Remove hosts not in the response.
|
||||
for addr in set(sds) - server_description.all_hosts:
|
||||
sds.pop(addr)
|
||||
|
||||
# If the host list differs from the seed list, we may not have a primary
|
||||
# after all.
|
||||
return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id)
|
||||
|
||||
|
||||
def _update_rs_with_primary_from_member(
|
||||
sds: MutableMapping[_Address, ServerDescription],
|
||||
replica_set_name: Optional[str],
|
||||
server_description: ServerDescription,
|
||||
) -> int:
|
||||
"""RS with known primary. Process a response from a non-primary.
|
||||
|
||||
Pass in a dict of ServerDescriptions, current replica set name, and the
|
||||
ServerDescription we are processing.
|
||||
|
||||
Returns new topology type.
|
||||
"""
|
||||
assert replica_set_name is not None
|
||||
|
||||
if replica_set_name != server_description.replica_set_name:
|
||||
sds.pop(server_description.address)
|
||||
elif server_description.me and server_description.address != server_description.me:
|
||||
sds.pop(server_description.address)
|
||||
|
||||
# Had this member been the primary?
|
||||
return _check_has_primary(sds)
|
||||
|
||||
|
||||
def _update_rs_no_primary_from_member(
|
||||
sds: MutableMapping[_Address, ServerDescription],
|
||||
replica_set_name: Optional[str],
|
||||
server_description: ServerDescription,
|
||||
) -> tuple[int, Optional[str]]:
|
||||
"""RS without known primary. Update from a non-primary's response.
|
||||
|
||||
Pass in a dict of ServerDescriptions, current replica set name, and the
|
||||
ServerDescription we are processing.
|
||||
|
||||
Returns (new topology type, new replica_set_name).
|
||||
"""
|
||||
topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary
|
||||
if replica_set_name is None:
|
||||
replica_set_name = server_description.replica_set_name
|
||||
|
||||
elif replica_set_name != server_description.replica_set_name:
|
||||
sds.pop(server_description.address)
|
||||
return topology_type, replica_set_name
|
||||
|
||||
# This isn't the primary's response, so don't remove any servers
|
||||
# it doesn't report. Only add new servers.
|
||||
for address in server_description.all_hosts:
|
||||
if address not in sds:
|
||||
sds[address] = ServerDescription(address)
|
||||
|
||||
if server_description.me and server_description.address != server_description.me:
|
||||
sds.pop(server_description.address)
|
||||
|
||||
return topology_type, replica_set_name
|
||||
|
||||
|
||||
def _check_has_primary(sds: Mapping[_Address, ServerDescription]) -> int:
|
||||
"""Current topology type is ReplicaSetWithPrimary. Is primary still known?
|
||||
|
||||
Pass in a dict of ServerDescriptions.
|
||||
|
||||
Returns new topology type.
|
||||
"""
|
||||
for s in sds.values():
|
||||
if s.server_type == SERVER_TYPE.RSPrimary:
|
||||
return TOPOLOGY_TYPE.ReplicaSetWithPrimary
|
||||
else: # noqa: PLW0120
|
||||
return TOPOLOGY_TYPE.ReplicaSetNoPrimary
|
||||
61
pymongo/asynchronous/typings.py
Normal file
61
pymongo/asynchronous/typings.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright 2022-Present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Type aliases used by PyMongo"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson.typings import _DocumentOut, _DocumentType, _DocumentTypeArg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.collation import Collation
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Common Shared Types.
|
||||
_Address = Tuple[str, Optional[int]]
|
||||
_CollationIn = Union[Mapping[str, Any], "Collation"]
|
||||
_Pipeline = Sequence[Mapping[str, Any]]
|
||||
ClusterTime = Mapping[str, Any]
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def strip_optional(elem: Optional[_T]) -> _T:
|
||||
"""This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T
|
||||
while inside a list comprehension.
|
||||
"""
|
||||
assert elem is not None
|
||||
return elem
|
||||
|
||||
|
||||
__all__ = [
|
||||
"_DocumentOut",
|
||||
"_DocumentType",
|
||||
"_DocumentTypeArg",
|
||||
"_Address",
|
||||
"_CollationIn",
|
||||
"_Pipeline",
|
||||
"strip_optional",
|
||||
]
|
||||
624
pymongo/asynchronous/uri_parser.py
Normal file
624
pymongo/asynchronous/uri_parser.py
Normal file
@ -0,0 +1,624 @@
|
||||
# Copyright 2011-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
"""Tools to parse and validate a MongoDB URI."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sized,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from pymongo.asynchronous.client_options import _parse_ssl_options
|
||||
from pymongo.asynchronous.common import (
|
||||
INTERNAL_URI_OPTION_NAME_MAP,
|
||||
SRV_SERVICE_NAME,
|
||||
URI_OPTIONS_DEPRECATION_MAP,
|
||||
_CaseInsensitiveDictionary,
|
||||
get_validated_options,
|
||||
)
|
||||
from pymongo.asynchronous.srv_resolver import _have_dnspython, _SrvResolver
|
||||
from pymongo.asynchronous.typings import _Address
|
||||
from pymongo.errors import ConfigurationError, InvalidURI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pyopenssl_context import SSLContext
|
||||
|
||||
_IS_SYNC = False
|
||||
SCHEME = "mongodb://"
|
||||
SCHEME_LEN = len(SCHEME)
|
||||
SRV_SCHEME = "mongodb+srv://"
|
||||
SRV_SCHEME_LEN = len(SRV_SCHEME)
|
||||
DEFAULT_PORT = 27017
|
||||
|
||||
|
||||
def _unquoted_percent(s: str) -> bool:
|
||||
"""Check for unescaped percent signs.
|
||||
|
||||
:param s: A string. `s` can have things like '%25', '%2525',
|
||||
and '%E2%85%A8' but cannot have unquoted percent like '%foo'.
|
||||
"""
|
||||
for i in range(len(s)):
|
||||
if s[i] == "%":
|
||||
sub = s[i : i + 3]
|
||||
# If unquoting yields the same string this means there was an
|
||||
# unquoted %.
|
||||
if unquote_plus(sub) == sub:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_userinfo(userinfo: str) -> tuple[str, str]:
|
||||
"""Validates the format of user information in a MongoDB URI.
|
||||
Reserved characters that are gen-delimiters (":", "/", "?", "#", "[",
|
||||
"]", "@") as per RFC 3986 must be escaped.
|
||||
|
||||
Returns a 2-tuple containing the unescaped username followed
|
||||
by the unescaped password.
|
||||
|
||||
:param userinfo: A string of the form <username>:<password>
|
||||
"""
|
||||
if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo):
|
||||
raise InvalidURI(
|
||||
"Username and password must be escaped according to "
|
||||
"RFC 3986, use urllib.parse.quote_plus"
|
||||
)
|
||||
|
||||
user, _, passwd = userinfo.partition(":")
|
||||
# No password is expected with GSSAPI authentication.
|
||||
if not user:
|
||||
raise InvalidURI("The empty string is not valid username.")
|
||||
|
||||
return unquote_plus(user), unquote_plus(passwd)
|
||||
|
||||
|
||||
def parse_ipv6_literal_host(
|
||||
entity: str, default_port: Optional[int]
|
||||
) -> tuple[str, Optional[Union[str, int]]]:
|
||||
"""Validates an IPv6 literal host:port string.
|
||||
|
||||
Returns a 2-tuple of IPv6 literal followed by port where
|
||||
port is default_port if it wasn't specified in entity.
|
||||
|
||||
:param entity: A string that represents an IPv6 literal enclosed
|
||||
in braces (e.g. '[::1]' or '[::1]:27017').
|
||||
:param default_port: The port number to use when one wasn't
|
||||
specified in entity.
|
||||
"""
|
||||
if entity.find("]") == -1:
|
||||
raise ValueError(
|
||||
"an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732."
|
||||
)
|
||||
i = entity.find("]:")
|
||||
if i == -1:
|
||||
return entity[1:-1], default_port
|
||||
return entity[1:i], entity[i + 2 :]
|
||||
|
||||
|
||||
def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address:
|
||||
"""Validates a host string
|
||||
|
||||
Returns a 2-tuple of host followed by port where port is default_port
|
||||
if it wasn't specified in the string.
|
||||
|
||||
:param entity: A host or host:port string where host could be a
|
||||
hostname or IP address.
|
||||
:param default_port: The port number to use when one wasn't
|
||||
specified in entity.
|
||||
"""
|
||||
host = entity
|
||||
port: Optional[Union[str, int]] = default_port
|
||||
if entity[0] == "[":
|
||||
host, port = parse_ipv6_literal_host(entity, default_port)
|
||||
elif entity.endswith(".sock"):
|
||||
return entity, default_port
|
||||
elif entity.find(":") != -1:
|
||||
if entity.count(":") > 1:
|
||||
raise ValueError(
|
||||
"Reserved characters such as ':' must be "
|
||||
"escaped according RFC 2396. An IPv6 "
|
||||
"address literal must be enclosed in '[' "
|
||||
"and ']' according to RFC 2732."
|
||||
)
|
||||
host, port = host.split(":", 1)
|
||||
if isinstance(port, str):
|
||||
if not port.isdigit() or int(port) > 65535 or int(port) <= 0:
|
||||
raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}")
|
||||
port = int(port)
|
||||
|
||||
# Normalize hostname to lowercase, since DNS is case-insensitive:
|
||||
# http://tools.ietf.org/html/rfc4343
|
||||
# This prevents useless rediscovery if "foo.com" is in the seed list but
|
||||
# "FOO.com" is in the hello response.
|
||||
return host.lower(), port
|
||||
|
||||
|
||||
# Options whose values are implicitly determined by tlsInsecure.
|
||||
_IMPLICIT_TLSINSECURE_OPTS = {
|
||||
"tlsallowinvalidcertificates",
|
||||
"tlsallowinvalidhostnames",
|
||||
"tlsdisableocspendpointcheck",
|
||||
}
|
||||
|
||||
|
||||
def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary:
|
||||
"""Helper method for split_options which creates the options dict.
|
||||
Also handles the creation of a list for the URI tag_sets/
|
||||
readpreferencetags portion, and the use of a unicode options string.
|
||||
"""
|
||||
options = _CaseInsensitiveDictionary()
|
||||
for uriopt in opts.split(delim):
|
||||
key, value = uriopt.split("=")
|
||||
if key.lower() == "readpreferencetags":
|
||||
options.setdefault(key, []).append(value)
|
||||
else:
|
||||
if key in options:
|
||||
warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2)
|
||||
if key.lower() == "authmechanismproperties":
|
||||
val = value
|
||||
else:
|
||||
val = unquote_plus(value)
|
||||
options[key] = val
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||
"""Raise appropriate errors when conflicting TLS options are present in
|
||||
the options dictionary.
|
||||
|
||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||
MongoDB URI options.
|
||||
"""
|
||||
# Implicitly defined options must not be explicitly specified.
|
||||
tlsinsecure = options.get("tlsinsecure")
|
||||
if tlsinsecure is not None:
|
||||
for opt in _IMPLICIT_TLSINSECURE_OPTS:
|
||||
if opt in options:
|
||||
err_msg = "URI options %s and %s cannot be specified simultaneously."
|
||||
raise InvalidURI(
|
||||
err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt))
|
||||
)
|
||||
|
||||
# Handle co-occurence of OCSP & tlsAllowInvalidCertificates options.
|
||||
tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates")
|
||||
if tlsallowinvalidcerts is not None:
|
||||
if "tlsdisableocspendpointcheck" in options:
|
||||
err_msg = "URI options %s and %s cannot be specified simultaneously."
|
||||
raise InvalidURI(
|
||||
err_msg
|
||||
% ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck"))
|
||||
)
|
||||
if tlsallowinvalidcerts is True:
|
||||
options["tlsdisableocspendpointcheck"] = True
|
||||
|
||||
# Handle co-occurence of CRL and OCSP-related options.
|
||||
tlscrlfile = options.get("tlscrlfile")
|
||||
if tlscrlfile is not None:
|
||||
for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"):
|
||||
if options.get(opt) is True:
|
||||
err_msg = "URI option %s=True cannot be specified when CRL checking is enabled."
|
||||
raise InvalidURI(err_msg % (opt,))
|
||||
|
||||
if "ssl" in options and "tls" in options:
|
||||
|
||||
def truth_value(val: Any) -> Any:
|
||||
if val in ("true", "false"):
|
||||
return val == "true"
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
return val
|
||||
|
||||
if truth_value(options.get("ssl")) != truth_value(options.get("tls")):
|
||||
err_msg = "Can not specify conflicting values for URI options %s and %s."
|
||||
raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls")))
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||
"""Issue appropriate warnings when deprecated options are present in the
|
||||
options dictionary. Removes deprecated option key, value pairs if the
|
||||
options dictionary is found to also have the renamed option.
|
||||
|
||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||
MongoDB URI options.
|
||||
"""
|
||||
for optname in list(options):
|
||||
if optname in URI_OPTIONS_DEPRECATION_MAP:
|
||||
mode, message = URI_OPTIONS_DEPRECATION_MAP[optname]
|
||||
if mode == "renamed":
|
||||
newoptname = message
|
||||
if newoptname in options:
|
||||
warn_msg = "Deprecated option '%s' ignored in favor of '%s'."
|
||||
warnings.warn(
|
||||
warn_msg % (options.cased_key(optname), options.cased_key(newoptname)),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
options.pop(optname)
|
||||
continue
|
||||
warn_msg = "Option '%s' is deprecated, use '%s' instead."
|
||||
warnings.warn(
|
||||
warn_msg % (options.cased_key(optname), newoptname),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif mode == "removed":
|
||||
warn_msg = "Option '%s' is deprecated. %s."
|
||||
warnings.warn(
|
||||
warn_msg % (options.cased_key(optname), message),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||
"""Normalizes option names in the options dictionary by converting them to
|
||||
their internally-used names.
|
||||
|
||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||
MongoDB URI options.
|
||||
"""
|
||||
# Expand the tlsInsecure option.
|
||||
tlsinsecure = options.get("tlsinsecure")
|
||||
if tlsinsecure is not None:
|
||||
for opt in _IMPLICIT_TLSINSECURE_OPTS:
|
||||
# Implicit options are logically the same as tlsInsecure.
|
||||
options[opt] = tlsinsecure
|
||||
|
||||
for optname in list(options):
|
||||
intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None)
|
||||
if intname is not None:
|
||||
options[intname] = options.pop(optname)
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]:
|
||||
"""Validates and normalizes options passed in a MongoDB URI.
|
||||
|
||||
Returns a new dictionary of validated and normalized options. If warn is
|
||||
False then errors will be thrown for invalid options, otherwise they will
|
||||
be ignored and a warning will be issued.
|
||||
|
||||
:param opts: A dict of MongoDB URI options.
|
||||
:param warn: If ``True`` then warnings will be logged and
|
||||
invalid options will be ignored. Otherwise invalid options will
|
||||
cause errors.
|
||||
"""
|
||||
return get_validated_options(opts, warn)
|
||||
|
||||
|
||||
def split_options(
|
||||
opts: str, validate: bool = True, warn: bool = False, normalize: bool = True
|
||||
) -> MutableMapping[str, Any]:
|
||||
"""Takes the options portion of a MongoDB URI, validates each option
|
||||
and returns the options in a dictionary.
|
||||
|
||||
:param opt: A string representing MongoDB URI options.
|
||||
:param validate: If ``True`` (the default), validate and normalize all
|
||||
options.
|
||||
:param warn: If ``False`` (the default), suppress all warnings raised
|
||||
during validation of options.
|
||||
:param normalize: If ``True`` (the default), renames all options to their
|
||||
internally-used names.
|
||||
"""
|
||||
and_idx = opts.find("&")
|
||||
semi_idx = opts.find(";")
|
||||
try:
|
||||
if and_idx >= 0 and semi_idx >= 0:
|
||||
raise InvalidURI("Can not mix '&' and ';' for option separators.")
|
||||
elif and_idx >= 0:
|
||||
options = _parse_options(opts, "&")
|
||||
elif semi_idx >= 0:
|
||||
options = _parse_options(opts, ";")
|
||||
elif opts.find("=") != -1:
|
||||
options = _parse_options(opts, None)
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise InvalidURI("MongoDB URI options are key=value pairs.") from None
|
||||
|
||||
options = _handle_security_options(options)
|
||||
|
||||
options = _handle_option_deprecations(options)
|
||||
|
||||
if normalize:
|
||||
options = _normalize_options(options)
|
||||
|
||||
if validate:
|
||||
options = cast(_CaseInsensitiveDictionary, validate_options(options, warn))
|
||||
if options.get("authsource") == "":
|
||||
raise InvalidURI("the authSource database cannot be an empty string")
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]:
|
||||
"""Takes a string of the form host1[:port],host2[:port]... and
|
||||
splits it into (host, port) tuples. If [:port] isn't present the
|
||||
default_port is used.
|
||||
|
||||
Returns a set of 2-tuples containing the host name (or IP) followed by
|
||||
port number.
|
||||
|
||||
:param hosts: A string of the form host1[:port],host2[:port],...
|
||||
:param default_port: The port number to use when one wasn't specified
|
||||
for a host.
|
||||
"""
|
||||
nodes = []
|
||||
for entity in hosts.split(","):
|
||||
if not entity:
|
||||
raise ConfigurationError("Empty host (or extra comma in host list).")
|
||||
port = default_port
|
||||
# Unix socket entities don't have ports
|
||||
if entity.endswith(".sock"):
|
||||
port = None
|
||||
nodes.append(parse_host(entity, port))
|
||||
return nodes
|
||||
|
||||
|
||||
# Prohibited characters in database name. DB names also can't have ".", but for
|
||||
# backward-compat we allow "db.collection" in URI.
|
||||
_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]")
|
||||
|
||||
_ALLOWED_TXT_OPTS = frozenset(
|
||||
["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"]
|
||||
)
|
||||
|
||||
|
||||
def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None:
|
||||
# Ensure directConnection was not True if there are multiple seeds.
|
||||
if len(nodes) > 1 and options.get("directconnection"):
|
||||
raise ConfigurationError("Cannot specify multiple hosts with directConnection=true")
|
||||
|
||||
if options.get("loadbalanced"):
|
||||
if len(nodes) > 1:
|
||||
raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true")
|
||||
if options.get("directconnection"):
|
||||
raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true")
|
||||
if options.get("replicaset"):
|
||||
raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true")
|
||||
|
||||
|
||||
def parse_uri(
|
||||
uri: str,
|
||||
default_port: Optional[int] = DEFAULT_PORT,
|
||||
validate: bool = True,
|
||||
warn: bool = False,
|
||||
normalize: bool = True,
|
||||
connect_timeout: Optional[float] = None,
|
||||
srv_service_name: Optional[str] = None,
|
||||
srv_max_hosts: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Parse and validate a MongoDB URI.
|
||||
|
||||
Returns a dict of the form::
|
||||
|
||||
{
|
||||
'nodelist': <list of (host, port) tuples>,
|
||||
'username': <username> or None,
|
||||
'password': <password> or None,
|
||||
'database': <database name> or None,
|
||||
'collection': <collection name> or None,
|
||||
'options': <dict of MongoDB URI options>,
|
||||
'fqdn': <fqdn of the MongoDB+SRV URI> or None
|
||||
}
|
||||
|
||||
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
|
||||
to build nodelist and options.
|
||||
|
||||
:param uri: The MongoDB URI to parse.
|
||||
:param default_port: The port number to use when one wasn't specified
|
||||
for a host in the URI.
|
||||
:param validate: If ``True`` (the default), validate and
|
||||
normalize all options. Default: ``True``.
|
||||
:param warn: When validating, if ``True`` then will warn
|
||||
the user then ignore any invalid options or values. If ``False``,
|
||||
validation will error when options are unsupported or values are
|
||||
invalid. Default: ``False``.
|
||||
:param normalize: If ``True``, convert names of URI options
|
||||
to their internally-used names. Default: ``True``.
|
||||
:param connect_timeout: The maximum time in milliseconds to
|
||||
wait for a response from the DNS server.
|
||||
:param srv_service_name: A custom SRV service name
|
||||
|
||||
.. versionchanged:: 4.6
|
||||
The delimiting slash (``/``) between hosts and connection options is now optional.
|
||||
For example, "mongodb://example.com?tls=true" is now a valid URI.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
To better follow RFC 3986, unquoted percent signs ("%") are no longer
|
||||
supported.
|
||||
|
||||
.. versionchanged:: 3.9
|
||||
Added the ``normalize`` parameter.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added support for mongodb+srv:// URIs.
|
||||
|
||||
.. versionchanged:: 3.5
|
||||
Return the original value of the ``readPreference`` MongoDB URI option
|
||||
instead of the validated read preference mode.
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
``warn`` added so invalid options can be ignored.
|
||||
"""
|
||||
if uri.startswith(SCHEME):
|
||||
is_srv = False
|
||||
scheme_free = uri[SCHEME_LEN:]
|
||||
elif uri.startswith(SRV_SCHEME):
|
||||
if not _have_dnspython():
|
||||
python_path = sys.executable or "python"
|
||||
raise ConfigurationError(
|
||||
'The "dnspython" module must be '
|
||||
"installed to use mongodb+srv:// URIs. "
|
||||
"To fix this error install pymongo again:\n "
|
||||
"%s -m pip install pymongo>=4.3" % (python_path)
|
||||
)
|
||||
is_srv = True
|
||||
scheme_free = uri[SRV_SCHEME_LEN:]
|
||||
else:
|
||||
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
|
||||
|
||||
if not scheme_free:
|
||||
raise InvalidURI("Must provide at least one hostname or IP.")
|
||||
|
||||
user = None
|
||||
passwd = None
|
||||
dbase = None
|
||||
collection = None
|
||||
options = _CaseInsensitiveDictionary()
|
||||
|
||||
host_plus_db_part, _, opts = scheme_free.partition("?")
|
||||
if "/" in host_plus_db_part:
|
||||
host_part, _, dbase = host_plus_db_part.partition("/")
|
||||
else:
|
||||
host_part = host_plus_db_part
|
||||
|
||||
if dbase:
|
||||
dbase = unquote_plus(dbase)
|
||||
if "." in dbase:
|
||||
dbase, collection = dbase.split(".", 1)
|
||||
if _BAD_DB_CHARS.search(dbase):
|
||||
raise InvalidURI('Bad database name "%s"' % dbase)
|
||||
else:
|
||||
dbase = None
|
||||
|
||||
if opts:
|
||||
options.update(split_options(opts, validate, warn, normalize))
|
||||
if srv_service_name is None:
|
||||
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
|
||||
if "@" in host_part:
|
||||
userinfo, _, hosts = host_part.rpartition("@")
|
||||
user, passwd = parse_userinfo(userinfo)
|
||||
else:
|
||||
hosts = host_part
|
||||
|
||||
if "/" in hosts:
|
||||
raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part)
|
||||
|
||||
hosts = unquote_plus(hosts)
|
||||
fqdn = None
|
||||
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
|
||||
if is_srv:
|
||||
if options.get("directConnection"):
|
||||
raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs")
|
||||
nodes = split_hosts(hosts, default_port=None)
|
||||
if len(nodes) != 1:
|
||||
raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname")
|
||||
fqdn, port = nodes[0]
|
||||
if port is not None:
|
||||
raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number")
|
||||
|
||||
# Use the connection timeout. connectTimeoutMS passed as a keyword
|
||||
# argument overrides the same option passed in the connection string.
|
||||
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
|
||||
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
|
||||
nodes = dns_resolver.get_hosts()
|
||||
dns_options = dns_resolver.get_options()
|
||||
if dns_options:
|
||||
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
|
||||
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
|
||||
raise ConfigurationError(
|
||||
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
|
||||
)
|
||||
for opt, val in parsed_dns_options.items():
|
||||
if opt not in options:
|
||||
options[opt] = val
|
||||
if options.get("loadBalanced") and srv_max_hosts:
|
||||
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
|
||||
if options.get("replicaSet") and srv_max_hosts:
|
||||
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
|
||||
if "tls" not in options and "ssl" not in options:
|
||||
options["tls"] = True if validate else "true"
|
||||
elif not is_srv and options.get("srvServiceName") is not None:
|
||||
raise ConfigurationError(
|
||||
"The srvServiceName option is only allowed with 'mongodb+srv://' URIs"
|
||||
)
|
||||
elif not is_srv and srv_max_hosts:
|
||||
raise ConfigurationError(
|
||||
"The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs"
|
||||
)
|
||||
else:
|
||||
nodes = split_hosts(hosts, default_port=default_port)
|
||||
|
||||
_check_options(nodes, options)
|
||||
|
||||
return {
|
||||
"nodelist": nodes,
|
||||
"username": user,
|
||||
"password": passwd,
|
||||
"database": dbase,
|
||||
"collection": collection,
|
||||
"options": options,
|
||||
"fqdn": fqdn,
|
||||
}
|
||||
|
||||
|
||||
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]:
|
||||
"""Parse KMS TLS connection options."""
|
||||
if not kms_tls_options:
|
||||
return {}
|
||||
if not isinstance(kms_tls_options, dict):
|
||||
raise TypeError("kms_tls_options must be a dict")
|
||||
contexts = {}
|
||||
for provider, options in kms_tls_options.items():
|
||||
if not isinstance(options, dict):
|
||||
raise TypeError(f'kms_tls_options["{provider}"] must be a dict')
|
||||
options.setdefault("tls", True)
|
||||
opts = _CaseInsensitiveDictionary(options)
|
||||
opts = _handle_security_options(opts)
|
||||
opts = _normalize_options(opts)
|
||||
opts = cast(_CaseInsensitiveDictionary, validate_options(opts))
|
||||
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts)
|
||||
if ssl_context is None:
|
||||
raise ConfigurationError("TLS is required for KMS providers")
|
||||
if allow_invalid_hostnames:
|
||||
raise ConfigurationError("Insecure TLS options prohibited")
|
||||
|
||||
for n in [
|
||||
"tlsInsecure",
|
||||
"tlsAllowInvalidCertificates",
|
||||
"tlsAllowInvalidHostnames",
|
||||
"tlsDisableCertificateRevocationCheck",
|
||||
]:
|
||||
if n in opts:
|
||||
raise ConfigurationError(f"Insecure TLS options prohibited: {n}")
|
||||
contexts[provider] = ssl_context
|
||||
return contexts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pprint
|
||||
|
||||
try:
|
||||
pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203
|
||||
except InvalidURI as exc:
|
||||
print(exc) # noqa: T201
|
||||
sys.exit(0)
|
||||
645
pymongo/auth.py
645
pymongo/auth.py
@ -1,4 +1,4 @@
|
||||
# Copyright 2013-present MongoDB, Inc.
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,645 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Authentication helpers."""
|
||||
"""Re-import of synchronous Auth API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import socket
|
||||
import typing
|
||||
from base64 import standard_b64decode, standard_b64encode
|
||||
from collections import namedtuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import quote
|
||||
from pymongo.synchronous.auth import * # noqa: F403
|
||||
from pymongo.synchronous.auth import __doc__ as original_doc
|
||||
|
||||
from bson.binary import Binary
|
||||
from pymongo.auth_aws import _authenticate_aws
|
||||
from pymongo.auth_oidc import (
|
||||
_authenticate_oidc,
|
||||
_get_authenticator,
|
||||
_OIDCAzureCallback,
|
||||
_OIDCGCPCallback,
|
||||
_OIDCProperties,
|
||||
_OIDCTestCallback,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.saslprep import saslprep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.pool import Connection
|
||||
|
||||
HAVE_KERBEROS = True
|
||||
_USE_PRINCIPAL = False
|
||||
try:
|
||||
import winkerberos as kerberos # type:ignore[import]
|
||||
|
||||
if tuple(map(int, kerberos.__version__.split(".")[:2])) >= (0, 5):
|
||||
_USE_PRINCIPAL = True
|
||||
except ImportError:
|
||||
try:
|
||||
import kerberos # type:ignore[import]
|
||||
except ImportError:
|
||||
HAVE_KERBEROS = False
|
||||
|
||||
|
||||
MECHANISMS = frozenset(
|
||||
[
|
||||
"GSSAPI",
|
||||
"MONGODB-CR",
|
||||
"MONGODB-OIDC",
|
||||
"MONGODB-X509",
|
||||
"MONGODB-AWS",
|
||||
"PLAIN",
|
||||
"SCRAM-SHA-1",
|
||||
"SCRAM-SHA-256",
|
||||
"DEFAULT",
|
||||
]
|
||||
)
|
||||
"""The authentication mechanisms supported by PyMongo."""
|
||||
|
||||
|
||||
class _Cache:
|
||||
__slots__ = ("data",)
|
||||
|
||||
_hash_val = hash("_Cache")
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = None
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
# Two instances must always compare equal.
|
||||
if isinstance(other, _Cache):
|
||||
return True
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
if isinstance(other, _Cache):
|
||||
return False
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self._hash_val
|
||||
|
||||
|
||||
MongoCredential = namedtuple(
|
||||
"MongoCredential",
|
||||
["mechanism", "source", "username", "password", "mechanism_properties", "cache"],
|
||||
)
|
||||
"""A hashable namedtuple of values used for authentication."""
|
||||
|
||||
|
||||
GSSAPIProperties = namedtuple(
|
||||
"GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"]
|
||||
)
|
||||
"""Mechanism properties for GSSAPI authentication."""
|
||||
|
||||
|
||||
_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
|
||||
"""Mechanism properties for MONGODB-AWS authentication."""
|
||||
|
||||
|
||||
def _build_credentials_tuple(
|
||||
mech: str,
|
||||
source: Optional[str],
|
||||
user: str,
|
||||
passwd: str,
|
||||
extra: Mapping[str, Any],
|
||||
database: Optional[str],
|
||||
) -> MongoCredential:
|
||||
"""Build and return a mechanism specific credentials tuple."""
|
||||
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
|
||||
raise ConfigurationError(f"{mech} requires a username.")
|
||||
if mech == "GSSAPI":
|
||||
if source is not None and source != "$external":
|
||||
raise ValueError("authentication source must be $external or None for GSSAPI")
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
service_name = properties.get("SERVICE_NAME", "mongodb")
|
||||
canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False))
|
||||
service_realm = properties.get("SERVICE_REALM")
|
||||
props = GSSAPIProperties(
|
||||
service_name=service_name,
|
||||
canonicalize_host_name=canonicalize,
|
||||
service_realm=service_realm,
|
||||
)
|
||||
# Source is always $external.
|
||||
return MongoCredential(mech, "$external", user, passwd, props, None)
|
||||
elif mech == "MONGODB-X509":
|
||||
if passwd is not None:
|
||||
raise ConfigurationError("Passwords are not supported by MONGODB-X509")
|
||||
if source is not None and source != "$external":
|
||||
raise ValueError("authentication source must be $external or None for MONGODB-X509")
|
||||
# Source is always $external, user can be None.
|
||||
return MongoCredential(mech, "$external", user, None, None, None)
|
||||
elif mech == "MONGODB-AWS":
|
||||
if user is not None and passwd is None:
|
||||
raise ConfigurationError("username without a password is not supported by MONGODB-AWS")
|
||||
if source is not None and source != "$external":
|
||||
raise ConfigurationError(
|
||||
"authentication source must be $external or None for MONGODB-AWS"
|
||||
)
|
||||
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
aws_session_token = properties.get("AWS_SESSION_TOKEN")
|
||||
aws_props = _AWSProperties(aws_session_token=aws_session_token)
|
||||
# user can be None for temporary link-local EC2 credentials.
|
||||
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
|
||||
elif mech == "MONGODB-OIDC":
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
callback = properties.get("OIDC_CALLBACK")
|
||||
human_callback = properties.get("OIDC_HUMAN_CALLBACK")
|
||||
environ = properties.get("ENVIRONMENT")
|
||||
token_resource = properties.get("TOKEN_RESOURCE", "")
|
||||
default_allowed = [
|
||||
"*.mongodb.net",
|
||||
"*.mongodb-dev.net",
|
||||
"*.mongodb-qa.net",
|
||||
"*.mongodbgov.net",
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
]
|
||||
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
|
||||
msg = (
|
||||
"authentication with MONGODB-OIDC requires providing either a callback or a environment"
|
||||
)
|
||||
if passwd is not None:
|
||||
msg = "password is not supported by MONGODB-OIDC"
|
||||
raise ConfigurationError(msg)
|
||||
if callback or human_callback:
|
||||
if environ is not None:
|
||||
raise ConfigurationError(msg)
|
||||
if callback and human_callback:
|
||||
msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK"
|
||||
raise ConfigurationError(msg)
|
||||
elif environ is not None:
|
||||
if environ == "test":
|
||||
if user is not None:
|
||||
msg = "test environment for MONGODB-OIDC does not support username"
|
||||
raise ConfigurationError(msg)
|
||||
callback = _OIDCTestCallback()
|
||||
elif environ == "azure":
|
||||
passwd = None
|
||||
if not token_resource:
|
||||
raise ConfigurationError(
|
||||
"Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
|
||||
)
|
||||
callback = _OIDCAzureCallback(token_resource)
|
||||
elif environ == "gcp":
|
||||
passwd = None
|
||||
if not token_resource:
|
||||
raise ConfigurationError(
|
||||
"GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
|
||||
)
|
||||
callback = _OIDCGCPCallback(token_resource)
|
||||
else:
|
||||
raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}")
|
||||
else:
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
oidc_props = _OIDCProperties(
|
||||
callback=callback,
|
||||
human_callback=human_callback,
|
||||
environment=environ,
|
||||
allowed_hosts=allowed_hosts,
|
||||
token_resource=token_resource,
|
||||
username=user,
|
||||
)
|
||||
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
|
||||
|
||||
elif mech == "PLAIN":
|
||||
source_database = source or database or "$external"
|
||||
return MongoCredential(mech, source_database, user, passwd, None, None)
|
||||
else:
|
||||
source_database = source or database or "admin"
|
||||
if passwd is None:
|
||||
raise ConfigurationError("A password is required.")
|
||||
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
|
||||
|
||||
|
||||
def _xor(fir: bytes, sec: bytes) -> bytes:
|
||||
"""XOR two byte strings together."""
|
||||
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
|
||||
|
||||
|
||||
def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]:
|
||||
"""Split a scram response into key, value pairs."""
|
||||
return dict(
|
||||
typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1))
|
||||
for item in response.split(b",")
|
||||
)
|
||||
|
||||
|
||||
def _authenticate_scram_start(
|
||||
credentials: MongoCredential, mechanism: str
|
||||
) -> tuple[bytes, bytes, MutableMapping[str, Any]]:
|
||||
username = credentials.username
|
||||
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
|
||||
nonce = standard_b64encode(os.urandom(32))
|
||||
first_bare = b"n=" + user + b",r=" + nonce
|
||||
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": mechanism,
|
||||
"payload": Binary(b"n,," + first_bare),
|
||||
"autoAuthorize": 1,
|
||||
"options": {"skipEmptyExchange": True},
|
||||
}
|
||||
return nonce, first_bare, cmd
|
||||
|
||||
|
||||
def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanism: str) -> None:
|
||||
"""Authenticate using SCRAM."""
|
||||
username = credentials.username
|
||||
if mechanism == "SCRAM-SHA-256":
|
||||
digest = "sha256"
|
||||
digestmod = hashlib.sha256
|
||||
data = saslprep(credentials.password).encode("utf-8")
|
||||
else:
|
||||
digest = "sha1"
|
||||
digestmod = hashlib.sha1
|
||||
data = _password_digest(username, credentials.password).encode("utf-8")
|
||||
source = credentials.source
|
||||
cache = credentials.cache
|
||||
|
||||
# Make local
|
||||
_hmac = hmac.HMAC
|
||||
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
assert isinstance(ctx, _ScramContext)
|
||||
assert ctx.scram_data is not None
|
||||
nonce, first_bare = ctx.scram_data
|
||||
res = ctx.speculative_authenticate
|
||||
else:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
|
||||
res = conn.command(source, cmd)
|
||||
|
||||
assert res is not None
|
||||
server_first = res["payload"]
|
||||
parsed = _parse_scram_response(server_first)
|
||||
iterations = int(parsed[b"i"])
|
||||
if iterations < 4096:
|
||||
raise OperationFailure("Server returned an invalid iteration count.")
|
||||
salt = parsed[b"s"]
|
||||
rnonce = parsed[b"r"]
|
||||
if not rnonce.startswith(nonce):
|
||||
raise OperationFailure("Server returned an invalid nonce.")
|
||||
|
||||
without_proof = b"c=biws,r=" + rnonce
|
||||
if cache.data:
|
||||
client_key, server_key, csalt, citerations = cache.data
|
||||
else:
|
||||
client_key, server_key, csalt, citerations = None, None, None, None
|
||||
|
||||
# Salt and / or iterations could change for a number of different
|
||||
# reasons. Either changing invalidates the cache.
|
||||
if not client_key or salt != csalt or iterations != citerations:
|
||||
salted_pass = hashlib.pbkdf2_hmac(digest, data, standard_b64decode(salt), iterations)
|
||||
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
|
||||
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
|
||||
cache.data = (client_key, server_key, salt, iterations)
|
||||
stored_key = digestmod(client_key).digest()
|
||||
auth_msg = b",".join((first_bare, server_first, without_proof))
|
||||
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()
|
||||
client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig))
|
||||
client_final = b",".join((without_proof, client_proof))
|
||||
|
||||
server_sig = standard_b64encode(_hmac(server_key, auth_msg, digestmod).digest())
|
||||
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": res["conversationId"],
|
||||
"payload": Binary(client_final),
|
||||
}
|
||||
res = conn.command(source, cmd)
|
||||
|
||||
parsed = _parse_scram_response(res["payload"])
|
||||
if not hmac.compare_digest(parsed[b"v"], server_sig):
|
||||
raise OperationFailure("Server returned an invalid signature.")
|
||||
|
||||
# A third empty challenge may be required if the server does not support
|
||||
# skipEmptyExchange: SERVER-44857.
|
||||
if not res["done"]:
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": res["conversationId"],
|
||||
"payload": Binary(b""),
|
||||
}
|
||||
res = conn.command(source, cmd)
|
||||
if not res["done"]:
|
||||
raise OperationFailure("SASL conversation failed to complete.")
|
||||
|
||||
|
||||
def _password_digest(username: str, password: str) -> str:
|
||||
"""Get a password digest to use for authentication."""
|
||||
if not isinstance(password, str):
|
||||
raise TypeError("password must be an instance of str")
|
||||
if len(password) == 0:
|
||||
raise ValueError("password can't be empty")
|
||||
if not isinstance(username, str):
|
||||
raise TypeError("username must be an instance of str")
|
||||
|
||||
md5hash = hashlib.md5() # noqa: S324
|
||||
data = f"{username}:mongo:{password}"
|
||||
md5hash.update(data.encode("utf-8"))
|
||||
return md5hash.hexdigest()
|
||||
|
||||
|
||||
def _auth_key(nonce: str, username: str, password: str) -> str:
|
||||
"""Get an auth key to use for authentication."""
|
||||
digest = _password_digest(username, password)
|
||||
md5hash = hashlib.md5() # noqa: S324
|
||||
data = f"{nonce}{username}{digest}"
|
||||
md5hash.update(data.encode("utf-8"))
|
||||
return md5hash.hexdigest()
|
||||
|
||||
|
||||
def _canonicalize_hostname(hostname: str) -> str:
|
||||
"""Canonicalize hostname following MIT-krb5 behavior."""
|
||||
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
|
||||
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
|
||||
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
|
||||
)[0]
|
||||
|
||||
try:
|
||||
name = socket.getnameinfo(sockaddr, socket.NI_NAMEREQD)
|
||||
except socket.gaierror:
|
||||
return canonname.lower()
|
||||
|
||||
return name[0].lower()
|
||||
|
||||
|
||||
def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using GSSAPI."""
|
||||
if not HAVE_KERBEROS:
|
||||
raise ConfigurationError(
|
||||
'The "kerberos" module must be installed to use GSSAPI authentication.'
|
||||
)
|
||||
|
||||
try:
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
props = credentials.mechanism_properties
|
||||
# Starting here and continuing through the while loop below - establish
|
||||
# the security context. See RFC 4752, Section 3.1, first paragraph.
|
||||
host = conn.address[0]
|
||||
if props.canonicalize_host_name:
|
||||
host = _canonicalize_hostname(host)
|
||||
service = props.service_name + "@" + host
|
||||
if props.service_realm is not None:
|
||||
service = service + "@" + props.service_realm
|
||||
|
||||
if password is not None:
|
||||
if _USE_PRINCIPAL:
|
||||
# Note that, though we use unquote_plus for unquoting URI
|
||||
# options, we use quote here. Microsoft's UrlUnescape (used
|
||||
# by WinKerberos) doesn't support +.
|
||||
principal = ":".join((quote(username), quote(password)))
|
||||
result, ctx = kerberos.authGSSClientInit(
|
||||
service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG
|
||||
)
|
||||
else:
|
||||
if "@" in username:
|
||||
user, domain = username.split("@", 1)
|
||||
else:
|
||||
user, domain = username, None
|
||||
result, ctx = kerberos.authGSSClientInit(
|
||||
service,
|
||||
gssflags=kerberos.GSS_C_MUTUAL_FLAG,
|
||||
user=user,
|
||||
domain=domain,
|
||||
password=password,
|
||||
)
|
||||
else:
|
||||
result, ctx = kerberos.authGSSClientInit(service, gssflags=kerberos.GSS_C_MUTUAL_FLAG)
|
||||
|
||||
if result != kerberos.AUTH_GSS_COMPLETE:
|
||||
raise OperationFailure("Kerberos context failed to initialize.")
|
||||
|
||||
try:
|
||||
# pykerberos uses a weird mix of exceptions and return values
|
||||
# to indicate errors.
|
||||
# 0 == continue, 1 == complete, -1 == error
|
||||
# Only authGSSClientStep can return 0.
|
||||
if kerberos.authGSSClientStep(ctx, "") != 0:
|
||||
raise OperationFailure("Unknown kerberos failure in step function.")
|
||||
|
||||
# Start a SASL conversation with mongod/s
|
||||
# Note: pykerberos deals with base64 encoded byte strings.
|
||||
# Since mongo accepts base64 strings as the payload we don't
|
||||
# have to use bson.binary.Binary.
|
||||
payload = kerberos.authGSSClientResponse(ctx)
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": "GSSAPI",
|
||||
"payload": payload,
|
||||
"autoAuthorize": 1,
|
||||
}
|
||||
response = conn.command("$external", cmd)
|
||||
|
||||
# Limit how many times we loop to catch protocol / library issues
|
||||
for _ in range(10):
|
||||
result = kerberos.authGSSClientStep(ctx, str(response["payload"]))
|
||||
if result == -1:
|
||||
raise OperationFailure("Unknown kerberos failure in step function.")
|
||||
|
||||
payload = kerberos.authGSSClientResponse(ctx) or ""
|
||||
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": response["conversationId"],
|
||||
"payload": payload,
|
||||
}
|
||||
response = conn.command("$external", cmd)
|
||||
|
||||
if result == kerberos.AUTH_GSS_COMPLETE:
|
||||
break
|
||||
else:
|
||||
raise OperationFailure("Kerberos authentication failed to complete.")
|
||||
|
||||
# Once the security context is established actually authenticate.
|
||||
# See RFC 4752, Section 3.1, last two paragraphs.
|
||||
if kerberos.authGSSClientUnwrap(ctx, str(response["payload"])) != 1:
|
||||
raise OperationFailure("Unknown kerberos failure during GSS_Unwrap step.")
|
||||
|
||||
if kerberos.authGSSClientWrap(ctx, kerberos.authGSSClientResponse(ctx), username) != 1:
|
||||
raise OperationFailure("Unknown kerberos failure during GSS_Wrap step.")
|
||||
|
||||
payload = kerberos.authGSSClientResponse(ctx)
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": response["conversationId"],
|
||||
"payload": payload,
|
||||
}
|
||||
conn.command("$external", cmd)
|
||||
|
||||
finally:
|
||||
kerberos.authGSSClientClean(ctx)
|
||||
|
||||
except kerberos.KrbError as exc:
|
||||
raise OperationFailure(str(exc)) from None
|
||||
|
||||
|
||||
def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using SASL PLAIN (RFC 4616)"""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
payload = (f"\x00{username}\x00{password}").encode()
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": "PLAIN",
|
||||
"payload": Binary(payload),
|
||||
"autoAuthorize": 1,
|
||||
}
|
||||
conn.command(source, cmd)
|
||||
|
||||
|
||||
def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using MONGODB-X509."""
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
# MONGODB-X509 is done after the speculative auth step.
|
||||
return
|
||||
|
||||
cmd = _X509Context(credentials, conn.address).speculate_command()
|
||||
conn.command("$external", cmd)
|
||||
|
||||
|
||||
def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using MONGODB-CR."""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
# Get a nonce
|
||||
response = conn.command(source, {"getnonce": 1})
|
||||
nonce = response["nonce"]
|
||||
key = _auth_key(nonce, username, password)
|
||||
|
||||
# Actually authenticate
|
||||
query = {"authenticate": 1, "user": username, "nonce": nonce, "key": key}
|
||||
conn.command(source, query)
|
||||
|
||||
|
||||
def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None:
|
||||
if conn.max_wire_version >= 7:
|
||||
if conn.negotiated_mechs:
|
||||
mechs = conn.negotiated_mechs
|
||||
else:
|
||||
source = credentials.source
|
||||
cmd = conn.hello_cmd()
|
||||
cmd["saslSupportedMechs"] = source + "." + credentials.username
|
||||
mechs = conn.command(source, cmd, publish_events=False).get("saslSupportedMechs", [])
|
||||
if "SCRAM-SHA-256" in mechs:
|
||||
return _authenticate_scram(credentials, conn, "SCRAM-SHA-256")
|
||||
else:
|
||||
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
|
||||
else:
|
||||
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
|
||||
|
||||
|
||||
_AUTH_MAP: Mapping[str, Callable[..., None]] = {
|
||||
"GSSAPI": _authenticate_gssapi,
|
||||
"MONGODB-CR": _authenticate_mongo_cr,
|
||||
"MONGODB-X509": _authenticate_x509,
|
||||
"MONGODB-AWS": _authenticate_aws,
|
||||
"MONGODB-OIDC": _authenticate_oidc, # type:ignore[dict-item]
|
||||
"PLAIN": _authenticate_plain,
|
||||
"SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"),
|
||||
"SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"),
|
||||
"DEFAULT": _authenticate_default,
|
||||
}
|
||||
|
||||
|
||||
class _AuthContext:
|
||||
def __init__(self, credentials: MongoCredential, address: tuple[str, int]) -> None:
|
||||
self.credentials = credentials
|
||||
self.speculative_authenticate: Optional[Mapping[str, Any]] = None
|
||||
self.address = address
|
||||
|
||||
@staticmethod
|
||||
def from_credentials(
|
||||
creds: MongoCredential, address: tuple[str, int]
|
||||
) -> Optional[_AuthContext]:
|
||||
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
|
||||
if spec_cls:
|
||||
return cast(_AuthContext, spec_cls(creds, address))
|
||||
return None
|
||||
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_response(self, hello: Hello[Mapping[str, Any]]) -> None:
|
||||
self.speculative_authenticate = hello.speculative_authenticate
|
||||
|
||||
def speculate_succeeded(self) -> bool:
|
||||
return bool(self.speculative_authenticate)
|
||||
|
||||
|
||||
class _ScramContext(_AuthContext):
|
||||
def __init__(
|
||||
self, credentials: MongoCredential, address: tuple[str, int], mechanism: str
|
||||
) -> None:
|
||||
super().__init__(credentials, address)
|
||||
self.scram_data: Optional[tuple[bytes, bytes]] = None
|
||||
self.mechanism = mechanism
|
||||
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism)
|
||||
# The 'db' field is included only on the speculative command.
|
||||
cmd["db"] = self.credentials.source
|
||||
# Save for later use.
|
||||
self.scram_data = (nonce, first_bare)
|
||||
return cmd
|
||||
|
||||
|
||||
class _X509Context(_AuthContext):
|
||||
def speculate_command(self) -> MutableMapping[str, Any]:
|
||||
cmd = {"authenticate": 1, "mechanism": "MONGODB-X509"}
|
||||
if self.credentials.username is not None:
|
||||
cmd["user"] = self.credentials.username
|
||||
return cmd
|
||||
|
||||
|
||||
class _OIDCContext(_AuthContext):
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
authenticator = _get_authenticator(self.credentials, self.address)
|
||||
cmd = authenticator.get_spec_auth_cmd()
|
||||
if cmd is None:
|
||||
return None
|
||||
cmd["db"] = self.credentials.source
|
||||
return cmd
|
||||
|
||||
|
||||
_SPECULATIVE_AUTH_MAP: Mapping[str, Any] = {
|
||||
"MONGODB-X509": _X509Context,
|
||||
"SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"),
|
||||
"SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
|
||||
"MONGODB-OIDC": _OIDCContext,
|
||||
"DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
|
||||
}
|
||||
|
||||
|
||||
def authenticate(
|
||||
credentials: MongoCredential, conn: Connection, reauthenticate: bool = False
|
||||
) -> None:
|
||||
"""Authenticate connection."""
|
||||
mechanism = credentials.mechanism
|
||||
auth_func = _AUTH_MAP[mechanism]
|
||||
if mechanism == "MONGODB-OIDC":
|
||||
_authenticate_oidc(credentials, conn, reauthenticate)
|
||||
else:
|
||||
auth_func(credentials, conn)
|
||||
__doc__ = original_doc
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright 2023-present MongoDB, Inc.
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,354 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MONGODB-OIDC Authentication helpers."""
|
||||
"""Re-import of synchronous AuthOIDC API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
|
||||
from urllib.parse import quote
|
||||
from pymongo.synchronous.auth_oidc import * # noqa: F403
|
||||
from pymongo.synchronous.auth_oidc import __doc__ as original_doc
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
from pymongo._azure_helpers import _get_azure_response
|
||||
from pymongo._csot import remaining
|
||||
from pymongo._gcp_helpers import _get_gcp_response
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.helpers import _AUTHENTICATION_FAILURE_CODE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.auth import MongoCredential
|
||||
from pymongo.pool import Connection
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCIdPInfo:
|
||||
issuer: str
|
||||
clientId: Optional[str] = field(default=None)
|
||||
requestScopes: Optional[list[str]] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCCallbackContext:
|
||||
timeout_seconds: float
|
||||
username: str
|
||||
version: int
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[OIDCIdPInfo] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCCallbackResult:
|
||||
access_token: str
|
||||
expires_in_seconds: Optional[float] = field(default=None)
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
class OIDCCallback(abc.ABC):
|
||||
"""A base class for defining OIDC callbacks."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
"""Convert the given BSON value into our own type."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCProperties:
|
||||
callback: Optional[OIDCCallback] = field(default=None)
|
||||
human_callback: Optional[OIDCCallback] = field(default=None)
|
||||
environment: Optional[str] = field(default=None)
|
||||
allowed_hosts: list[str] = field(default_factory=list)
|
||||
token_resource: Optional[str] = field(default=None)
|
||||
username: str = ""
|
||||
|
||||
|
||||
"""Mechanism properties for MONGODB-OIDC authentication."""
|
||||
|
||||
TOKEN_BUFFER_MINUTES = 5
|
||||
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
|
||||
CALLBACK_VERSION = 1
|
||||
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
|
||||
TIME_BETWEEN_CALLS_SECONDS = 0.1
|
||||
|
||||
|
||||
def _get_authenticator(
|
||||
credentials: MongoCredential, address: tuple[str, int]
|
||||
) -> _OIDCAuthenticator:
|
||||
if credentials.cache.data:
|
||||
return credentials.cache.data
|
||||
|
||||
# Extract values.
|
||||
principal_name = credentials.username
|
||||
properties = credentials.mechanism_properties
|
||||
|
||||
# Validate that the address is allowed.
|
||||
if not properties.environment:
|
||||
found = False
|
||||
allowed_hosts = properties.allowed_hosts
|
||||
for patt in allowed_hosts:
|
||||
if patt == address[0]:
|
||||
found = True
|
||||
elif patt.startswith("*.") and address[0].endswith(patt[1:]):
|
||||
found = True
|
||||
if not found:
|
||||
raise ConfigurationError(
|
||||
f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
|
||||
)
|
||||
|
||||
# Get or create the cache data.
|
||||
credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties)
|
||||
return credentials.cache.data
|
||||
|
||||
|
||||
class _OIDCTestCallback(OIDCCallback):
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
token_file = os.environ.get("OIDC_TOKEN_FILE")
|
||||
if not token_file:
|
||||
raise RuntimeError(
|
||||
'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set'
|
||||
)
|
||||
with open(token_file) as fid:
|
||||
return OIDCCallbackResult(access_token=fid.read().strip())
|
||||
|
||||
|
||||
class _OIDCAzureCallback(OIDCCallback):
|
||||
def __init__(self, token_resource: str) -> None:
|
||||
self.token_resource = quote(token_resource)
|
||||
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds)
|
||||
return OIDCCallbackResult(
|
||||
access_token=resp["access_token"], expires_in_seconds=resp["expires_in"]
|
||||
)
|
||||
|
||||
|
||||
class _OIDCGCPCallback(OIDCCallback):
|
||||
def __init__(self, token_resource: str) -> None:
|
||||
self.token_resource = quote(token_resource)
|
||||
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
|
||||
return OIDCCallbackResult(access_token=resp["access_token"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCAuthenticator:
|
||||
username: str
|
||||
properties: _OIDCProperties
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
access_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[OIDCIdPInfo] = field(default=None)
|
||||
token_gen_id: int = field(default=0)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
last_call_time: float = field(default=0)
|
||||
|
||||
def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
"""Handle a reauthenticate from the server."""
|
||||
# Invalidate the token for the connection.
|
||||
self._invalidate(conn)
|
||||
# Call the appropriate auth logic for the callback type.
|
||||
if self.properties.callback:
|
||||
return self._authenticate_machine(conn)
|
||||
return self._authenticate_human(conn)
|
||||
|
||||
def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
"""Handle an initial authenticate request."""
|
||||
# First handle speculative auth.
|
||||
# If it succeeded, we are done.
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
resp = ctx.speculative_authenticate
|
||||
if resp and resp["done"]:
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
return resp
|
||||
|
||||
# If spec auth failed, call the appropriate auth logic for the callback type.
|
||||
# We cannot assume that the token is invalid, because a proxy may have been
|
||||
# involved that stripped the speculative auth information.
|
||||
if self.properties.callback:
|
||||
return self._authenticate_machine(conn)
|
||||
return self._authenticate_human(conn)
|
||||
|
||||
def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]:
|
||||
"""Get the appropriate speculative auth command."""
|
||||
if not self.access_token:
|
||||
return None
|
||||
return self._get_start_command({"jwt": self.access_token})
|
||||
|
||||
def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]:
|
||||
# If there is a cached access token, try to authenticate with it. If
|
||||
# authentication fails with error code 18, invalidate the access token,
|
||||
# fetch a new access token, and try to authenticate again. If authentication
|
||||
# fails for any other reason, raise the error to the user.
|
||||
if self.access_token:
|
||||
try:
|
||||
return self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
return self._authenticate_machine(conn)
|
||||
raise
|
||||
return self._sasl_start_jwt(conn)
|
||||
|
||||
def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
# If we have a cached access token, try a JwtStepRequest.
|
||||
# authentication fails with error code 18, invalidate the access token,
|
||||
# and try to authenticate again. If authentication fails for any other
|
||||
# reason, raise the error to the user.
|
||||
if self.access_token:
|
||||
try:
|
||||
return self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
return self._authenticate_human(conn)
|
||||
raise
|
||||
|
||||
# If we have a cached refresh token, try a JwtStepRequest with that.
|
||||
# If authentication fails with error code 18, invalidate the access and
|
||||
# refresh tokens, and try to authenticate again. If authentication fails for
|
||||
# any other reason, raise the error to the user.
|
||||
if self.refresh_token:
|
||||
try:
|
||||
return self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
self.refresh_token = None
|
||||
return self._authenticate_human(conn)
|
||||
raise
|
||||
|
||||
# Start a new Two-Step SASL conversation.
|
||||
# Run a PrincipalStepRequest to get the IdpInfo.
|
||||
cmd = self._get_start_command(None)
|
||||
start_resp = self._run_command(conn, cmd)
|
||||
# Attempt to authenticate with a JwtStepRequest.
|
||||
return self._sasl_continue_jwt(conn, start_resp)
|
||||
|
||||
def _get_access_token(self) -> Optional[str]:
|
||||
properties = self.properties
|
||||
cb: Union[None, OIDCCallback]
|
||||
resp: OIDCCallbackResult
|
||||
|
||||
is_human = properties.human_callback is not None
|
||||
if is_human and self.idp_info is None:
|
||||
return None
|
||||
|
||||
if properties.callback:
|
||||
cb = properties.callback
|
||||
if properties.human_callback:
|
||||
cb = properties.human_callback
|
||||
|
||||
prev_token = self.access_token
|
||||
if prev_token:
|
||||
return prev_token
|
||||
|
||||
if cb is None and not prev_token:
|
||||
return None
|
||||
|
||||
if not prev_token and cb is not None:
|
||||
with self.lock:
|
||||
# See if the token was changed while we were waiting for the
|
||||
# lock.
|
||||
new_token = self.access_token
|
||||
if new_token != prev_token:
|
||||
return new_token
|
||||
|
||||
# Ensure that we are waiting a min time between callback invocations.
|
||||
delta = time.time() - self.last_call_time
|
||||
if delta < TIME_BETWEEN_CALLS_SECONDS:
|
||||
time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
|
||||
self.last_call_time = time.time()
|
||||
|
||||
if is_human:
|
||||
timeout = HUMAN_CALLBACK_TIMEOUT_SECONDS
|
||||
assert self.idp_info is not None
|
||||
else:
|
||||
timeout = int(remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS)
|
||||
context = OIDCCallbackContext(
|
||||
timeout_seconds=timeout,
|
||||
version=CALLBACK_VERSION,
|
||||
refresh_token=self.refresh_token,
|
||||
idp_info=self.idp_info,
|
||||
username=self.properties.username,
|
||||
)
|
||||
resp = cb.fetch(context)
|
||||
if not isinstance(resp, OIDCCallbackResult):
|
||||
raise ValueError("Callback result must be of type OIDCCallbackResult")
|
||||
self.refresh_token = resp.refresh_token
|
||||
self.access_token = resp.access_token
|
||||
self.token_gen_id += 1
|
||||
|
||||
return self.access_token
|
||||
|
||||
def _run_command(self, conn: Connection, cmd: MutableMapping[str, Any]) -> Mapping[str, Any]:
|
||||
try:
|
||||
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
self._invalidate(conn)
|
||||
raise
|
||||
|
||||
def _is_auth_error(self, err: Exception) -> bool:
|
||||
if not isinstance(err, OperationFailure):
|
||||
return False
|
||||
return err.code == _AUTHENTICATION_FAILURE_CODE
|
||||
|
||||
def _invalidate(self, conn: Connection) -> None:
|
||||
# Ignore the invalidation if a token gen id is given and is less than our
|
||||
# current token gen id.
|
||||
token_gen_id = conn.oidc_token_gen_id or 0
|
||||
if token_gen_id is not None and token_gen_id < self.token_gen_id:
|
||||
return
|
||||
self.access_token = None
|
||||
|
||||
def _sasl_continue_jwt(
|
||||
self, conn: Connection, start_resp: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
self.access_token = None
|
||||
self.refresh_token = None
|
||||
start_payload: dict = bson.decode(start_resp["payload"])
|
||||
if "issuer" in start_payload:
|
||||
self.idp_info = OIDCIdPInfo(**start_payload)
|
||||
access_token = self._get_access_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
|
||||
return self._run_command(conn, cmd)
|
||||
|
||||
def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]:
|
||||
access_token = self._get_access_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
cmd = self._get_start_command({"jwt": access_token})
|
||||
return self._run_command(conn, cmd)
|
||||
|
||||
def _get_start_command(self, payload: Optional[Mapping[str, Any]]) -> MutableMapping[str, Any]:
|
||||
if payload is None:
|
||||
principal_name = self.username
|
||||
if principal_name:
|
||||
payload = {"n": principal_name}
|
||||
else:
|
||||
payload = {}
|
||||
bin_payload = Binary(bson.encode(payload))
|
||||
return {"saslStart": 1, "mechanism": "MONGODB-OIDC", "payload": bin_payload}
|
||||
|
||||
def _get_continue_command(
|
||||
self, payload: Mapping[str, Any], start_resp: Mapping[str, Any]
|
||||
) -> MutableMapping[str, Any]:
|
||||
bin_payload = Binary(bson.encode(payload))
|
||||
return {
|
||||
"saslContinue": 1,
|
||||
"payload": bin_payload,
|
||||
"conversationId": start_resp["conversationId"],
|
||||
}
|
||||
|
||||
|
||||
def _authenticate_oidc(
|
||||
credentials: MongoCredential, conn: Connection, reauthenticate: bool
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""Authenticate using MONGODB-OIDC."""
|
||||
authenticator = _get_authenticator(credentials, conn.address)
|
||||
if reauthenticate:
|
||||
return authenticator.reauthenticate(conn)
|
||||
else:
|
||||
return authenticator.authenticate(conn)
|
||||
__doc__ = original_doc
|
||||
|
||||
@ -1,489 +1,21 @@
|
||||
# Copyright 2017 MongoDB, Inc.
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Watch changes on a collection, a database, or the entire cluster."""
|
||||
"""Re-import of synchronous ChangeStream API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
|
||||
from pymongo.synchronous.change_stream import * # noqa: F403
|
||||
from pymongo.synchronous.change_stream import __doc__ as original_doc
|
||||
|
||||
from bson import CodecOptions, _bson_to_dict
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot, common
|
||||
from pymongo.aggregation import (
|
||||
_AggregationCommand,
|
||||
_CollectionAggregationCommand,
|
||||
_DatabaseAggregationCommand,
|
||||
)
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
from pymongo.errors import (
|
||||
ConnectionFailure,
|
||||
CursorNotFound,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
)
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
|
||||
# The change streams spec considers the following server errors from the
|
||||
# getMore command non-resumable. All other getMore errors are resumable.
|
||||
_RESUMABLE_GETMORE_ERRORS = frozenset(
|
||||
[
|
||||
6, # HostUnreachable
|
||||
7, # HostNotFound
|
||||
89, # NetworkTimeout
|
||||
91, # ShutdownInProgress
|
||||
189, # PrimarySteppedDown
|
||||
262, # ExceededTimeLimit
|
||||
9001, # SocketException
|
||||
10107, # NotWritablePrimary
|
||||
11600, # InterruptedAtShutdown
|
||||
11602, # InterruptedDueToReplStateChange
|
||||
13435, # NotPrimaryNoSecondaryOk
|
||||
13436, # NotPrimaryOrSecondary
|
||||
63, # StaleShardVersion
|
||||
150, # StaleEpoch
|
||||
13388, # StaleConfig
|
||||
234, # RetryChangeStream
|
||||
133, # FailedToSatisfyReadPreference
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.pool import Connection
|
||||
|
||||
|
||||
def _resumable(exc: PyMongoError) -> bool:
|
||||
"""Return True if given a resumable change stream error."""
|
||||
if isinstance(exc, (ConnectionFailure, CursorNotFound)):
|
||||
return True
|
||||
if isinstance(exc, OperationFailure):
|
||||
if exc._max_wire_version is None:
|
||||
return False
|
||||
return (
|
||||
exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError")
|
||||
) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS)
|
||||
return False
|
||||
|
||||
|
||||
class ChangeStream(Generic[_DocumentType]):
|
||||
"""The internal abstract base class for change stream cursors.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
:meth:`pymongo.collection.Collection.watch`,
|
||||
:meth:`pymongo.database.Database.watch`, or
|
||||
:meth:`pymongo.mongo_client.MongoClient.watch` instead.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
.. seealso:: The MongoDB documentation on `changeStreams <https://mongodb.com/docs/manual/changeStreams/>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[
|
||||
MongoClient[_DocumentType], Database[_DocumentType], Collection[_DocumentType]
|
||||
],
|
||||
pipeline: Optional[_Pipeline],
|
||||
full_document: Optional[str],
|
||||
resume_after: Optional[Mapping[str, Any]],
|
||||
max_await_time_ms: Optional[int],
|
||||
batch_size: Optional[int],
|
||||
collation: Optional[_CollationIn],
|
||||
start_at_operation_time: Optional[Timestamp],
|
||||
session: Optional[ClientSession],
|
||||
start_after: Optional[Mapping[str, Any]],
|
||||
comment: Optional[Any] = None,
|
||||
full_document_before_change: Optional[str] = None,
|
||||
show_expanded_events: Optional[bool] = None,
|
||||
) -> None:
|
||||
if pipeline is None:
|
||||
pipeline = []
|
||||
pipeline = common.validate_list("pipeline", pipeline)
|
||||
common.validate_string_or_none("full_document", full_document)
|
||||
validate_collation_or_none(collation)
|
||||
common.validate_non_negative_integer_or_none("batchSize", batch_size)
|
||||
|
||||
self._decode_custom = False
|
||||
self._orig_codec_options: CodecOptions[_DocumentType] = target.codec_options
|
||||
if target.codec_options.type_registry._decoder_map:
|
||||
self._decode_custom = True
|
||||
# Keep the type registry so that we support encoding custom types
|
||||
# in the pipeline.
|
||||
self._target = target.with_options( # type: ignore
|
||||
codec_options=target.codec_options.with_options(document_class=RawBSONDocument)
|
||||
)
|
||||
else:
|
||||
self._target = target
|
||||
|
||||
self._pipeline = copy.deepcopy(pipeline)
|
||||
self._full_document = full_document
|
||||
self._full_document_before_change = full_document_before_change
|
||||
self._uses_start_after = start_after is not None
|
||||
self._uses_resume_after = resume_after is not None
|
||||
self._resume_token = copy.deepcopy(start_after or resume_after)
|
||||
self._max_await_time_ms = max_await_time_ms
|
||||
self._batch_size = batch_size
|
||||
self._collation = collation
|
||||
self._start_at_operation_time = start_at_operation_time
|
||||
self._session = session
|
||||
self._comment = comment
|
||||
self._closed = False
|
||||
self._timeout = self._target._timeout
|
||||
self._show_expanded_events = show_expanded_events
|
||||
# Initialize cursor.
|
||||
self._cursor = self._create_cursor()
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_AggregationCommand]:
|
||||
"""The aggregation command class to be used."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _client(self) -> MongoClient:
|
||||
"""The client against which the aggregation commands for
|
||||
this ChangeStream will be run.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _change_stream_options(self) -> dict[str, Any]:
|
||||
"""Return the options dict for the $changeStream pipeline stage."""
|
||||
options: dict[str, Any] = {}
|
||||
if self._full_document is not None:
|
||||
options["fullDocument"] = self._full_document
|
||||
|
||||
if self._full_document_before_change is not None:
|
||||
options["fullDocumentBeforeChange"] = self._full_document_before_change
|
||||
|
||||
resume_token = self.resume_token
|
||||
if resume_token is not None:
|
||||
if self._uses_start_after:
|
||||
options["startAfter"] = resume_token
|
||||
else:
|
||||
options["resumeAfter"] = resume_token
|
||||
elif self._start_at_operation_time is not None:
|
||||
options["startAtOperationTime"] = self._start_at_operation_time
|
||||
|
||||
if self._show_expanded_events:
|
||||
options["showExpandedEvents"] = self._show_expanded_events
|
||||
|
||||
return options
|
||||
|
||||
def _command_options(self) -> dict[str, Any]:
|
||||
"""Return the options dict for the aggregation command."""
|
||||
options = {}
|
||||
if self._max_await_time_ms is not None:
|
||||
options["maxAwaitTimeMS"] = self._max_await_time_ms
|
||||
if self._batch_size is not None:
|
||||
options["batchSize"] = self._batch_size
|
||||
return options
|
||||
|
||||
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
|
||||
"""Return the full aggregation pipeline for this ChangeStream."""
|
||||
options = self._change_stream_options()
|
||||
full_pipeline: list = [{"$changeStream": options}]
|
||||
full_pipeline.extend(self._pipeline)
|
||||
return full_pipeline
|
||||
|
||||
def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None:
|
||||
"""Callback that caches the postBatchResumeToken or
|
||||
startAtOperationTime from a changeStream aggregate command response
|
||||
containing an empty batch of change documents.
|
||||
|
||||
This is implemented as a callback because we need access to the wire
|
||||
version in order to determine whether to cache this value.
|
||||
"""
|
||||
if not result["cursor"]["firstBatch"]:
|
||||
if "postBatchResumeToken" in result["cursor"]:
|
||||
self._resume_token = result["cursor"]["postBatchResumeToken"]
|
||||
elif (
|
||||
self._start_at_operation_time is None
|
||||
and self._uses_resume_after is False
|
||||
and self._uses_start_after is False
|
||||
and conn.max_wire_version >= 7
|
||||
):
|
||||
self._start_at_operation_time = result.get("operationTime")
|
||||
# PYTHON-2181: informative error on missing operationTime.
|
||||
if self._start_at_operation_time is None:
|
||||
raise OperationFailure(
|
||||
"Expected field 'operationTime' missing from command "
|
||||
f"response : {result!r}"
|
||||
)
|
||||
|
||||
def _run_aggregation_cmd(
|
||||
self, session: Optional[ClientSession], explicit_session: bool
|
||||
) -> CommandCursor:
|
||||
"""Run the full aggregation pipeline for this ChangeStream and return
|
||||
the corresponding CommandCursor.
|
||||
"""
|
||||
cmd = self._aggregation_command_class(
|
||||
self._target,
|
||||
CommandCursor,
|
||||
self._aggregation_pipeline(),
|
||||
self._command_options(),
|
||||
explicit_session,
|
||||
result_processor=self._process_result,
|
||||
comment=self._comment,
|
||||
)
|
||||
return self._client._retryable_read(
|
||||
cmd.get_cursor,
|
||||
self._target._read_preference_for(session),
|
||||
session,
|
||||
operation=_Op.AGGREGATE,
|
||||
)
|
||||
|
||||
def _create_cursor(self) -> CommandCursor:
|
||||
with self._client._tmp_session(self._session, close=False) as s:
|
||||
return self._run_aggregation_cmd(session=s, explicit_session=self._session is not None)
|
||||
|
||||
def _resume(self) -> None:
|
||||
"""Reestablish this change stream after a resumable error."""
|
||||
try:
|
||||
self._cursor.close()
|
||||
except PyMongoError:
|
||||
pass
|
||||
self._cursor = self._create_cursor()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close this ChangeStream."""
|
||||
self._closed = True
|
||||
self._cursor.close()
|
||||
|
||||
def __iter__(self) -> ChangeStream[_DocumentType]:
|
||||
return self
|
||||
|
||||
@property
|
||||
def resume_token(self) -> Optional[Mapping[str, Any]]:
|
||||
"""The cached resume token that will be used to resume after the most
|
||||
recently returned change.
|
||||
|
||||
.. versionadded:: 3.9
|
||||
"""
|
||||
return copy.deepcopy(self._resume_token)
|
||||
|
||||
@_csot.apply
|
||||
def next(self) -> _DocumentType:
|
||||
"""Advance the cursor.
|
||||
|
||||
This method blocks until the next change document is returned or an
|
||||
unrecoverable error is raised. This method is used when iterating over
|
||||
all changes in the cursor. For example::
|
||||
|
||||
try:
|
||||
resume_token = None
|
||||
pipeline = [{'$match': {'operationType': 'insert'}}]
|
||||
with db.collection.watch(pipeline) as stream:
|
||||
for insert_change in stream:
|
||||
print(insert_change)
|
||||
resume_token = stream.resume_token
|
||||
except pymongo.errors.PyMongoError:
|
||||
# The ChangeStream encountered an unrecoverable error or the
|
||||
# resume attempt failed to recreate the cursor.
|
||||
if resume_token is None:
|
||||
# There is no usable resume token because there was a
|
||||
# failure during ChangeStream initialization.
|
||||
logging.error('...')
|
||||
else:
|
||||
# Use the interrupted ChangeStream's resume token to create
|
||||
# a new ChangeStream. The new stream will continue from the
|
||||
# last seen insert change without missing any events.
|
||||
with db.collection.watch(
|
||||
pipeline, resume_after=resume_token) as stream:
|
||||
for insert_change in stream:
|
||||
print(insert_change)
|
||||
|
||||
Raises :exc:`StopIteration` if this ChangeStream is closed.
|
||||
"""
|
||||
while self.alive:
|
||||
doc = self.try_next()
|
||||
if doc is not None:
|
||||
return doc
|
||||
|
||||
raise StopIteration
|
||||
|
||||
__next__ = next
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
"""Does this cursor have the potential to return more data?
|
||||
|
||||
.. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise
|
||||
:exc:`StopIteration` and :meth:`try_next` can return ``None``.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
"""
|
||||
return not self._closed
|
||||
|
||||
@_csot.apply
|
||||
def try_next(self) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor without blocking indefinitely.
|
||||
|
||||
This method returns the next change document without waiting
|
||||
indefinitely for the next change. For example::
|
||||
|
||||
with db.collection.watch() as stream:
|
||||
while stream.alive:
|
||||
change = stream.try_next()
|
||||
# Note that the ChangeStream's resume token may be updated
|
||||
# even when no changes are returned.
|
||||
print("Current resume token: %r" % (stream.resume_token,))
|
||||
if change is not None:
|
||||
print("Change document: %r" % (change,))
|
||||
continue
|
||||
# We end up here when there are no recent changes.
|
||||
# Sleep for a while before trying again to avoid flooding
|
||||
# the server with getMore requests when no changes are
|
||||
# available.
|
||||
time.sleep(10)
|
||||
|
||||
If no change document is cached locally then this method runs a single
|
||||
getMore command. If the getMore yields any documents, the next
|
||||
document is returned, otherwise, if the getMore returns no documents
|
||||
(because there have been no changes) then ``None`` is returned.
|
||||
|
||||
:return: The next change document or ``None`` when no document is available
|
||||
after running a single getMore or when the cursor is closed.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
"""
|
||||
if not self._closed and not self._cursor.alive:
|
||||
self._resume()
|
||||
|
||||
# Attempt to get the next change with at most one getMore and at most
|
||||
# one resume attempt.
|
||||
try:
|
||||
try:
|
||||
change = self._cursor._try_next(True)
|
||||
except PyMongoError as exc:
|
||||
if not _resumable(exc):
|
||||
raise
|
||||
self._resume()
|
||||
change = self._cursor._try_next(False)
|
||||
except PyMongoError as exc:
|
||||
# Close the stream after a fatal error.
|
||||
if not _resumable(exc) and not exc.timeout:
|
||||
self.close()
|
||||
raise
|
||||
except Exception:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
# Check if the cursor was invalidated.
|
||||
if not self._cursor.alive:
|
||||
self._closed = True
|
||||
|
||||
# If no changes are available.
|
||||
if change is None:
|
||||
# We have either iterated over all documents in the cursor,
|
||||
# OR the most-recently returned batch is empty. In either case,
|
||||
# update the cached resume token with the postBatchResumeToken if
|
||||
# one was returned. We also clear the startAtOperationTime.
|
||||
if self._cursor._post_batch_resume_token is not None:
|
||||
self._resume_token = self._cursor._post_batch_resume_token
|
||||
self._start_at_operation_time = None
|
||||
return change
|
||||
|
||||
# Else, changes are available.
|
||||
try:
|
||||
resume_token = change["_id"]
|
||||
except KeyError:
|
||||
self.close()
|
||||
raise InvalidOperation(
|
||||
"Cannot provide resume functionality when the resume token is missing."
|
||||
) from None
|
||||
|
||||
# If this is the last change document from the current batch, cache the
|
||||
# postBatchResumeToken.
|
||||
if not self._cursor._has_next() and self._cursor._post_batch_resume_token:
|
||||
resume_token = self._cursor._post_batch_resume_token
|
||||
|
||||
# Hereafter, don't use startAfter; instead use resumeAfter.
|
||||
self._uses_start_after = False
|
||||
self._uses_resume_after = True
|
||||
|
||||
# Cache the resume token and clear startAtOperationTime.
|
||||
self._resume_token = resume_token
|
||||
self._start_at_operation_time = None
|
||||
|
||||
if self._decode_custom:
|
||||
return _bson_to_dict(change.raw, self._orig_codec_options)
|
||||
return change
|
||||
|
||||
def __enter__(self) -> ChangeStream[_DocumentType]:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class CollectionChangeStream(ChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on a single collection.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.collection.Collection.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
_target: Collection[_DocumentType]
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_CollectionAggregationCommand]:
|
||||
return _CollectionAggregationCommand
|
||||
|
||||
@property
|
||||
def _client(self) -> MongoClient[_DocumentType]:
|
||||
return self._target.database.client
|
||||
|
||||
|
||||
class DatabaseChangeStream(ChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in a database.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.database.Database.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
_target: Database[_DocumentType]
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_DatabaseAggregationCommand]:
|
||||
return _DatabaseAggregationCommand
|
||||
|
||||
@property
|
||||
def _client(self) -> MongoClient[_DocumentType]:
|
||||
return self._target.client
|
||||
|
||||
|
||||
class ClusterChangeStream(DatabaseChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in the cluster.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.mongo_client.MongoClient.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
def _change_stream_options(self) -> dict[str, Any]:
|
||||
options = super()._change_stream_options()
|
||||
options["allChangesForCluster"] = True
|
||||
return options
|
||||
__doc__ = original_doc
|
||||
|
||||
@ -1,332 +1,21 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tools to parse mongo client options."""
|
||||
"""Re-import of synchronous ClientOptions API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast
|
||||
from pymongo.synchronous.client_options import * # noqa: F403
|
||||
from pymongo.synchronous.client_options import __doc__ as original_doc
|
||||
|
||||
from bson.codec_options import _parse_codec_options
|
||||
from pymongo import common
|
||||
from pymongo.compression_support import CompressionSettings
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.monitoring import _EventListener, _EventListeners
|
||||
from pymongo.pool import PoolOptions
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import (
|
||||
_ServerMode,
|
||||
make_read_preference,
|
||||
read_pref_mode_from_name,
|
||||
)
|
||||
from pymongo.server_selectors import any_server_selector
|
||||
from pymongo.ssl_support import get_ssl_context
|
||||
from pymongo.write_concern import WriteConcern, validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.codec_options import CodecOptions
|
||||
from pymongo.auth import MongoCredential
|
||||
from pymongo.encryption_options import AutoEncryptionOpts
|
||||
from pymongo.pyopenssl_context import SSLContext
|
||||
from pymongo.topology_description import _ServerSelector
|
||||
|
||||
|
||||
def _parse_credentials(
|
||||
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
) -> Optional[MongoCredential]:
|
||||
"""Parse authentication credentials."""
|
||||
mechanism = options.get("authmechanism", "DEFAULT" if username else None)
|
||||
source = options.get("authsource")
|
||||
if username or mechanism:
|
||||
from pymongo.auth import _build_credentials_tuple
|
||||
|
||||
return _build_credentials_tuple(mechanism, source, username, password, options, database)
|
||||
return None
|
||||
|
||||
|
||||
def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode:
|
||||
"""Parse read preference options."""
|
||||
if "read_preference" in options:
|
||||
return options["read_preference"]
|
||||
|
||||
name = options.get("readpreference", "primary")
|
||||
mode = read_pref_mode_from_name(name)
|
||||
tags = options.get("readpreferencetags")
|
||||
max_staleness = options.get("maxstalenessseconds", -1)
|
||||
return make_read_preference(mode, tags, max_staleness)
|
||||
|
||||
|
||||
def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern:
|
||||
"""Parse write concern options."""
|
||||
concern = options.get("w")
|
||||
wtimeout = options.get("wtimeoutms")
|
||||
j = options.get("journal")
|
||||
fsync = options.get("fsync")
|
||||
return WriteConcern(concern, wtimeout, j, fsync)
|
||||
|
||||
|
||||
def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
|
||||
"""Parse read concern options."""
|
||||
concern = options.get("readconcernlevel")
|
||||
return ReadConcern(concern)
|
||||
|
||||
|
||||
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
|
||||
"""Parse ssl options."""
|
||||
use_tls = options.get("tls")
|
||||
if use_tls is not None:
|
||||
validate_boolean("tls", use_tls)
|
||||
|
||||
certfile = options.get("tlscertificatekeyfile")
|
||||
passphrase = options.get("tlscertificatekeyfilepassword")
|
||||
ca_certs = options.get("tlscafile")
|
||||
crlfile = options.get("tlscrlfile")
|
||||
allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False)
|
||||
allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False)
|
||||
disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False)
|
||||
|
||||
enabled_tls_opts = []
|
||||
for opt in (
|
||||
"tlscertificatekeyfile",
|
||||
"tlscertificatekeyfilepassword",
|
||||
"tlscafile",
|
||||
"tlscrlfile",
|
||||
):
|
||||
# Any non-null value of these options implies tls=True.
|
||||
if opt in options and options[opt]:
|
||||
enabled_tls_opts.append(opt)
|
||||
for opt in (
|
||||
"tlsallowinvalidcertificates",
|
||||
"tlsallowinvalidhostnames",
|
||||
"tlsdisableocspendpointcheck",
|
||||
):
|
||||
# A value of False for these options implies tls=True.
|
||||
if opt in options and not options[opt]:
|
||||
enabled_tls_opts.append(opt)
|
||||
|
||||
if enabled_tls_opts:
|
||||
if use_tls is None:
|
||||
# Implicitly enable TLS when one of the tls* options is set.
|
||||
use_tls = True
|
||||
elif not use_tls:
|
||||
# Error since tls is explicitly disabled but a tls option is set.
|
||||
raise ConfigurationError(
|
||||
"TLS has not been enabled but the "
|
||||
"following tls parameters have been set: "
|
||||
"%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts)
|
||||
)
|
||||
|
||||
if use_tls:
|
||||
ctx = get_ssl_context(
|
||||
certfile,
|
||||
passphrase,
|
||||
ca_certs,
|
||||
crlfile,
|
||||
allow_invalid_certificates,
|
||||
allow_invalid_hostnames,
|
||||
disable_ocsp_endpoint_check,
|
||||
)
|
||||
return ctx, allow_invalid_hostnames
|
||||
return None, allow_invalid_hostnames
|
||||
|
||||
|
||||
def _parse_pool_options(
|
||||
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
) -> PoolOptions:
|
||||
"""Parse connection pool options."""
|
||||
credentials = _parse_credentials(username, password, database, options)
|
||||
max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE)
|
||||
min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE)
|
||||
max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC)
|
||||
if max_pool_size is not None and min_pool_size > max_pool_size:
|
||||
raise ValueError("minPoolSize must be smaller or equal to maxPoolSize")
|
||||
connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT)
|
||||
socket_timeout = options.get("sockettimeoutms")
|
||||
wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT)
|
||||
event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners"))
|
||||
appname = options.get("appname")
|
||||
driver = options.get("driver")
|
||||
server_api = options.get("server_api")
|
||||
compression_settings = CompressionSettings(
|
||||
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
|
||||
)
|
||||
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
|
||||
load_balanced = options.get("loadbalanced")
|
||||
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
|
||||
return PoolOptions(
|
||||
max_pool_size,
|
||||
min_pool_size,
|
||||
max_idle_time_seconds,
|
||||
connect_timeout,
|
||||
socket_timeout,
|
||||
wait_queue_timeout,
|
||||
ssl_context,
|
||||
tls_allow_invalid_hostnames,
|
||||
_EventListeners(event_listeners),
|
||||
appname,
|
||||
driver,
|
||||
compression_settings,
|
||||
max_connecting=max_connecting,
|
||||
server_api=server_api,
|
||||
load_balanced=load_balanced,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
|
||||
class ClientOptions:
|
||||
"""Read only configuration options for a MongoClient.
|
||||
|
||||
Should not be instantiated directly by application developers. Access
|
||||
a client's options via :attr:`pymongo.mongo_client.MongoClient.options`
|
||||
instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
):
|
||||
self.__options = options
|
||||
self.__codec_options = _parse_codec_options(options)
|
||||
self.__direct_connection = options.get("directconnection")
|
||||
self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS)
|
||||
# self.__server_selection_timeout is in seconds. Must use full name for
|
||||
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
|
||||
self.__server_selection_timeout = options.get(
|
||||
"serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT
|
||||
)
|
||||
self.__pool_options = _parse_pool_options(username, password, database, options)
|
||||
self.__read_preference = _parse_read_preference(options)
|
||||
self.__replica_set_name = options.get("replicaset")
|
||||
self.__write_concern = _parse_write_concern(options)
|
||||
self.__read_concern = _parse_read_concern(options)
|
||||
self.__connect = options.get("connect")
|
||||
self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY)
|
||||
self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES)
|
||||
self.__retry_reads = options.get("retryreads", common.RETRY_READS)
|
||||
self.__server_selector = options.get("server_selector", any_server_selector)
|
||||
self.__auto_encryption_opts = options.get("auto_encryption_opts")
|
||||
self.__load_balanced = options.get("loadbalanced")
|
||||
self.__timeout = options.get("timeoutms")
|
||||
self.__server_monitoring_mode = options.get(
|
||||
"servermonitoringmode", common.SERVER_MONITORING_MODE
|
||||
)
|
||||
|
||||
@property
|
||||
def _options(self) -> Mapping[str, Any]:
|
||||
"""The original options used to create this ClientOptions."""
|
||||
return self.__options
|
||||
|
||||
@property
|
||||
def connect(self) -> Optional[bool]:
|
||||
"""Whether to begin discovering a MongoDB topology automatically."""
|
||||
return self.__connect
|
||||
|
||||
@property
|
||||
def codec_options(self) -> CodecOptions:
|
||||
"""A :class:`~bson.codec_options.CodecOptions` instance."""
|
||||
return self.__codec_options
|
||||
|
||||
@property
|
||||
def direct_connection(self) -> Optional[bool]:
|
||||
"""Whether to connect to the deployment in 'Single' topology."""
|
||||
return self.__direct_connection
|
||||
|
||||
@property
|
||||
def local_threshold_ms(self) -> int:
|
||||
"""The local threshold for this instance."""
|
||||
return self.__local_threshold_ms
|
||||
|
||||
@property
|
||||
def server_selection_timeout(self) -> int:
|
||||
"""The server selection timeout for this instance in seconds."""
|
||||
return self.__server_selection_timeout
|
||||
|
||||
@property
|
||||
def server_selector(self) -> _ServerSelector:
|
||||
return self.__server_selector
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self) -> int:
|
||||
"""The monitoring frequency in seconds."""
|
||||
return self.__heartbeat_frequency
|
||||
|
||||
@property
|
||||
def pool_options(self) -> PoolOptions:
|
||||
"""A :class:`~pymongo.pool.PoolOptions` instance."""
|
||||
return self.__pool_options
|
||||
|
||||
@property
|
||||
def read_preference(self) -> _ServerMode:
|
||||
"""A read preference instance."""
|
||||
return self.__read_preference
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""Replica set name or None."""
|
||||
return self.__replica_set_name
|
||||
|
||||
@property
|
||||
def write_concern(self) -> WriteConcern:
|
||||
"""A :class:`~pymongo.write_concern.WriteConcern` instance."""
|
||||
return self.__write_concern
|
||||
|
||||
@property
|
||||
def read_concern(self) -> ReadConcern:
|
||||
"""A :class:`~pymongo.read_concern.ReadConcern` instance."""
|
||||
return self.__read_concern
|
||||
|
||||
@property
|
||||
def timeout(self) -> Optional[float]:
|
||||
"""The configured timeoutMS converted to seconds, or None.
|
||||
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
return self.__timeout
|
||||
|
||||
@property
|
||||
def retry_writes(self) -> bool:
|
||||
"""If this instance should retry supported write operations."""
|
||||
return self.__retry_writes
|
||||
|
||||
@property
|
||||
def retry_reads(self) -> bool:
|
||||
"""If this instance should retry supported read operations."""
|
||||
return self.__retry_reads
|
||||
|
||||
@property
|
||||
def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]:
|
||||
"""A :class:`~pymongo.encryption.AutoEncryptionOpts` or None."""
|
||||
return self.__auto_encryption_opts
|
||||
|
||||
@property
|
||||
def load_balanced(self) -> Optional[bool]:
|
||||
"""True if the client was configured to connect to a load balancer."""
|
||||
return self.__load_balanced
|
||||
|
||||
@property
|
||||
def event_listeners(self) -> list[_EventListeners]:
|
||||
"""The event listeners registered for this client.
|
||||
|
||||
See :mod:`~pymongo.monitoring` for details.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
assert self.__pool_options._event_listeners is not None
|
||||
return self.__pool_options._event_listeners.event_listeners()
|
||||
|
||||
@property
|
||||
def server_monitoring_mode(self) -> str:
|
||||
"""The configured serverMonitoringMode option.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
"""
|
||||
return self.__server_monitoring_mode
|
||||
__doc__ = original_doc
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
# Copyright 2016 MongoDB, Inc.
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,213 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tools for working with `collations`_.
|
||||
|
||||
.. _collations: https://www.mongodb.com/docs/manual/reference/collation/
|
||||
"""
|
||||
"""Re-import of synchronous Collation API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping, Optional, Union
|
||||
from pymongo.synchronous.collation import * # noqa: F403
|
||||
from pymongo.synchronous.collation import __doc__ as original_doc
|
||||
|
||||
from pymongo import common
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
|
||||
class CollationStrength:
|
||||
"""
|
||||
An enum that defines values for `strength` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
PRIMARY = 1
|
||||
"""Differentiate base (unadorned) characters."""
|
||||
|
||||
SECONDARY = 2
|
||||
"""Differentiate character accents."""
|
||||
|
||||
TERTIARY = 3
|
||||
"""Differentiate character case."""
|
||||
|
||||
QUATERNARY = 4
|
||||
"""Differentiate words with and without punctuation."""
|
||||
|
||||
IDENTICAL = 5
|
||||
"""Differentiate unicode code point (characters are exactly identical)."""
|
||||
|
||||
|
||||
class CollationAlternate:
|
||||
"""
|
||||
An enum that defines values for `alternate` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
NON_IGNORABLE = "non-ignorable"
|
||||
"""Spaces and punctuation are treated as base characters."""
|
||||
|
||||
SHIFTED = "shifted"
|
||||
"""Spaces and punctuation are *not* considered base characters.
|
||||
|
||||
Spaces and punctuation are distinguished regardless when the
|
||||
:class:`~pymongo.collation.Collation` strength is at least
|
||||
:data:`~pymongo.collation.CollationStrength.QUATERNARY`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class CollationMaxVariable:
|
||||
"""
|
||||
An enum that defines values for `max_variable` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
PUNCT = "punct"
|
||||
"""Both punctuation and spaces are ignored."""
|
||||
|
||||
SPACE = "space"
|
||||
"""Spaces alone are ignored."""
|
||||
|
||||
|
||||
class CollationCaseFirst:
|
||||
"""
|
||||
An enum that defines values for `case_first` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
UPPER = "upper"
|
||||
"""Sort uppercase characters first."""
|
||||
|
||||
LOWER = "lower"
|
||||
"""Sort lowercase characters first."""
|
||||
|
||||
OFF = "off"
|
||||
"""Default for locale or collation strength."""
|
||||
|
||||
|
||||
class Collation:
|
||||
"""Collation
|
||||
|
||||
:param locale: (string) The locale of the collation. This should be a string
|
||||
that identifies an `ICU locale ID` exactly. For example, ``en_US`` is
|
||||
valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB
|
||||
documentation for a list of supported locales.
|
||||
:param caseLevel: (optional) If ``True``, turn on case sensitivity if
|
||||
`strength` is 1 or 2 (case sensitivity is implied if `strength` is
|
||||
greater than 2). Defaults to ``False``.
|
||||
:param caseFirst: (optional) Specify that either uppercase or lowercase
|
||||
characters take precedence. Must be one of the following values:
|
||||
|
||||
* :data:`~CollationCaseFirst.UPPER`
|
||||
* :data:`~CollationCaseFirst.LOWER`
|
||||
* :data:`~CollationCaseFirst.OFF` (the default)
|
||||
|
||||
:param strength: Specify the comparison strength. This is also
|
||||
known as the ICU comparison level. This must be one of the following
|
||||
values:
|
||||
|
||||
* :data:`~CollationStrength.PRIMARY`
|
||||
* :data:`~CollationStrength.SECONDARY`
|
||||
* :data:`~CollationStrength.TERTIARY` (the default)
|
||||
* :data:`~CollationStrength.QUATERNARY`
|
||||
* :data:`~CollationStrength.IDENTICAL`
|
||||
|
||||
Each successive level builds upon the previous. For example, a
|
||||
`strength` of :data:`~CollationStrength.SECONDARY` differentiates
|
||||
characters based both on the unadorned base character and its accents.
|
||||
|
||||
:param numericOrdering: If ``True``, order numbers numerically
|
||||
instead of in collation order (defaults to ``False``).
|
||||
:param alternate: Specify whether spaces and punctuation are
|
||||
considered base characters. This must be one of the following values:
|
||||
|
||||
* :data:`~CollationAlternate.NON_IGNORABLE` (the default)
|
||||
* :data:`~CollationAlternate.SHIFTED`
|
||||
|
||||
:param maxVariable: When `alternate` is
|
||||
:data:`~CollationAlternate.SHIFTED`, this option specifies what
|
||||
characters may be ignored. This must be one of the following values:
|
||||
|
||||
* :data:`~CollationMaxVariable.PUNCT` (the default)
|
||||
* :data:`~CollationMaxVariable.SPACE`
|
||||
|
||||
:param normalization: If ``True``, normalizes text into Unicode
|
||||
NFD. Defaults to ``False``.
|
||||
:param backwards: If ``True``, accents on characters are
|
||||
considered from the back of the word to the front, as it is done in some
|
||||
French dictionary ordering traditions. Defaults to ``False``.
|
||||
:param kwargs: Keyword arguments supplying any additional options
|
||||
to be sent with this Collation object.
|
||||
|
||||
.. versionadded: 3.4
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
locale: str,
|
||||
caseLevel: Optional[bool] = None,
|
||||
caseFirst: Optional[str] = None,
|
||||
strength: Optional[int] = None,
|
||||
numericOrdering: Optional[bool] = None,
|
||||
alternate: Optional[str] = None,
|
||||
maxVariable: Optional[str] = None,
|
||||
normalization: Optional[bool] = None,
|
||||
backwards: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
locale = common.validate_string("locale", locale)
|
||||
self.__document: dict[str, Any] = {"locale": locale}
|
||||
if caseLevel is not None:
|
||||
self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel)
|
||||
if caseFirst is not None:
|
||||
self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst)
|
||||
if strength is not None:
|
||||
self.__document["strength"] = common.validate_integer("strength", strength)
|
||||
if numericOrdering is not None:
|
||||
self.__document["numericOrdering"] = validate_boolean(
|
||||
"numericOrdering", numericOrdering
|
||||
)
|
||||
if alternate is not None:
|
||||
self.__document["alternate"] = common.validate_string("alternate", alternate)
|
||||
if maxVariable is not None:
|
||||
self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable)
|
||||
if normalization is not None:
|
||||
self.__document["normalization"] = validate_boolean("normalization", normalization)
|
||||
if backwards is not None:
|
||||
self.__document["backwards"] = validate_boolean("backwards", backwards)
|
||||
self.__document.update(kwargs)
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""The document representation of this collation.
|
||||
|
||||
.. note::
|
||||
:class:`Collation` is immutable. Mutating the value of
|
||||
:attr:`document` does not mutate this :class:`Collation`.
|
||||
"""
|
||||
return self.__document.copy()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
document = self.document
|
||||
return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Collation):
|
||||
return self.document == other.document
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
def validate_collation_or_none(
|
||||
value: Optional[Union[Mapping[str, Any], Collation]]
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, Collation):
|
||||
return value.document
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")
|
||||
__doc__ = original_doc
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,390 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CommandCursor class to iterate over command results."""
|
||||
"""Re-import of synchronous CommandCursor API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Iterator,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
from pymongo.synchronous.command_cursor import * # noqa: F403
|
||||
from pymongo.synchronous.command_cursor import __doc__ as original_doc
|
||||
|
||||
from bson import CodecOptions, _convert_raw_document_lists_to_streams
|
||||
from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _ConnectionManager
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.message import _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore
|
||||
from pymongo.response import PinnedResponse
|
||||
from pymongo.typings import _Address, _DocumentOut, _DocumentType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.pool import Connection
|
||||
|
||||
|
||||
class CommandCursor(Generic[_DocumentType]):
|
||||
"""A cursor / iterator over command cursors."""
|
||||
|
||||
_getmore_class = _GetMore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: Collection[_DocumentType],
|
||||
cursor_info: Mapping[str, Any],
|
||||
address: Optional[_Address],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
explicit_session: bool = False,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
"""Create a new command cursor."""
|
||||
self.__sock_mgr: Any = None
|
||||
self.__collection: Collection[_DocumentType] = collection
|
||||
self.__id = cursor_info["id"]
|
||||
self.__data = deque(cursor_info["firstBatch"])
|
||||
self.__postbatchresumetoken: Optional[Mapping[str, Any]] = cursor_info.get(
|
||||
"postBatchResumeToken"
|
||||
)
|
||||
self.__address = address
|
||||
self.__batch_size = batch_size
|
||||
self.__max_await_time_ms = max_await_time_ms
|
||||
self.__session = session
|
||||
self.__explicit_session = explicit_session
|
||||
self.__killed = self.__id == 0
|
||||
self.__comment = comment
|
||||
if self.__killed:
|
||||
self.__end_session(True)
|
||||
|
||||
if "ns" in cursor_info: # noqa: SIM401
|
||||
self.__ns = cursor_info["ns"]
|
||||
else:
|
||||
self.__ns = collection.full_name
|
||||
|
||||
self.batch_size(batch_size)
|
||||
|
||||
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
|
||||
raise TypeError("max_await_time_ms must be an integer or None")
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.__die()
|
||||
|
||||
def __die(self, synchronous: bool = False) -> None:
|
||||
"""Closes this cursor."""
|
||||
already_killed = self.__killed
|
||||
self.__killed = True
|
||||
if self.__id and not already_killed:
|
||||
cursor_id = self.__id
|
||||
assert self.__address is not None
|
||||
address = _CursorAddress(self.__address, self.__ns)
|
||||
else:
|
||||
# Skip killCursors.
|
||||
cursor_id = 0
|
||||
address = None
|
||||
self.__collection.database.client._cleanup_cursor(
|
||||
synchronous,
|
||||
cursor_id,
|
||||
address,
|
||||
self.__sock_mgr,
|
||||
self.__session,
|
||||
self.__explicit_session,
|
||||
)
|
||||
if not self.__explicit_session:
|
||||
self.__session = None
|
||||
self.__sock_mgr = None
|
||||
|
||||
def __end_session(self, synchronous: bool) -> None:
|
||||
if self.__session and not self.__explicit_session:
|
||||
self.__session._end_session(lock=synchronous)
|
||||
self.__session = None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Explicitly close / kill this cursor."""
|
||||
self.__die(True)
|
||||
|
||||
def batch_size(self, batch_size: int) -> CommandCursor[_DocumentType]:
|
||||
"""Limits the number of documents returned in one batch. Each batch
|
||||
requires a round trip to the server. It can be adjusted to optimize
|
||||
performance and limit data transfer.
|
||||
|
||||
.. note:: batch_size can not override MongoDB's internal limits on the
|
||||
amount of data it will return to the client in a single batch (i.e
|
||||
if you set batch size to 1,000,000,000, MongoDB will currently only
|
||||
return 4-16MB of results per batch).
|
||||
|
||||
Raises :exc:`TypeError` if `batch_size` is not an integer.
|
||||
Raises :exc:`ValueError` if `batch_size` is less than ``0``.
|
||||
|
||||
:param batch_size: The size of each batch of results requested.
|
||||
"""
|
||||
if not isinstance(batch_size, int):
|
||||
raise TypeError("batch_size must be an integer")
|
||||
if batch_size < 0:
|
||||
raise ValueError("batch_size must be >= 0")
|
||||
|
||||
self.__batch_size = batch_size == 1 and 2 or batch_size
|
||||
return self
|
||||
|
||||
def _has_next(self) -> bool:
|
||||
"""Returns `True` if the cursor has documents remaining from the
|
||||
previous batch.
|
||||
"""
|
||||
return len(self.__data) > 0
|
||||
|
||||
@property
|
||||
def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]:
|
||||
"""Retrieve the postBatchResumeToken from the response to a
|
||||
changeStream aggregate or getMore.
|
||||
"""
|
||||
return self.__postbatchresumetoken
|
||||
|
||||
def _maybe_pin_connection(self, conn: Connection) -> None:
|
||||
client = self.__collection.database.client
|
||||
if not client._should_pin_cursor(self.__session):
|
||||
return
|
||||
if not self.__sock_mgr:
|
||||
conn.pin_cursor()
|
||||
conn_mgr = _ConnectionManager(conn, False)
|
||||
# Ensure the connection gets returned when the entire result is
|
||||
# returned in the first batch.
|
||||
if self.__id == 0:
|
||||
conn_mgr.close()
|
||||
else:
|
||||
self.__sock_mgr = conn_mgr
|
||||
|
||||
def __send_message(self, operation: _GetMore) -> None:
|
||||
"""Send a getmore message and handle the response."""
|
||||
client = self.__collection.database.client
|
||||
try:
|
||||
response = client._run_operation(
|
||||
operation, self._unpack_response, address=self.__address
|
||||
)
|
||||
except OperationFailure as exc:
|
||||
if exc.code in _CURSOR_CLOSED_ERRORS:
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self.__killed = True
|
||||
if exc.timeout:
|
||||
self.__die(False)
|
||||
else:
|
||||
# Return the session and pinned connection, if necessary.
|
||||
self.close()
|
||||
raise
|
||||
except ConnectionFailure:
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self.__killed = True
|
||||
# Return the session and pinned connection, if necessary.
|
||||
self.close()
|
||||
raise
|
||||
except Exception:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self.__sock_mgr:
|
||||
self.__sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
|
||||
if response.from_command:
|
||||
cursor = response.docs[0]["cursor"]
|
||||
documents = cursor["nextBatch"]
|
||||
self.__postbatchresumetoken = cursor.get("postBatchResumeToken")
|
||||
self.__id = cursor["id"]
|
||||
else:
|
||||
documents = response.docs
|
||||
assert isinstance(response.data, _OpReply)
|
||||
self.__id = response.data.cursor_id
|
||||
|
||||
if self.__id == 0:
|
||||
self.close()
|
||||
self.__data = deque(documents)
|
||||
|
||||
def _unpack_response(
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions[Mapping[str, Any]],
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> Sequence[_DocumentOut]:
|
||||
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
|
||||
|
||||
def _refresh(self) -> int:
|
||||
"""Refreshes the cursor with more data from the server.
|
||||
|
||||
Returns the length of self.__data after refresh. Will exit early if
|
||||
self.__data is already non-empty. Raises OperationFailure when the
|
||||
cursor cannot be refreshed due to an error on the query.
|
||||
"""
|
||||
if len(self.__data) or self.__killed:
|
||||
return len(self.__data)
|
||||
|
||||
if self.__id: # Get More
|
||||
dbname, collname = self.__ns.split(".", 1)
|
||||
read_pref = self.__collection._read_preference_for(self.session)
|
||||
self.__send_message(
|
||||
self._getmore_class(
|
||||
dbname,
|
||||
collname,
|
||||
self.__batch_size,
|
||||
self.__id,
|
||||
self.__collection.codec_options,
|
||||
read_pref,
|
||||
self.__session,
|
||||
self.__collection.database.client,
|
||||
self.__max_await_time_ms,
|
||||
self.__sock_mgr,
|
||||
False,
|
||||
self.__comment,
|
||||
)
|
||||
)
|
||||
else: # Cursor id is zero nothing else to return
|
||||
self.__die(True)
|
||||
|
||||
return len(self.__data)
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
"""Does this cursor have the potential to return more data?
|
||||
|
||||
Even if :attr:`alive` is ``True``, :meth:`next` can raise
|
||||
:exc:`StopIteration`. Best to use a for loop::
|
||||
|
||||
for doc in collection.aggregate(pipeline):
|
||||
print(doc)
|
||||
|
||||
.. note:: :attr:`alive` can be True while iterating a cursor from
|
||||
a failed server. In this case :attr:`alive` will return False after
|
||||
:meth:`next` fails to retrieve the next batch of results from the
|
||||
server.
|
||||
"""
|
||||
return bool(len(self.__data) or (not self.__killed))
|
||||
|
||||
@property
|
||||
def cursor_id(self) -> int:
|
||||
"""Returns the id of the cursor."""
|
||||
return self.__id
|
||||
|
||||
@property
|
||||
def address(self) -> Optional[_Address]:
|
||||
"""The (host, port) of the server used, or None.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return self.__address
|
||||
|
||||
@property
|
||||
def session(self) -> Optional[ClientSession]:
|
||||
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
"""
|
||||
if self.__explicit_session:
|
||||
return self.__session
|
||||
return None
|
||||
|
||||
def __iter__(self) -> Iterator[_DocumentType]:
|
||||
return self
|
||||
|
||||
def next(self) -> _DocumentType:
|
||||
"""Advance the cursor."""
|
||||
# Block until a document is returnable.
|
||||
while self.alive:
|
||||
doc = self._try_next(True)
|
||||
if doc is not None:
|
||||
return doc
|
||||
|
||||
raise StopIteration
|
||||
|
||||
__next__ = next
|
||||
|
||||
def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor blocking for at most one getMore command."""
|
||||
if not len(self.__data) and not self.__killed and get_more_allowed:
|
||||
self._refresh()
|
||||
if len(self.__data):
|
||||
return self.__data.popleft()
|
||||
else:
|
||||
return None
|
||||
|
||||
def try_next(self) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor without blocking indefinitely.
|
||||
|
||||
This method returns the next document without waiting
|
||||
indefinitely for data.
|
||||
|
||||
If no document is cached locally then this method runs a single
|
||||
getMore command. If the getMore yields any documents, the next
|
||||
document is returned, otherwise, if the getMore returns no documents
|
||||
(because there is no additional data) then ``None`` is returned.
|
||||
|
||||
:return: The next document or ``None`` when no document is available
|
||||
after running a single getMore or when the cursor is closed.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
"""
|
||||
return self._try_next(get_more_allowed=True)
|
||||
|
||||
def __enter__(self) -> CommandCursor[_DocumentType]:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]):
|
||||
_getmore_class = _RawBatchGetMore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: Collection[_DocumentType],
|
||||
cursor_info: Mapping[str, Any],
|
||||
address: Optional[_Address],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
explicit_session: bool = False,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
"""Create a new cursor / iterator over raw batches of BSON data.
|
||||
|
||||
Should not be called directly by application developers -
|
||||
see :meth:`~pymongo.collection.Collection.aggregate_raw_batches`
|
||||
instead.
|
||||
|
||||
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
|
||||
"""
|
||||
assert not cursor_info.get("firstBatch")
|
||||
super().__init__(
|
||||
collection,
|
||||
cursor_info,
|
||||
address,
|
||||
batch_size,
|
||||
max_await_time_ms,
|
||||
session,
|
||||
explicit_session,
|
||||
comment,
|
||||
)
|
||||
|
||||
def _unpack_response( # type: ignore[override]
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
|
||||
if not legacy_response:
|
||||
# OP_MSG returns firstBatch/nextBatch documents as a BSON array
|
||||
# Re-assemble the array of documents into a document stream
|
||||
_convert_raw_document_lists_to_streams(raw_response[0])
|
||||
return raw_response # type: ignore[return-value]
|
||||
|
||||
def __getitem__(self, index: int) -> NoReturn:
|
||||
raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor")
|
||||
__doc__ = original_doc
|
||||
|
||||
1347
pymongo/cursor.py
1347
pymongo/cursor.py
File diff suppressed because it is too large
Load Diff
94
pymongo/cursor_shared.py
Normal file
94
pymongo/cursor_shared.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
"""Constants and types shared across all cursor classes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping, Sequence, Tuple, Union
|
||||
|
||||
# These errors mean that the server has already killed the cursor so there is
|
||||
# no need to send killCursors.
|
||||
_CURSOR_CLOSED_ERRORS = frozenset(
|
||||
[
|
||||
43, # CursorNotFound
|
||||
175, # QueryPlanKilled
|
||||
237, # CursorKilled
|
||||
# On a tailable cursor, the following errors mean the capped collection
|
||||
# rolled over.
|
||||
# MongoDB 2.6:
|
||||
# {'$err': 'Runner killed during getMore', 'code': 28617, 'ok': 0}
|
||||
28617,
|
||||
# MongoDB 3.0:
|
||||
# {'$err': 'getMore executor error: UnknownError no details available',
|
||||
# 'code': 17406, 'ok': 0}
|
||||
17406,
|
||||
# MongoDB 3.2 + 3.4:
|
||||
# {'ok': 0.0, 'errmsg': 'GetMore command executor error:
|
||||
# CappedPositionLost: CollectionScan died due to failure to restore
|
||||
# tailable cursor position. Last seen record id: RecordId(3)',
|
||||
# 'code': 96}
|
||||
96,
|
||||
# MongoDB 3.6+:
|
||||
# {'ok': 0.0, 'errmsg': 'errmsg: "CollectionScan died due to failure to
|
||||
# restore tailable cursor position. Last seen record id: RecordId(3)"',
|
||||
# 'code': 136, 'codeName': 'CappedPositionLost'}
|
||||
136,
|
||||
]
|
||||
)
|
||||
|
||||
_QUERY_OPTIONS = {
|
||||
"tailable_cursor": 2,
|
||||
"secondary_okay": 4,
|
||||
"oplog_replay": 8,
|
||||
"no_timeout": 16,
|
||||
"await_data": 32,
|
||||
"exhaust": 64,
|
||||
"partial": 128,
|
||||
}
|
||||
|
||||
|
||||
class CursorType:
|
||||
NON_TAILABLE = 0
|
||||
"""The standard cursor type."""
|
||||
|
||||
TAILABLE = _QUERY_OPTIONS["tailable_cursor"]
|
||||
"""The tailable cursor type.
|
||||
|
||||
Tailable cursors are only for use with capped collections. They are not
|
||||
closed when the last data is retrieved but are kept open and the cursor
|
||||
location marks the final document position. If more data is received
|
||||
iteration of the cursor will continue from the last document received.
|
||||
"""
|
||||
|
||||
TAILABLE_AWAIT = TAILABLE | _QUERY_OPTIONS["await_data"]
|
||||
"""A tailable cursor with the await option set.
|
||||
|
||||
Creates a tailable cursor that will wait for a few seconds after returning
|
||||
the full result set so that it can capture and return additional data added
|
||||
during the query.
|
||||
"""
|
||||
|
||||
EXHAUST = _QUERY_OPTIONS["exhaust"]
|
||||
"""An exhaust cursor.
|
||||
|
||||
MongoDB will stream batched results to the client without waiting for the
|
||||
client to request each batch, reducing latency.
|
||||
"""
|
||||
|
||||
|
||||
_Sort = Union[
|
||||
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
|
||||
]
|
||||
_Hint = Union[str, _Sort]
|
||||
1377
pymongo/database.py
1377
pymongo/database.py
File diff suppressed because it is too large
Load Diff
34
pymongo/database_shared.py
Normal file
34
pymongo/database_shared.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
"""Constants, helpers, and types shared across all database classes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping, TypeVar
|
||||
|
||||
from pymongo.errors import InvalidName
|
||||
|
||||
|
||||
def _check_name(name: str) -> None:
|
||||
"""Check if a database name is valid."""
|
||||
if not name:
|
||||
raise InvalidName("database name cannot be the empty string")
|
||||
|
||||
for invalid_char in [" ", ".", "$", "/", "\\", "\x00", '"']:
|
||||
if invalid_char in name:
|
||||
raise InvalidName("database names cannot contain the character %r" % invalid_char)
|
||||
|
||||
|
||||
_CodecDocumentType = TypeVar("_CodecDocumentType", bound=Mapping[str, Any])
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,257 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Support for automatic client-side field level encryption."""
|
||||
"""Re-import of synchronous EncryptionOptions API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
||||
from pymongo.synchronous.encryption_options import * # noqa: F403
|
||||
from pymongo.synchronous.encryption_options import __doc__ as original_doc
|
||||
|
||||
try:
|
||||
import pymongocrypt # type:ignore[import] # noqa: F401
|
||||
|
||||
_HAVE_PYMONGOCRYPT = True
|
||||
except ImportError:
|
||||
_HAVE_PYMONGOCRYPT = False
|
||||
from bson import int64
|
||||
from pymongo.common import validate_is_mapping
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.uri_parser import _parse_kms_tls_options
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.typings import _DocumentTypeArg
|
||||
|
||||
|
||||
class AutoEncryptionOpts:
|
||||
"""Options to configure automatic client-side field level encryption."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: Optional[MongoClient[_DocumentTypeArg]] = None,
|
||||
schema_map: Optional[Mapping[str, Any]] = None,
|
||||
bypass_auto_encryption: bool = False,
|
||||
mongocryptd_uri: str = "mongodb://localhost:27020",
|
||||
mongocryptd_bypass_spawn: bool = False,
|
||||
mongocryptd_spawn_path: str = "mongocryptd",
|
||||
mongocryptd_spawn_args: Optional[list[str]] = None,
|
||||
kms_tls_options: Optional[Mapping[str, Any]] = None,
|
||||
crypt_shared_lib_path: Optional[str] = None,
|
||||
crypt_shared_lib_required: bool = False,
|
||||
bypass_query_analysis: bool = False,
|
||||
encrypted_fields_map: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Options to configure automatic client-side field level encryption.
|
||||
|
||||
Automatic client-side field level encryption requires MongoDB >=4.2
|
||||
enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not
|
||||
supported for operations on a database or view and will result in
|
||||
error.
|
||||
|
||||
Although automatic encryption requires MongoDB >=4.2 enterprise or a
|
||||
MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all
|
||||
users. To configure automatic *decryption* without automatic
|
||||
*encryption* set ``bypass_auto_encryption=True``. Explicit
|
||||
encryption and explicit decryption is also supported for all users
|
||||
with the :class:`~pymongo.encryption.ClientEncryption` class.
|
||||
|
||||
See :ref:`automatic-client-side-encryption` for an example.
|
||||
|
||||
:param kms_providers: Map of KMS provider options. The `kms_providers`
|
||||
map values differ by provider:
|
||||
|
||||
- `aws`: Map with "accessKeyId" and "secretAccessKey" as strings.
|
||||
These are the AWS access key ID and AWS secret access key used
|
||||
to generate KMS messages. An optional "sessionToken" may be
|
||||
included to support temporary AWS credentials.
|
||||
- `azure`: Map with "tenantId", "clientId", and "clientSecret" as
|
||||
strings. Additionally, "identityPlatformEndpoint" may also be
|
||||
specified as a string (defaults to 'login.microsoftonline.com').
|
||||
These are the Azure Active Directory credentials used to
|
||||
generate Azure Key Vault messages.
|
||||
- `gcp`: Map with "email" as a string and "privateKey"
|
||||
as `bytes` or a base64 encoded string.
|
||||
Additionally, "endpoint" may also be specified as a string
|
||||
(defaults to 'oauth2.googleapis.com'). These are the
|
||||
credentials used to generate Google Cloud KMS messages.
|
||||
- `kmip`: Map with "endpoint" as a host with required port.
|
||||
For example: ``{"endpoint": "example.com:443"}``.
|
||||
- `local`: Map with "key" as `bytes` (96 bytes in length) or
|
||||
a base64 encoded string which decodes
|
||||
to 96 bytes. "key" is the master key used to encrypt/decrypt
|
||||
data keys. This key should be generated and stored as securely
|
||||
as possible.
|
||||
|
||||
KMS providers may be specified with an optional name suffix
|
||||
separated by a colon, for example "kmip:name" or "aws:name".
|
||||
Named KMS providers do not support :ref:`CSFLE on-demand credentials`.
|
||||
Named KMS providers enables more than one of each KMS provider type to be configured.
|
||||
For example, to configure multiple local KMS providers::
|
||||
|
||||
kms_providers = {
|
||||
"local": {"key": local_kek1}, # Unnamed KMS provider.
|
||||
"local:myname": {"key": local_kek2}, # Named KMS provider with name "myname".
|
||||
}
|
||||
|
||||
:param key_vault_namespace: The namespace for the key vault collection.
|
||||
The key vault collection contains all data keys used for encryption
|
||||
and decryption. Data keys are stored as documents in this MongoDB
|
||||
collection. Data keys are protected with encryption by a KMS
|
||||
provider.
|
||||
:param key_vault_client: By default, the key vault collection
|
||||
is assumed to reside in the same MongoDB cluster as the encrypted
|
||||
MongoClient. Use this option to route data key queries to a
|
||||
separate MongoDB cluster.
|
||||
:param schema_map: Map of collection namespace ("db.coll") to
|
||||
JSON Schema. By default, a collection's JSONSchema is periodically
|
||||
polled with the listCollections command. But a JSONSchema may be
|
||||
specified locally with the schemaMap option.
|
||||
|
||||
**Supplying a `schema_map` provides more security than relying on
|
||||
JSON Schemas obtained from the server. It protects against a
|
||||
malicious server advertising a false JSON Schema, which could trick
|
||||
the client into sending unencrypted data that should be
|
||||
encrypted.**
|
||||
|
||||
Schemas supplied in the schemaMap only apply to configuring
|
||||
automatic encryption for client side encryption. Other validation
|
||||
rules in the JSON schema will not be enforced by the driver and
|
||||
will result in an error.
|
||||
:param bypass_auto_encryption: If ``True``, automatic
|
||||
encryption will be disabled but automatic decryption will still be
|
||||
enabled. Defaults to ``False``.
|
||||
:param mongocryptd_uri: The MongoDB URI used to connect
|
||||
to the *local* mongocryptd process. Defaults to
|
||||
``'mongodb://localhost:27020'``.
|
||||
:param mongocryptd_bypass_spawn: If ``True``, the encrypted
|
||||
MongoClient will not attempt to spawn the mongocryptd process.
|
||||
Defaults to ``False``.
|
||||
:param mongocryptd_spawn_path: Used for spawning the
|
||||
mongocryptd process. Defaults to ``'mongocryptd'`` and spawns
|
||||
mongocryptd from the system path.
|
||||
:param mongocryptd_spawn_args: A list of string arguments to
|
||||
use when spawning the mongocryptd process. Defaults to
|
||||
``['--idleShutdownTimeoutSecs=60']``. If the list does not include
|
||||
the ``idleShutdownTimeoutSecs`` option then
|
||||
``'--idleShutdownTimeoutSecs=60'`` will be added.
|
||||
:param kms_tls_options: A map of KMS provider names to TLS
|
||||
options to use when creating secure connections to KMS providers.
|
||||
Accepts the same TLS options as
|
||||
:class:`pymongo.mongo_client.MongoClient`. For example, to
|
||||
override the system default CA file::
|
||||
|
||||
kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}}
|
||||
|
||||
Or to supply a client certificate::
|
||||
|
||||
kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}}
|
||||
:param crypt_shared_lib_path: Override the path to load the crypt_shared library.
|
||||
:param crypt_shared_lib_required: If True, raise an error if libmongocrypt is
|
||||
unable to load the crypt_shared library.
|
||||
:param bypass_query_analysis: If ``True``, disable automatic analysis
|
||||
of outgoing commands. Set `bypass_query_analysis` to use explicit
|
||||
encryption on indexed fields without the MongoDB Enterprise Advanced
|
||||
licensed crypt_shared library.
|
||||
:param encrypted_fields_map: Map of collection namespace ("db.coll") to documents
|
||||
that described the encrypted fields for Queryable Encryption. For example::
|
||||
|
||||
{
|
||||
"db.encryptedCollection": {
|
||||
"escCollection": "enxcol_.encryptedCollection.esc",
|
||||
"ecocCollection": "enxcol_.encryptedCollection.ecoc",
|
||||
"fields": [
|
||||
{
|
||||
"path": "firstName",
|
||||
"keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')),
|
||||
"bsonType": "string",
|
||||
"queries": {"queryType": "equality"}
|
||||
},
|
||||
{
|
||||
"path": "ssn",
|
||||
"keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')),
|
||||
"bsonType": "string"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`,
|
||||
and `bypass_query_analysis` parameters.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Added the `kms_tls_options` parameter and the "kmip" KMS provider.
|
||||
|
||||
.. versionadded:: 3.9
|
||||
"""
|
||||
if not _HAVE_PYMONGOCRYPT:
|
||||
raise ConfigurationError(
|
||||
"client side encryption requires the pymongocrypt library: "
|
||||
"install a compatible version with: "
|
||||
"python -m pip install 'pymongo[encryption]'"
|
||||
)
|
||||
if encrypted_fields_map:
|
||||
validate_is_mapping("encrypted_fields_map", encrypted_fields_map)
|
||||
self._encrypted_fields_map = encrypted_fields_map
|
||||
self._bypass_query_analysis = bypass_query_analysis
|
||||
self._crypt_shared_lib_path = crypt_shared_lib_path
|
||||
self._crypt_shared_lib_required = crypt_shared_lib_required
|
||||
self._kms_providers = kms_providers
|
||||
self._key_vault_namespace = key_vault_namespace
|
||||
self._key_vault_client = key_vault_client
|
||||
self._schema_map = schema_map
|
||||
self._bypass_auto_encryption = bypass_auto_encryption
|
||||
self._mongocryptd_uri = mongocryptd_uri
|
||||
self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
|
||||
self._mongocryptd_spawn_path = mongocryptd_spawn_path
|
||||
if mongocryptd_spawn_args is None:
|
||||
mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"]
|
||||
self._mongocryptd_spawn_args = mongocryptd_spawn_args
|
||||
if not isinstance(self._mongocryptd_spawn_args, list):
|
||||
raise TypeError("mongocryptd_spawn_args must be a list")
|
||||
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
|
||||
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
|
||||
# Maps KMS provider name to a SSLContext.
|
||||
self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options)
|
||||
self._bypass_query_analysis = bypass_query_analysis
|
||||
|
||||
|
||||
class RangeOpts:
|
||||
"""Options to configure encrypted queries using the rangePreview algorithm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sparsity: int,
|
||||
min: Optional[Any] = None,
|
||||
max: Optional[Any] = None,
|
||||
precision: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Options to configure encrypted queries using the rangePreview algorithm.
|
||||
|
||||
.. note:: This feature is experimental only, and not intended for public use.
|
||||
|
||||
:param sparsity: An integer.
|
||||
:param min: A BSON scalar value corresponding to the type being queried.
|
||||
:param max: A BSON scalar value corresponding to the type being queried.
|
||||
:param precision: An integer, may only be set for double or decimal128 types.
|
||||
|
||||
.. versionadded:: 4.4
|
||||
"""
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.sparsity = sparsity
|
||||
self.precision = precision
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
doc = {}
|
||||
for k, v in [
|
||||
("sparsity", int64.Int64(self.sparsity)),
|
||||
("precision", self.precision),
|
||||
("min", self.min),
|
||||
("max", self.max),
|
||||
]:
|
||||
if v is not None:
|
||||
doc[k] = v
|
||||
return doc
|
||||
__doc__ = original_doc
|
||||
|
||||
@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional, Sequence, Un
|
||||
from bson.errors import InvalidDocument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.typings import _DocumentOut
|
||||
from pymongo.asynchronous.typings import _DocumentOut
|
||||
|
||||
|
||||
class PyMongoError(Exception):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright 2020-present MongoDB, Inc.
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,212 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Example event logger classes.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
|
||||
These loggers can be registered using :func:`register` or
|
||||
:class:`~pymongo.mongo_client.MongoClient`.
|
||||
|
||||
``monitoring.register(CommandLogger())``
|
||||
|
||||
or
|
||||
|
||||
``MongoClient(event_listeners=[CommandLogger()])``
|
||||
"""
|
||||
"""Re-import of synchronous EventLoggers API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pymongo.synchronous.event_loggers import * # noqa: F403
|
||||
from pymongo.synchronous.event_loggers import __doc__ as original_doc
|
||||
|
||||
from pymongo import monitoring
|
||||
|
||||
|
||||
class CommandLogger(monitoring.CommandListener):
|
||||
"""A simple listener that logs command events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.CommandStartedEvent`,
|
||||
:class:`~pymongo.monitoring.CommandSucceededEvent` and
|
||||
:class:`~pymongo.monitoring.CommandFailedEvent` events and
|
||||
logs them at the `INFO` severity level using :mod:`logging`.
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def started(self, event: monitoring.CommandStartedEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} started on server "
|
||||
f"{event.connection_id}"
|
||||
)
|
||||
|
||||
def succeeded(self, event: monitoring.CommandSucceededEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} on server {event.connection_id} "
|
||||
f"succeeded in {event.duration_micros} "
|
||||
"microseconds"
|
||||
)
|
||||
|
||||
def failed(self, event: monitoring.CommandFailedEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} on server {event.connection_id} "
|
||||
f"failed in {event.duration_micros} "
|
||||
"microseconds"
|
||||
)
|
||||
|
||||
|
||||
class ServerLogger(monitoring.ServerListener):
|
||||
"""A simple listener that logs server discovery events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`,
|
||||
:class:`~pymongo.monitoring.ServerDescriptionChangedEvent`,
|
||||
and :class:`~pymongo.monitoring.ServerClosedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def opened(self, event: monitoring.ServerOpeningEvent) -> None:
|
||||
logging.info(f"Server {event.server_address} added to topology {event.topology_id}")
|
||||
|
||||
def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None:
|
||||
previous_server_type = event.previous_description.server_type
|
||||
new_server_type = event.new_description.server_type
|
||||
if new_server_type != previous_server_type:
|
||||
# server_type_name was added in PyMongo 3.4
|
||||
logging.info(
|
||||
f"Server {event.server_address} changed type from "
|
||||
f"{event.previous_description.server_type_name} to "
|
||||
f"{event.new_description.server_type_name}"
|
||||
)
|
||||
|
||||
def closed(self, event: monitoring.ServerClosedEvent) -> None:
|
||||
logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}")
|
||||
|
||||
|
||||
class HeartbeatLogger(monitoring.ServerHeartbeatListener):
|
||||
"""A simple listener that logs server heartbeat events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`,
|
||||
:class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`,
|
||||
and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
|
||||
logging.info(f"Heartbeat sent to server {event.connection_id}")
|
||||
|
||||
def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
|
||||
# The reply.document attribute was added in PyMongo 3.4.
|
||||
logging.info(
|
||||
f"Heartbeat to server {event.connection_id} "
|
||||
"succeeded with reply "
|
||||
f"{event.reply.document}"
|
||||
)
|
||||
|
||||
def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
|
||||
logging.warning(
|
||||
f"Heartbeat to server {event.connection_id} failed with error {event.reply}"
|
||||
)
|
||||
|
||||
|
||||
class TopologyLogger(monitoring.TopologyListener):
|
||||
"""A simple listener that logs server topology events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`,
|
||||
:class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`,
|
||||
and :class:`~pymongo.monitoring.TopologyClosedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def opened(self, event: monitoring.TopologyOpenedEvent) -> None:
|
||||
logging.info(f"Topology with id {event.topology_id} opened")
|
||||
|
||||
def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None:
|
||||
logging.info(f"Topology description updated for topology id {event.topology_id}")
|
||||
previous_topology_type = event.previous_description.topology_type
|
||||
new_topology_type = event.new_description.topology_type
|
||||
if new_topology_type != previous_topology_type:
|
||||
# topology_type_name was added in PyMongo 3.4
|
||||
logging.info(
|
||||
f"Topology {event.topology_id} changed type from "
|
||||
f"{event.previous_description.topology_type_name} to "
|
||||
f"{event.new_description.topology_type_name}"
|
||||
)
|
||||
# The has_writable_server and has_readable_server methods
|
||||
# were added in PyMongo 3.4.
|
||||
if not event.new_description.has_writable_server():
|
||||
logging.warning("No writable servers available.")
|
||||
if not event.new_description.has_readable_server():
|
||||
logging.warning("No readable servers available.")
|
||||
|
||||
def closed(self, event: monitoring.TopologyClosedEvent) -> None:
|
||||
logging.info(f"Topology with id {event.topology_id} closed")
|
||||
|
||||
|
||||
class ConnectionPoolLogger(monitoring.ConnectionPoolListener):
|
||||
"""A simple listener that logs server connection pool events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`,
|
||||
:class:`~pymongo.monitoring.PoolClearedEvent`,
|
||||
:class:`~pymongo.monitoring.PoolClosedEvent`,
|
||||
:~pymongo.monitoring.class:`ConnectionCreatedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionReadyEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionClosedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckedOutEvent`,
|
||||
and :class:`~pymongo.monitoring.ConnectionCheckedInEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def pool_created(self, event: monitoring.PoolCreatedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool created")
|
||||
|
||||
def pool_ready(self, event: monitoring.PoolReadyEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool ready")
|
||||
|
||||
def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool cleared")
|
||||
|
||||
def pool_closed(self, event: monitoring.PoolClosedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool closed")
|
||||
|
||||
def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created")
|
||||
|
||||
def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded"
|
||||
)
|
||||
|
||||
def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] "
|
||||
f'connection closed, reason: "{event.reason}"'
|
||||
)
|
||||
|
||||
def connection_check_out_started(
|
||||
self, event: monitoring.ConnectionCheckOutStartedEvent
|
||||
) -> None:
|
||||
logging.info(f"[pool {event.address}] connection check out started")
|
||||
|
||||
def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}")
|
||||
|
||||
def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool"
|
||||
)
|
||||
|
||||
def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool"
|
||||
)
|
||||
__doc__ = original_doc
|
||||
|
||||
72
pymongo/helpers_constants.py
Normal file
72
pymongo/helpers_constants.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Constants used by the driver that don't really fit elsewhere."""
|
||||
|
||||
# From the SDAM spec, the "node is shutting down" codes.
|
||||
from __future__ import annotations
|
||||
|
||||
_SHUTDOWN_CODES: frozenset = frozenset(
|
||||
[
|
||||
11600, # InterruptedAtShutdown
|
||||
91, # ShutdownInProgress
|
||||
]
|
||||
)
|
||||
# From the SDAM spec, the "not primary" error codes are combined with the
|
||||
# "node is recovering" error codes (of which the "node is shutting down"
|
||||
# errors are a subset).
|
||||
_NOT_PRIMARY_CODES: frozenset = (
|
||||
frozenset(
|
||||
[
|
||||
10058, # LegacyNotPrimary <=3.2 "not primary" error code
|
||||
10107, # NotWritablePrimary
|
||||
13435, # NotPrimaryNoSecondaryOk
|
||||
11602, # InterruptedDueToReplStateChange
|
||||
13436, # NotPrimaryOrSecondary
|
||||
189, # PrimarySteppedDown
|
||||
]
|
||||
)
|
||||
| _SHUTDOWN_CODES
|
||||
)
|
||||
# From the retryable writes spec.
|
||||
_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
|
||||
[
|
||||
7, # HostNotFound
|
||||
6, # HostUnreachable
|
||||
89, # NetworkTimeout
|
||||
9001, # SocketException
|
||||
262, # ExceededTimeLimit
|
||||
134, # ReadConcernMajorityNotAvailableYet
|
||||
]
|
||||
)
|
||||
|
||||
# Server code raised when re-authentication is required
|
||||
_REAUTHENTICATION_REQUIRED_CODE: int = 391
|
||||
|
||||
# Server code raised when authentication fails.
|
||||
_AUTHENTICATION_FAILURE_CODE: int = 18
|
||||
|
||||
# Note - to avoid bugs from forgetting which if these is all lowercase and
|
||||
# which are camelCase, and at the same time avoid having to add a test for
|
||||
# every command, use all lowercase here and test against command_name.lower().
|
||||
_SENSITIVE_COMMANDS: set = {
|
||||
"authenticate",
|
||||
"saslstart",
|
||||
"saslcontinue",
|
||||
"getnonce",
|
||||
"createuser",
|
||||
"updateuser",
|
||||
"copydbgetnonce",
|
||||
"copydbsaslstart",
|
||||
"copydb",
|
||||
}
|
||||
102
pymongo/lock.py
102
pymongo/lock.py
@ -13,9 +13,12 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
|
||||
|
||||
@ -38,3 +41,102 @@ def _release_locks() -> None:
|
||||
for lock in _forkable_locks:
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
|
||||
|
||||
class _ALock:
|
||||
def __init__(self, lock: threading.Lock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
||||
return self._lock.acquire(blocking=blocking, timeout=timeout)
|
||||
|
||||
async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
||||
if timeout > 0:
|
||||
tstart = time.monotonic()
|
||||
while True:
|
||||
acquired = self._lock.acquire(blocking=False)
|
||||
if acquired:
|
||||
return True
|
||||
if timeout > 0 and (time.monotonic() - tstart) > timeout:
|
||||
return False
|
||||
if not blocking:
|
||||
return False
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def release(self) -> None:
|
||||
self._lock.release()
|
||||
|
||||
async def __aenter__(self) -> _ALock:
|
||||
await self.a_acquire()
|
||||
return self
|
||||
|
||||
def __enter__(self) -> _ALock:
|
||||
self._lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release()
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release()
|
||||
|
||||
|
||||
class _ACondition:
|
||||
def __init__(self, condition: threading.Condition) -> None:
|
||||
self._condition = condition
|
||||
|
||||
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
||||
if timeout > 0:
|
||||
tstart = time.monotonic()
|
||||
while True:
|
||||
acquired = self._condition.acquire(blocking=False)
|
||||
if acquired:
|
||||
return True
|
||||
if timeout > 0 and (time.monotonic() - tstart) > timeout:
|
||||
return False
|
||||
if not blocking:
|
||||
return False
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
if timeout is not None:
|
||||
tstart = time.monotonic()
|
||||
while True:
|
||||
notified = self._condition.wait(0.001)
|
||||
if notified:
|
||||
return True
|
||||
if timeout is not None and (time.monotonic() - tstart) > timeout:
|
||||
return False
|
||||
|
||||
async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool:
|
||||
if timeout is not None:
|
||||
tstart = time.monotonic()
|
||||
while True:
|
||||
notified = self._condition.wait_for(predicate, 0.001)
|
||||
if notified:
|
||||
return True
|
||||
if timeout is not None and (time.monotonic() - tstart) > timeout:
|
||||
return False
|
||||
|
||||
def notify(self, n: int = 1) -> None:
|
||||
self._condition.notify(n)
|
||||
|
||||
def notify_all(self) -> None:
|
||||
self._condition.notify_all()
|
||||
|
||||
def release(self) -> None:
|
||||
self._condition.release()
|
||||
|
||||
async def __aenter__(self) -> _ACondition:
|
||||
await self.acquire()
|
||||
return self
|
||||
|
||||
def __enter__(self) -> _ACondition:
|
||||
self._condition.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release()
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
49
pymongo/network_layer.py
Normal file
49
pymongo/network_layer.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Internal network layer helper methods."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import struct
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pymongo import ssl_support
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pyopenssl_context import _sslConn
|
||||
|
||||
_UNPACK_HEADER = struct.Struct("<iiii").unpack
|
||||
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
|
||||
_POLL_TIMEOUT = 0.5
|
||||
# Errors raised by sockets (and TLS sockets) when in non-blocking mode.
|
||||
BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS)
|
||||
|
||||
|
||||
async def async_sendall(socket: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
timeout = socket.gettimeout()
|
||||
socket.settimeout(0.0)
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
await asyncio.wait_for(loop.sock_sendall(socket, buf), timeout=timeout) # type: ignore[arg-type]
|
||||
finally:
|
||||
socket.settimeout(timeout)
|
||||
|
||||
|
||||
def sendall(socket: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
socket.sendall(buf)
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,612 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Operation class definitions."""
|
||||
"""Re-import of synchronous Operations API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from pymongo.synchronous.operations import * # noqa: F403
|
||||
from pymongo.synchronous.operations import __doc__ as original_doc
|
||||
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo import helpers
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.common import validate_is_mapping, validate_list
|
||||
from pymongo.helpers import _gen_index_name, _index_document, _index_list
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.bulk import _Bulk
|
||||
|
||||
# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary
|
||||
_IndexList = Union[
|
||||
Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any]
|
||||
]
|
||||
_IndexKeyHint = Union[str, _IndexList]
|
||||
|
||||
|
||||
class _Op(str, enum.Enum):
|
||||
ABORT = "abortTransaction"
|
||||
AGGREGATE = "aggregate"
|
||||
COMMIT = "commitTransaction"
|
||||
COUNT = "count"
|
||||
CREATE = "create"
|
||||
CREATE_INDEXES = "createIndexes"
|
||||
CREATE_SEARCH_INDEXES = "createSearchIndexes"
|
||||
DELETE = "delete"
|
||||
DISTINCT = "distinct"
|
||||
DROP = "drop"
|
||||
DROP_DATABASE = "dropDatabase"
|
||||
DROP_INDEXES = "dropIndexes"
|
||||
DROP_SEARCH_INDEXES = "dropSearchIndexes"
|
||||
END_SESSIONS = "endSessions"
|
||||
FIND_AND_MODIFY = "findAndModify"
|
||||
FIND = "find"
|
||||
INSERT = "insert"
|
||||
LIST_COLLECTIONS = "listCollections"
|
||||
LIST_INDEXES = "listIndexes"
|
||||
LIST_SEARCH_INDEX = "listSearchIndexes"
|
||||
LIST_DATABASES = "listDatabases"
|
||||
UPDATE = "update"
|
||||
UPDATE_INDEX = "updateIndex"
|
||||
UPDATE_SEARCH_INDEX = "updateSearchIndex"
|
||||
RENAME = "rename"
|
||||
GETMORE = "getMore"
|
||||
KILL_CURSORS = "killCursors"
|
||||
TEST = "testOperation"
|
||||
|
||||
|
||||
class InsertOne(Generic[_DocumentType]):
|
||||
"""Represents an insert_one operation."""
|
||||
|
||||
__slots__ = ("_doc",)
|
||||
|
||||
def __init__(self, document: _DocumentType) -> None:
|
||||
"""Create an InsertOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
|
||||
:param document: The document to insert. If the document is missing an
|
||||
_id field one will be added.
|
||||
"""
|
||||
self._doc = document
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_insert(self._doc) # type: ignore[arg-type]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"InsertOne({self._doc!r})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return other._doc == self._doc
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
class DeleteOne:
|
||||
"""Represents a delete_one operation."""
|
||||
|
||||
__slots__ = ("_filter", "_collation", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create a DeleteOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the document to delete.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.Collection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.4 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hint`` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
self._filter = filter
|
||||
self._collation = collation
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_delete(
|
||||
self._filter,
|
||||
1,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return (other._filter, other._collation, other._hint) == (
|
||||
self._filter,
|
||||
self._collation,
|
||||
self._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
class DeleteMany:
|
||||
"""Represents a delete_many operation."""
|
||||
|
||||
__slots__ = ("_filter", "_collation", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create a DeleteMany instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the documents to delete.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.Collection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.4 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hint`` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
self._filter = filter
|
||||
self._collation = collation
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_delete(
|
||||
self._filter,
|
||||
0,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return (other._filter, other._collation, other._hint) == (
|
||||
self._filter,
|
||||
self._collation,
|
||||
self._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
class ReplaceOne(Generic[_DocumentType]):
|
||||
"""Represents a replace_one operation."""
|
||||
|
||||
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
replacement: Union[_DocumentType, RawBSONDocument],
|
||||
upsert: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create a ReplaceOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the document to replace.
|
||||
:param replacement: The new document.
|
||||
:param upsert: If ``True``, perform an insert if no documents
|
||||
match the filter.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.Collection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hint`` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the ``collation`` option.
|
||||
"""
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if upsert is not None:
|
||||
validate_boolean("upsert", upsert)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
self._filter = filter
|
||||
self._doc = replacement
|
||||
self._upsert = upsert
|
||||
self._collation = collation
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_replace(
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return (
|
||||
other._filter,
|
||||
other._doc,
|
||||
other._upsert,
|
||||
other._collation,
|
||||
other._hint,
|
||||
) == (
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
other._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format(
|
||||
self.__class__.__name__,
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
self._hint,
|
||||
)
|
||||
|
||||
|
||||
class _UpdateOp:
|
||||
"""Private base class for update operations."""
|
||||
|
||||
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
doc: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool,
|
||||
collation: Optional[_CollationIn],
|
||||
array_filters: Optional[list[Mapping[str, Any]]],
|
||||
hint: Optional[_IndexKeyHint],
|
||||
):
|
||||
if filter is not None:
|
||||
validate_is_mapping("filter", filter)
|
||||
if upsert is not None:
|
||||
validate_boolean("upsert", upsert)
|
||||
if array_filters is not None:
|
||||
validate_list("array_filters", array_filters)
|
||||
if hint is not None and not isinstance(hint, str):
|
||||
self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint)
|
||||
else:
|
||||
self._hint = hint
|
||||
|
||||
self._filter = filter
|
||||
self._doc = doc
|
||||
self._upsert = upsert
|
||||
self._collation = collation
|
||||
self._array_filters = array_filters
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, type(self)):
|
||||
return (
|
||||
other._filter,
|
||||
other._doc,
|
||||
other._upsert,
|
||||
other._collation,
|
||||
other._array_filters,
|
||||
other._hint,
|
||||
) == (
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
self._array_filters,
|
||||
self._hint,
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format(
|
||||
self.__class__.__name__,
|
||||
self._filter,
|
||||
self._doc,
|
||||
self._upsert,
|
||||
self._collation,
|
||||
self._array_filters,
|
||||
self._hint,
|
||||
)
|
||||
|
||||
|
||||
class UpdateOne(_UpdateOp):
|
||||
"""Represents an update_one operation."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[list[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Represents an update_one operation.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the document to update.
|
||||
:param update: The modifications to apply.
|
||||
:param upsert: If ``True``, perform an insert if no documents
|
||||
match the filter.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param array_filters: A list of filters specifying which
|
||||
array elements an update should apply.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.Collection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the `hint` option.
|
||||
.. versionchanged:: 3.9
|
||||
Added the ability to accept a pipeline as the `update`.
|
||||
.. versionchanged:: 3.6
|
||||
Added the `array_filters` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
super().__init__(filter, update, upsert, collation, array_filters, hint)
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_update(
|
||||
self._filter,
|
||||
self._doc,
|
||||
False,
|
||||
self._upsert,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
array_filters=self._array_filters,
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
|
||||
class UpdateMany(_UpdateOp):
|
||||
"""Represents an update_many operation."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[list[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
) -> None:
|
||||
"""Create an UpdateMany instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
|
||||
:param filter: A query that matches the documents to update.
|
||||
:param update: The modifications to apply.
|
||||
:param upsert: If ``True``, perform an insert if no documents
|
||||
match the filter.
|
||||
:param collation: An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
:param array_filters: A list of filters specifying which
|
||||
array elements an update should apply.
|
||||
:param hint: An index to use to support the query
|
||||
predicate specified either by its string name, or in the same
|
||||
format as passed to
|
||||
:meth:`~pymongo.collection.Collection.create_index` (e.g.
|
||||
``[('field', ASCENDING)]``). This option is only supported on
|
||||
MongoDB 4.2 and above.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the `hint` option.
|
||||
.. versionchanged:: 3.9
|
||||
Added the ability to accept a pipeline as the `update`.
|
||||
.. versionchanged:: 3.6
|
||||
Added the `array_filters` option.
|
||||
.. versionchanged:: 3.5
|
||||
Added the `collation` option.
|
||||
"""
|
||||
super().__init__(filter, update, upsert, collation, array_filters, hint)
|
||||
|
||||
def _add_to_bulk(self, bulkobj: _Bulk) -> None:
|
||||
"""Add this operation to the _Bulk instance `bulkobj`."""
|
||||
bulkobj.add_update(
|
||||
self._filter,
|
||||
self._doc,
|
||||
True,
|
||||
self._upsert,
|
||||
collation=validate_collation_or_none(self._collation),
|
||||
array_filters=self._array_filters,
|
||||
hint=self._hint,
|
||||
)
|
||||
|
||||
|
||||
class IndexModel:
|
||||
"""Represents an index to create."""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None:
|
||||
"""Create an Index instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.create_indexes`.
|
||||
|
||||
Takes either a single key or a list containing (key, direction) pairs
|
||||
or keys. If no direction is given, :data:`~pymongo.ASCENDING` will
|
||||
be assumed.
|
||||
The key(s) must be an instance of :class:`str`, and the direction(s) must
|
||||
be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`,
|
||||
:data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`,
|
||||
:data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`).
|
||||
|
||||
Valid options include, but are not limited to:
|
||||
|
||||
- `name`: custom name to use for this index - if none is
|
||||
given, a name will be generated.
|
||||
- `unique`: if ``True``, creates a uniqueness constraint on the index.
|
||||
- `background`: if ``True``, this index should be created in the
|
||||
background.
|
||||
- `sparse`: if ``True``, omit from the index any documents that lack
|
||||
the indexed field.
|
||||
- `bucketSize`: for use with geoHaystack indexes.
|
||||
Number of documents to group together within a certain proximity
|
||||
to a given longitude and latitude.
|
||||
- `min`: minimum value for keys in a :data:`~pymongo.GEO2D`
|
||||
index.
|
||||
- `max`: maximum value for keys in a :data:`~pymongo.GEO2D`
|
||||
index.
|
||||
- `expireAfterSeconds`: <int> Used to create an expiring (TTL)
|
||||
collection. MongoDB will automatically delete documents from
|
||||
this collection after <int> seconds. The indexed field must
|
||||
be a UTC datetime or the data will not expire.
|
||||
- `partialFilterExpression`: A document that specifies a filter for
|
||||
a partial index.
|
||||
- `collation`: An instance of :class:`~pymongo.collation.Collation`
|
||||
that specifies the collation to use.
|
||||
- `wildcardProjection`: Allows users to include or exclude specific
|
||||
field paths from a `wildcard index`_ using the { "$**" : 1} key
|
||||
pattern. Requires MongoDB >= 4.2.
|
||||
- `hidden`: if ``True``, this index will be hidden from the query
|
||||
planner and will not be evaluated as part of query plan
|
||||
selection. Requires MongoDB >= 4.4.
|
||||
|
||||
See the MongoDB documentation for a full list of supported options by
|
||||
server version.
|
||||
|
||||
:param keys: a single key or a list containing (key, direction) pairs
|
||||
or keys specifying the index to create.
|
||||
:param kwargs: any additional index creation
|
||||
options (see the above list) should be passed as keyword
|
||||
arguments.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added the ``hidden`` option.
|
||||
.. versionchanged:: 3.2
|
||||
Added the ``partialFilterExpression`` option to support partial
|
||||
indexes.
|
||||
|
||||
.. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/
|
||||
"""
|
||||
keys = _index_list(keys)
|
||||
if kwargs.get("name") is None:
|
||||
kwargs["name"] = _gen_index_name(keys)
|
||||
kwargs["key"] = _index_document(keys)
|
||||
collation = validate_collation_or_none(kwargs.pop("collation", None))
|
||||
self.__document = kwargs
|
||||
if collation is not None:
|
||||
self.__document["collation"] = collation
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""An index document suitable for passing to the createIndexes
|
||||
command.
|
||||
"""
|
||||
return self.__document
|
||||
|
||||
|
||||
class SearchIndexModel:
|
||||
"""Represents a search index to create."""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
definition: Mapping[str, Any],
|
||||
name: Optional[str] = None,
|
||||
type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a Search Index instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.create_search_index` and :meth:`~pymongo.collection.Collection.create_search_indexes`.
|
||||
|
||||
:param definition: The definition for this index.
|
||||
:param name: The name for this index, if present.
|
||||
:param type: The type for this index which defaults to "search". Alternative values include "vectorSearch".
|
||||
:param kwargs: Keyword arguments supplying any additional options.
|
||||
|
||||
.. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster.
|
||||
.. versionadded:: 4.5
|
||||
.. versionchanged:: 4.7
|
||||
Added the type and kwargs arguments.
|
||||
"""
|
||||
self.__document: dict[str, Any] = {}
|
||||
if name is not None:
|
||||
self.__document["name"] = name
|
||||
self.__document["definition"] = definition
|
||||
if type is not None:
|
||||
self.__document["type"] = type
|
||||
self.__document.update(kwargs)
|
||||
|
||||
@property
|
||||
def document(self) -> Mapping[str, Any]:
|
||||
"""The document for this index."""
|
||||
return self.__document
|
||||
__doc__ = original_doc
|
||||
|
||||
2111
pymongo/pool.py
2111
pymongo/pool.py
File diff suppressed because it is too large
Load Diff
@ -17,6 +17,7 @@ context.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket as _socket
|
||||
import ssl as _stdlibssl
|
||||
import sys as _sys
|
||||
@ -364,6 +365,58 @@ class SSLContext:
|
||||
# but not that same as CPython's.
|
||||
self._ctx.set_default_verify_paths()
|
||||
|
||||
async def a_wrap_socket(
|
||||
self,
|
||||
sock: _socket.socket,
|
||||
server_side: bool = False,
|
||||
do_handshake_on_connect: bool = True,
|
||||
suppress_ragged_eofs: bool = True,
|
||||
server_hostname: Optional[str] = None,
|
||||
session: Optional[_SSL.Session] = None,
|
||||
) -> _sslConn:
|
||||
"""Wrap an existing Python socket connection and return a TLS socket
|
||||
object.
|
||||
"""
|
||||
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
|
||||
loop = asyncio.get_running_loop()
|
||||
if session:
|
||||
ssl_conn.set_session(session)
|
||||
if server_side is True:
|
||||
ssl_conn.set_accept_state()
|
||||
else:
|
||||
# SNI
|
||||
if server_hostname and not _is_ip_address(server_hostname):
|
||||
# XXX: Do this in a callback registered with
|
||||
# SSLContext.set_info_callback? See Twisted for an example.
|
||||
ssl_conn.set_tlsext_host_name(server_hostname.encode("idna"))
|
||||
if self.verify_mode != _stdlibssl.CERT_NONE:
|
||||
# Request a stapled OCSP response.
|
||||
await loop.run_in_executor(None, ssl_conn.request_ocsp)
|
||||
ssl_conn.set_connect_state()
|
||||
# If this wasn't true the caller of wrap_socket would call
|
||||
# do_handshake()
|
||||
if do_handshake_on_connect:
|
||||
# XXX: If we do hostname checking in a callback we can get rid
|
||||
# of this call to do_handshake() since the handshake
|
||||
# will happen automatically later.
|
||||
await loop.run_in_executor(None, ssl_conn.do_handshake)
|
||||
# XXX: Do this in a callback registered with
|
||||
# SSLContext.set_info_callback? See Twisted for an example.
|
||||
if self.check_hostname and server_hostname is not None:
|
||||
from service_identity import pyopenssl
|
||||
|
||||
try:
|
||||
if _is_ip_address(server_hostname):
|
||||
pyopenssl.verify_ip_address(ssl_conn, server_hostname)
|
||||
else:
|
||||
pyopenssl.verify_hostname(ssl_conn, server_hostname)
|
||||
except ( # type:ignore[misc]
|
||||
service_identity.SICertificateError,
|
||||
service_identity.SIVerificationError,
|
||||
) as exc:
|
||||
raise _CertificateError(str(exc)) from None
|
||||
return ssl_conn
|
||||
|
||||
def wrap_socket(
|
||||
self,
|
||||
sock: _socket.socket,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Copyright 2012-present MongoDB, Inc.
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License",
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
@ -12,611 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for choosing which member of a replica set to read from."""
|
||||
|
||||
"""Re-import of synchronous ReadPreferences API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import abc
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
|
||||
from pymongo.synchronous.read_preferences import * # noqa: F403
|
||||
from pymongo.synchronous.read_preferences import __doc__ as original_doc
|
||||
|
||||
from pymongo import max_staleness_selectors
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.server_selectors import (
|
||||
member_with_tags_server_selector,
|
||||
secondary_with_tags_server_selector,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.server_selectors import Selection
|
||||
from pymongo.topology_description import TopologyDescription
|
||||
|
||||
_PRIMARY = 0
|
||||
_PRIMARY_PREFERRED = 1
|
||||
_SECONDARY = 2
|
||||
_SECONDARY_PREFERRED = 3
|
||||
_NEAREST = 4
|
||||
|
||||
|
||||
_MONGOS_MODES = (
|
||||
"primary",
|
||||
"primaryPreferred",
|
||||
"secondary",
|
||||
"secondaryPreferred",
|
||||
"nearest",
|
||||
)
|
||||
|
||||
_Hedge = Mapping[str, Any]
|
||||
_TagSets = Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]:
|
||||
"""Validate tag sets for a MongoClient."""
|
||||
if tag_sets is None:
|
||||
return tag_sets
|
||||
|
||||
if not isinstance(tag_sets, (list, tuple)):
|
||||
raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence")
|
||||
if len(tag_sets) == 0:
|
||||
raise ValueError(
|
||||
f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags"
|
||||
)
|
||||
|
||||
for tags in tag_sets:
|
||||
if not isinstance(tags, abc.Mapping):
|
||||
raise TypeError(
|
||||
f"Tag set {tags!r} invalid, must be an instance of dict, "
|
||||
"bson.son.SON or other type that inherits from "
|
||||
"collection.Mapping"
|
||||
)
|
||||
|
||||
return list(tag_sets)
|
||||
|
||||
|
||||
def _invalid_max_staleness_msg(max_staleness: Any) -> str:
|
||||
return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness
|
||||
|
||||
|
||||
# Some duplication with common.py to avoid import cycle.
|
||||
def _validate_max_staleness(max_staleness: Any) -> int:
|
||||
"""Validate max_staleness."""
|
||||
if max_staleness == -1:
|
||||
return -1
|
||||
|
||||
if not isinstance(max_staleness, int):
|
||||
raise TypeError(_invalid_max_staleness_msg(max_staleness))
|
||||
|
||||
if max_staleness <= 0:
|
||||
raise ValueError(_invalid_max_staleness_msg(max_staleness))
|
||||
|
||||
return max_staleness
|
||||
|
||||
|
||||
def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]:
|
||||
"""Validate hedge."""
|
||||
if hedge is None:
|
||||
return None
|
||||
|
||||
if not isinstance(hedge, dict):
|
||||
raise TypeError(f"hedge must be a dictionary, not {hedge!r}")
|
||||
|
||||
return hedge
|
||||
|
||||
|
||||
class _ServerMode:
|
||||
"""Base class for all read preferences."""
|
||||
|
||||
__slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: int,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
self.__mongos_mode = _MONGOS_MODES[mode]
|
||||
self.__mode = mode
|
||||
self.__tag_sets = _validate_tag_sets(tag_sets)
|
||||
self.__max_staleness = _validate_max_staleness(max_staleness)
|
||||
self.__hedge = _validate_hedge(hedge)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of this read preference."""
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def mongos_mode(self) -> str:
|
||||
"""The mongos mode of this read preference."""
|
||||
return self.__mongos_mode
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""Read preference as a document."""
|
||||
doc: dict[str, Any] = {"mode": self.__mongos_mode}
|
||||
if self.__tag_sets not in (None, [{}]):
|
||||
doc["tags"] = self.__tag_sets
|
||||
if self.__max_staleness != -1:
|
||||
doc["maxStalenessSeconds"] = self.__max_staleness
|
||||
if self.__hedge not in (None, {}):
|
||||
doc["hedge"] = self.__hedge
|
||||
return doc
|
||||
|
||||
@property
|
||||
def mode(self) -> int:
|
||||
"""The mode of this read preference instance."""
|
||||
return self.__mode
|
||||
|
||||
@property
|
||||
def tag_sets(self) -> _TagSets:
|
||||
"""Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to
|
||||
read only from members whose ``dc`` tag has the value ``"ny"``.
|
||||
To specify a priority-order for tag sets, provide a list of
|
||||
tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag
|
||||
set, ``{}``, means "read from any member that matches the mode,
|
||||
ignoring tags." MongoClient tries each set of tags in turn
|
||||
until it finds a set of tags with at least one matching member.
|
||||
For example, to only send a query to an analytic node::
|
||||
|
||||
Nearest(tag_sets=[{"node":"analytics"}])
|
||||
|
||||
Or using :class:`SecondaryPreferred`::
|
||||
|
||||
SecondaryPreferred(tag_sets=[{"node":"analytics"}])
|
||||
|
||||
.. seealso:: `Data-Center Awareness
|
||||
<https://www.mongodb.com/docs/manual/data-center-awareness/>`_
|
||||
"""
|
||||
return list(self.__tag_sets) if self.__tag_sets else [{}]
|
||||
|
||||
@property
|
||||
def max_staleness(self) -> int:
|
||||
"""The maximum estimated length of time (in seconds) a replica set
|
||||
secondary can fall behind the primary in replication before it will
|
||||
no longer be selected for operations, or -1 for no maximum.
|
||||
"""
|
||||
return self.__max_staleness
|
||||
|
||||
@property
|
||||
def hedge(self) -> Optional[_Hedge]:
|
||||
"""The read preference ``hedge`` parameter.
|
||||
|
||||
A dictionary that configures how the server will perform hedged reads.
|
||||
It consists of the following keys:
|
||||
|
||||
- ``enabled``: Enables or disables hedged reads in sharded clusters.
|
||||
|
||||
Hedged reads are automatically enabled in MongoDB 4.4+ when using a
|
||||
``nearest`` read preference. To explicitly enable hedged reads, set
|
||||
the ``enabled`` key to ``true``::
|
||||
|
||||
>>> Nearest(hedge={'enabled': True})
|
||||
|
||||
To explicitly disable hedged reads, set the ``enabled`` key to
|
||||
``False``::
|
||||
|
||||
>>> Nearest(hedge={'enabled': False})
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
return self.__hedge
|
||||
|
||||
@property
|
||||
def min_wire_version(self) -> int:
|
||||
"""The wire protocol version the server must support.
|
||||
|
||||
Some read preferences impose version requirements on all servers (e.g.
|
||||
maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5).
|
||||
|
||||
All servers' maxWireVersion must be at least this read preference's
|
||||
`min_wire_version`, or the driver raises
|
||||
:exc:`~pymongo.errors.ConfigurationError`.
|
||||
"""
|
||||
return 0 if self.__max_staleness == -1 else 5
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format(
|
||||
self.name,
|
||||
self.__tag_sets,
|
||||
self.__max_staleness,
|
||||
self.__hedge,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, _ServerMode):
|
||||
return (
|
||||
self.mode == other.mode
|
||||
and self.tag_sets == other.tag_sets
|
||||
and self.max_staleness == other.max_staleness
|
||||
and self.hedge == other.hedge
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
"""Return value of object for pickling.
|
||||
|
||||
Needed explicitly because __slots__() defined.
|
||||
"""
|
||||
return {
|
||||
"mode": self.__mode,
|
||||
"tag_sets": self.__tag_sets,
|
||||
"max_staleness": self.__max_staleness,
|
||||
"hedge": self.__hedge,
|
||||
}
|
||||
|
||||
def __setstate__(self, value: Mapping[str, Any]) -> None:
|
||||
"""Restore from pickling."""
|
||||
self.__mode = value["mode"]
|
||||
self.__mongos_mode = _MONGOS_MODES[self.__mode]
|
||||
self.__tag_sets = _validate_tag_sets(value["tag_sets"])
|
||||
self.__max_staleness = _validate_max_staleness(value["max_staleness"])
|
||||
self.__hedge = _validate_hedge(value["hedge"])
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
return selection
|
||||
|
||||
|
||||
class Primary(_ServerMode):
|
||||
"""Primary read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed if the server
|
||||
is standalone or a replica set primary.
|
||||
* When connected to a mongos queries are sent to the primary of a shard.
|
||||
* When connected to a replica set queries are sent to the primary of
|
||||
the replica set.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(_PRIMARY)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to a Selection."""
|
||||
return selection.primary_selection
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Primary()"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, _ServerMode):
|
||||
return other.mode == _PRIMARY
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class PrimaryPreferred(_ServerMode):
|
||||
"""PrimaryPreferred read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are sent to the primary of a shard if
|
||||
available, otherwise a shard secondary.
|
||||
* When connected to a replica set queries are sent to the primary if
|
||||
available, otherwise a secondary.
|
||||
|
||||
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
|
||||
created reads will be routed to an available secondary until the
|
||||
primary of the replica set is discovered.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` to use if the primary is not
|
||||
available.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` to use if the primary is not available.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
if selection.primary:
|
||||
return selection.primary_selection
|
||||
else:
|
||||
return secondary_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
|
||||
class Secondary(_ServerMode):
|
||||
"""Secondary read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are distributed among shard
|
||||
secondaries. An error is raised if no secondaries are available.
|
||||
* When connected to a replica set queries are distributed among
|
||||
secondaries. An error is raised if no secondaries are available.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` for this read preference.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` for this read preference.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_SECONDARY, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
return secondary_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
|
||||
class SecondaryPreferred(_ServerMode):
|
||||
"""SecondaryPreferred read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are distributed among shard
|
||||
secondaries, or the shard primary if no secondary is available.
|
||||
* When connected to a replica set queries are distributed among
|
||||
secondaries, or the primary if no secondary is available.
|
||||
|
||||
.. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first
|
||||
created reads will be routed to the primary of the replica set until
|
||||
an available secondary is discovered.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` for this read preference.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` for this read preference.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
secondaries = secondary_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
if secondaries:
|
||||
return secondaries
|
||||
else:
|
||||
return selection.primary_selection
|
||||
|
||||
|
||||
class Nearest(_ServerMode):
|
||||
"""Nearest read preference.
|
||||
|
||||
* When directly connected to one mongod queries are allowed to standalone
|
||||
servers, to a replica set primary, or to replica set secondaries.
|
||||
* When connected to a mongos queries are distributed among all members of
|
||||
a shard.
|
||||
* When connected to a replica set queries are distributed among all
|
||||
members.
|
||||
|
||||
:param tag_sets: The :attr:`~tag_sets` for this read preference.
|
||||
:param max_staleness: (integer, in seconds) The maximum estimated
|
||||
length of time a replica set secondary can fall behind the primary in
|
||||
replication before it will no longer be selected for operations.
|
||||
Default -1, meaning no maximum. If it is set, it must be at least
|
||||
90 seconds.
|
||||
:param hedge: The :attr:`~hedge` for this read preference.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Added ``hedge`` parameter.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tag_sets: Optional[_TagSets] = None,
|
||||
max_staleness: int = -1,
|
||||
hedge: Optional[_Hedge] = None,
|
||||
) -> None:
|
||||
super().__init__(_NEAREST, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to Selection."""
|
||||
return member_with_tags_server_selector(
|
||||
self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection)
|
||||
)
|
||||
|
||||
|
||||
class _AggWritePref:
|
||||
"""Agg $out/$merge write preference.
|
||||
|
||||
* If there are readable servers and there is any pre-5.0 server, use
|
||||
primary read preference.
|
||||
* Otherwise use `pref` read preference.
|
||||
|
||||
:param pref: The read preference to use on MongoDB 5.0+.
|
||||
"""
|
||||
|
||||
__slots__ = ("pref", "effective_pref")
|
||||
|
||||
def __init__(self, pref: _ServerMode):
|
||||
self.pref = pref
|
||||
self.effective_pref: _ServerMode = ReadPreference.PRIMARY
|
||||
|
||||
def selection_hook(self, topology_description: TopologyDescription) -> None:
|
||||
common_wv = topology_description.common_wire_version
|
||||
if (
|
||||
topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED)
|
||||
and common_wv
|
||||
and common_wv < 13
|
||||
):
|
||||
self.effective_pref = ReadPreference.PRIMARY
|
||||
else:
|
||||
self.effective_pref = self.pref
|
||||
|
||||
def __call__(self, selection: Selection) -> Selection:
|
||||
"""Apply this read preference to a Selection."""
|
||||
return self.effective_pref(selection)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"_AggWritePref(pref={self.pref!r})"
|
||||
|
||||
# Proxy other calls to the effective_pref so that _AggWritePref can be
|
||||
# used in place of an actual read preference.
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self.effective_pref, name)
|
||||
|
||||
|
||||
_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest)
|
||||
|
||||
|
||||
def make_read_preference(
|
||||
mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1
|
||||
) -> _ServerMode:
|
||||
if mode == _PRIMARY:
|
||||
if tag_sets not in (None, [{}]):
|
||||
raise ConfigurationError("Read preference primary cannot be combined with tags")
|
||||
if max_staleness != -1:
|
||||
raise ConfigurationError(
|
||||
"Read preference primary cannot be combined with maxStalenessSeconds"
|
||||
)
|
||||
return Primary()
|
||||
return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore
|
||||
|
||||
|
||||
_MODES = (
|
||||
"PRIMARY",
|
||||
"PRIMARY_PREFERRED",
|
||||
"SECONDARY",
|
||||
"SECONDARY_PREFERRED",
|
||||
"NEAREST",
|
||||
)
|
||||
|
||||
|
||||
class ReadPreference:
|
||||
"""An enum that defines some commonly used read preference modes.
|
||||
|
||||
Apps can also create a custom read preference, for example::
|
||||
|
||||
Nearest(tag_sets=[{"node":"analytics"}])
|
||||
|
||||
See :doc:`/examples/high_availability` for code examples.
|
||||
|
||||
A read preference is used in three cases:
|
||||
|
||||
:class:`~pymongo.mongo_client.MongoClient` connected to a single mongod:
|
||||
|
||||
- ``PRIMARY``: Queries are allowed if the server is standalone or a replica
|
||||
set primary.
|
||||
- All other modes allow queries to standalone servers, to a replica set
|
||||
primary, or to replica set secondaries.
|
||||
|
||||
:class:`~pymongo.mongo_client.MongoClient` initialized with the
|
||||
``replicaSet`` option:
|
||||
|
||||
- ``PRIMARY``: Read from the primary. This is the default, and provides the
|
||||
strongest consistency. If no primary is available, raise
|
||||
:class:`~pymongo.errors.AutoReconnect`.
|
||||
|
||||
- ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is
|
||||
none, read from a secondary.
|
||||
|
||||
- ``SECONDARY``: Read from a secondary. If no secondary is available,
|
||||
raise :class:`~pymongo.errors.AutoReconnect`.
|
||||
|
||||
- ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise
|
||||
from the primary.
|
||||
|
||||
- ``NEAREST``: Read from any member.
|
||||
|
||||
:class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a
|
||||
sharded cluster of replica sets:
|
||||
|
||||
- ``PRIMARY``: Read from the primary of the shard, or raise
|
||||
:class:`~pymongo.errors.OperationFailure` if there is none.
|
||||
This is the default.
|
||||
|
||||
- ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is
|
||||
none, read from a secondary of the shard.
|
||||
|
||||
- ``SECONDARY``: Read from a secondary of the shard, or raise
|
||||
:class:`~pymongo.errors.OperationFailure` if there is none.
|
||||
|
||||
- ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available,
|
||||
otherwise from the shard primary.
|
||||
|
||||
- ``NEAREST``: Read from any shard member.
|
||||
"""
|
||||
|
||||
PRIMARY = Primary()
|
||||
PRIMARY_PREFERRED = PrimaryPreferred()
|
||||
SECONDARY = Secondary()
|
||||
SECONDARY_PREFERRED = SecondaryPreferred()
|
||||
NEAREST = Nearest()
|
||||
|
||||
|
||||
def read_pref_mode_from_name(name: str) -> int:
|
||||
"""Get the read preference mode from mongos/uri name."""
|
||||
return _MONGOS_MODES.index(name)
|
||||
|
||||
|
||||
class MovingAverage:
|
||||
"""Tracks an exponentially-weighted moving average."""
|
||||
|
||||
average: Optional[float]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.average = None
|
||||
|
||||
def add_sample(self, sample: float) -> None:
|
||||
if sample < 0:
|
||||
# Likely system time change while waiting for hello response
|
||||
# and not using time.monotonic. Ignore it, the next one will
|
||||
# probably be valid.
|
||||
return
|
||||
if self.average is None:
|
||||
self.average = sample
|
||||
else:
|
||||
# The Server Selection Spec requires an exponentially weighted
|
||||
# average with alpha = 0.2.
|
||||
self.average = 0.8 * self.average + 0.2 * sample
|
||||
|
||||
def get(self) -> Optional[float]:
|
||||
"""Get the calculated average, or None if no samples yet."""
|
||||
return self.average
|
||||
|
||||
def reset(self) -> None:
|
||||
self.average = None
|
||||
__doc__ = original_doc
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,288 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Represent one server the driver is connected to."""
|
||||
"""Re-import of synchronous ServerDescription API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Mapping, Optional
|
||||
from pymongo.synchronous.server_description import * # noqa: F403
|
||||
from pymongo.synchronous.server_description import __doc__ as original_doc
|
||||
|
||||
from bson import EPOCH_NAIVE
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.typings import ClusterTime, _Address
|
||||
|
||||
|
||||
class ServerDescription:
|
||||
"""Immutable representation of one server.
|
||||
|
||||
:param address: A (host, port) pair
|
||||
:param hello: Optional Hello instance
|
||||
:param round_trip_time: Optional float
|
||||
:param error: Optional, the last error attempting to connect to the server
|
||||
:param round_trip_time: Optional float, the min latency from the most recent samples
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"_address",
|
||||
"_server_type",
|
||||
"_all_hosts",
|
||||
"_tags",
|
||||
"_replica_set_name",
|
||||
"_primary",
|
||||
"_max_bson_size",
|
||||
"_max_message_size",
|
||||
"_max_write_batch_size",
|
||||
"_min_wire_version",
|
||||
"_max_wire_version",
|
||||
"_round_trip_time",
|
||||
"_min_round_trip_time",
|
||||
"_me",
|
||||
"_is_writable",
|
||||
"_is_readable",
|
||||
"_ls_timeout_minutes",
|
||||
"_error",
|
||||
"_set_version",
|
||||
"_election_id",
|
||||
"_cluster_time",
|
||||
"_last_write_date",
|
||||
"_last_update_time",
|
||||
"_topology_version",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: _Address,
|
||||
hello: Optional[Hello] = None,
|
||||
round_trip_time: Optional[float] = None,
|
||||
error: Optional[Exception] = None,
|
||||
min_round_trip_time: float = 0.0,
|
||||
) -> None:
|
||||
self._address = address
|
||||
if not hello:
|
||||
hello = Hello({})
|
||||
|
||||
self._server_type = hello.server_type
|
||||
self._all_hosts = hello.all_hosts
|
||||
self._tags = hello.tags
|
||||
self._replica_set_name = hello.replica_set_name
|
||||
self._primary = hello.primary
|
||||
self._max_bson_size = hello.max_bson_size
|
||||
self._max_message_size = hello.max_message_size
|
||||
self._max_write_batch_size = hello.max_write_batch_size
|
||||
self._min_wire_version = hello.min_wire_version
|
||||
self._max_wire_version = hello.max_wire_version
|
||||
self._set_version = hello.set_version
|
||||
self._election_id = hello.election_id
|
||||
self._cluster_time = hello.cluster_time
|
||||
self._is_writable = hello.is_writable
|
||||
self._is_readable = hello.is_readable
|
||||
self._ls_timeout_minutes = hello.logical_session_timeout_minutes
|
||||
self._round_trip_time = round_trip_time
|
||||
self._min_round_trip_time = min_round_trip_time
|
||||
self._me = hello.me
|
||||
self._last_update_time = time.monotonic()
|
||||
self._error = error
|
||||
self._topology_version = hello.topology_version
|
||||
if error:
|
||||
details = getattr(error, "details", None)
|
||||
if isinstance(details, dict):
|
||||
self._topology_version = details.get("topologyVersion")
|
||||
|
||||
self._last_write_date: Optional[float]
|
||||
if hello.last_write_date:
|
||||
# Convert from datetime to seconds.
|
||||
delta = hello.last_write_date - EPOCH_NAIVE
|
||||
self._last_write_date = delta.total_seconds()
|
||||
else:
|
||||
self._last_write_date = None
|
||||
|
||||
@property
|
||||
def address(self) -> _Address:
|
||||
"""The address (host, port) of this server."""
|
||||
return self._address
|
||||
|
||||
@property
|
||||
def server_type(self) -> int:
|
||||
"""The type of this server."""
|
||||
return self._server_type
|
||||
|
||||
@property
|
||||
def server_type_name(self) -> str:
|
||||
"""The server type as a human readable string.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
return SERVER_TYPE._fields[self._server_type]
|
||||
|
||||
@property
|
||||
def all_hosts(self) -> set[tuple[str, int]]:
|
||||
"""List of hosts, passives, and arbiters known to this server."""
|
||||
return self._all_hosts
|
||||
|
||||
@property
|
||||
def tags(self) -> Mapping[str, Any]:
|
||||
return self._tags
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""Replica set name or None."""
|
||||
return self._replica_set_name
|
||||
|
||||
@property
|
||||
def primary(self) -> Optional[tuple[str, int]]:
|
||||
"""This server's opinion about who the primary is, or None."""
|
||||
return self._primary
|
||||
|
||||
@property
|
||||
def max_bson_size(self) -> int:
|
||||
return self._max_bson_size
|
||||
|
||||
@property
|
||||
def max_message_size(self) -> int:
|
||||
return self._max_message_size
|
||||
|
||||
@property
|
||||
def max_write_batch_size(self) -> int:
|
||||
return self._max_write_batch_size
|
||||
|
||||
@property
|
||||
def min_wire_version(self) -> int:
|
||||
return self._min_wire_version
|
||||
|
||||
@property
|
||||
def max_wire_version(self) -> int:
|
||||
return self._max_wire_version
|
||||
|
||||
@property
|
||||
def set_version(self) -> Optional[int]:
|
||||
return self._set_version
|
||||
|
||||
@property
|
||||
def election_id(self) -> Optional[ObjectId]:
|
||||
return self._election_id
|
||||
|
||||
@property
|
||||
def cluster_time(self) -> Optional[ClusterTime]:
|
||||
return self._cluster_time
|
||||
|
||||
@property
|
||||
def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]:
|
||||
warnings.warn(
|
||||
"'election_tuple' is deprecated, use 'set_version' and 'election_id' instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._set_version, self._election_id
|
||||
|
||||
@property
|
||||
def me(self) -> Optional[tuple[str, int]]:
|
||||
return self._me
|
||||
|
||||
@property
|
||||
def logical_session_timeout_minutes(self) -> Optional[int]:
|
||||
return self._ls_timeout_minutes
|
||||
|
||||
@property
|
||||
def last_write_date(self) -> Optional[float]:
|
||||
return self._last_write_date
|
||||
|
||||
@property
|
||||
def last_update_time(self) -> float:
|
||||
return self._last_update_time
|
||||
|
||||
@property
|
||||
def round_trip_time(self) -> Optional[float]:
|
||||
"""The current average latency or None."""
|
||||
# This override is for unittesting only!
|
||||
if self._address in self._host_to_round_trip_time:
|
||||
return self._host_to_round_trip_time[self._address]
|
||||
|
||||
return self._round_trip_time
|
||||
|
||||
@property
|
||||
def min_round_trip_time(self) -> float:
|
||||
"""The min latency from the most recent samples."""
|
||||
return self._min_round_trip_time
|
||||
|
||||
@property
|
||||
def error(self) -> Optional[Exception]:
|
||||
"""The last error attempting to connect to the server, or None."""
|
||||
return self._error
|
||||
|
||||
@property
|
||||
def is_writable(self) -> bool:
|
||||
return self._is_writable
|
||||
|
||||
@property
|
||||
def is_readable(self) -> bool:
|
||||
return self._is_readable
|
||||
|
||||
@property
|
||||
def mongos(self) -> bool:
|
||||
return self._server_type == SERVER_TYPE.Mongos
|
||||
|
||||
@property
|
||||
def is_server_type_known(self) -> bool:
|
||||
return self.server_type != SERVER_TYPE.Unknown
|
||||
|
||||
@property
|
||||
def retryable_writes_supported(self) -> bool:
|
||||
"""Checks if this server supports retryable writes."""
|
||||
return (
|
||||
self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)
|
||||
) or self._server_type == SERVER_TYPE.LoadBalancer
|
||||
|
||||
@property
|
||||
def retryable_reads_supported(self) -> bool:
|
||||
"""Checks if this server supports retryable writes."""
|
||||
return self._max_wire_version >= 6
|
||||
|
||||
@property
|
||||
def topology_version(self) -> Optional[Mapping[str, Any]]:
|
||||
return self._topology_version
|
||||
|
||||
def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription:
|
||||
unknown = ServerDescription(self.address, error=error)
|
||||
unknown._topology_version = self.topology_version
|
||||
return unknown
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, ServerDescription):
|
||||
return (
|
||||
(self._address == other.address)
|
||||
and (self._server_type == other.server_type)
|
||||
and (self._min_wire_version == other.min_wire_version)
|
||||
and (self._max_wire_version == other.max_wire_version)
|
||||
and (self._me == other.me)
|
||||
and (self._all_hosts == other.all_hosts)
|
||||
and (self._tags == other.tags)
|
||||
and (self._replica_set_name == other.replica_set_name)
|
||||
and (self._set_version == other.set_version)
|
||||
and (self._election_id == other.election_id)
|
||||
and (self._primary == other.primary)
|
||||
and (self._ls_timeout_minutes == other.logical_session_timeout_minutes)
|
||||
and (self._error == other.error)
|
||||
)
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __repr__(self) -> str:
|
||||
errmsg = ""
|
||||
if self.error:
|
||||
errmsg = f", error={self.error!r}"
|
||||
return "<{} {} server_type: {}, rtt: {}{}>".format(
|
||||
self.__class__.__name__,
|
||||
self.address,
|
||||
self.server_type_name,
|
||||
self.round_trip_time,
|
||||
errmsg,
|
||||
)
|
||||
|
||||
# For unittesting only. Use under no circumstances!
|
||||
_host_to_round_trip_time: dict = {}
|
||||
__doc__ = original_doc
|
||||
|
||||
0
pymongo/synchronous/__init__.py
Normal file
0
pymongo/synchronous/__init__.py
Normal file
@ -18,20 +18,22 @@ from __future__ import annotations
|
||||
from collections.abc import Callable, Mapping, MutableMapping
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pymongo import common
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_preferences import ReadPreference, _AggWritePref
|
||||
from pymongo.synchronous import common
|
||||
from pymongo.synchronous.collation import validate_collation_or_none
|
||||
from pymongo.synchronous.read_preferences import ReadPreference, _AggWritePref
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
from pymongo.database import Database
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.server import Server
|
||||
from pymongo.typings import _DocumentType, _Pipeline
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.pool import Connection
|
||||
from pymongo.synchronous.read_preferences import _ServerMode
|
||||
from pymongo.synchronous.server import Server
|
||||
from pymongo.synchronous.typings import _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class _AggregationCommand:
|
||||
658
pymongo/synchronous/auth.py
Normal file
658
pymongo/synchronous/auth.py
Normal file
@ -0,0 +1,658 @@
|
||||
# Copyright 2013-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import socket
|
||||
import typing
|
||||
from base64 import standard_b64decode, standard_b64encode
|
||||
from collections import namedtuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import quote
|
||||
|
||||
from bson.binary import Binary
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.saslprep import saslprep
|
||||
from pymongo.synchronous.auth_aws import _authenticate_aws
|
||||
from pymongo.synchronous.auth_oidc import (
|
||||
_authenticate_oidc,
|
||||
_get_authenticator,
|
||||
_OIDCAzureCallback,
|
||||
_OIDCGCPCallback,
|
||||
_OIDCProperties,
|
||||
_OIDCTestCallback,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.hello import Hello
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
HAVE_KERBEROS = True
|
||||
_USE_PRINCIPAL = False
|
||||
try:
|
||||
import winkerberos as kerberos # type:ignore[import]
|
||||
|
||||
if tuple(map(int, kerberos.__version__.split(".")[:2])) >= (0, 5):
|
||||
_USE_PRINCIPAL = True
|
||||
except ImportError:
|
||||
try:
|
||||
import kerberos # type:ignore[import]
|
||||
except ImportError:
|
||||
HAVE_KERBEROS = False
|
||||
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
MECHANISMS = frozenset(
|
||||
[
|
||||
"GSSAPI",
|
||||
"MONGODB-CR",
|
||||
"MONGODB-OIDC",
|
||||
"MONGODB-X509",
|
||||
"MONGODB-AWS",
|
||||
"PLAIN",
|
||||
"SCRAM-SHA-1",
|
||||
"SCRAM-SHA-256",
|
||||
"DEFAULT",
|
||||
]
|
||||
)
|
||||
"""The authentication mechanisms supported by PyMongo."""
|
||||
|
||||
|
||||
class _Cache:
|
||||
__slots__ = ("data",)
|
||||
|
||||
_hash_val = hash("_Cache")
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = None
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
# Two instances must always compare equal.
|
||||
if isinstance(other, _Cache):
|
||||
return True
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
if isinstance(other, _Cache):
|
||||
return False
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self._hash_val
|
||||
|
||||
|
||||
MongoCredential = namedtuple(
|
||||
"MongoCredential",
|
||||
["mechanism", "source", "username", "password", "mechanism_properties", "cache"],
|
||||
)
|
||||
"""A hashable namedtuple of values used for authentication."""
|
||||
|
||||
|
||||
GSSAPIProperties = namedtuple(
|
||||
"GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"]
|
||||
)
|
||||
"""Mechanism properties for GSSAPI authentication."""
|
||||
|
||||
|
||||
_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
|
||||
"""Mechanism properties for MONGODB-AWS authentication."""
|
||||
|
||||
|
||||
def _build_credentials_tuple(
|
||||
mech: str,
|
||||
source: Optional[str],
|
||||
user: str,
|
||||
passwd: str,
|
||||
extra: Mapping[str, Any],
|
||||
database: Optional[str],
|
||||
) -> MongoCredential:
|
||||
"""Build and return a mechanism specific credentials tuple."""
|
||||
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
|
||||
raise ConfigurationError(f"{mech} requires a username.")
|
||||
if mech == "GSSAPI":
|
||||
if source is not None and source != "$external":
|
||||
raise ValueError("authentication source must be $external or None for GSSAPI")
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
service_name = properties.get("SERVICE_NAME", "mongodb")
|
||||
canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False))
|
||||
service_realm = properties.get("SERVICE_REALM")
|
||||
props = GSSAPIProperties(
|
||||
service_name=service_name,
|
||||
canonicalize_host_name=canonicalize,
|
||||
service_realm=service_realm,
|
||||
)
|
||||
# Source is always $external.
|
||||
return MongoCredential(mech, "$external", user, passwd, props, None)
|
||||
elif mech == "MONGODB-X509":
|
||||
if passwd is not None:
|
||||
raise ConfigurationError("Passwords are not supported by MONGODB-X509")
|
||||
if source is not None and source != "$external":
|
||||
raise ValueError("authentication source must be $external or None for MONGODB-X509")
|
||||
# Source is always $external, user can be None.
|
||||
return MongoCredential(mech, "$external", user, None, None, None)
|
||||
elif mech == "MONGODB-AWS":
|
||||
if user is not None and passwd is None:
|
||||
raise ConfigurationError("username without a password is not supported by MONGODB-AWS")
|
||||
if source is not None and source != "$external":
|
||||
raise ConfigurationError(
|
||||
"authentication source must be $external or None for MONGODB-AWS"
|
||||
)
|
||||
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
aws_session_token = properties.get("AWS_SESSION_TOKEN")
|
||||
aws_props = _AWSProperties(aws_session_token=aws_session_token)
|
||||
# user can be None for temporary link-local EC2 credentials.
|
||||
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
|
||||
elif mech == "MONGODB-OIDC":
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
callback = properties.get("OIDC_CALLBACK")
|
||||
human_callback = properties.get("OIDC_HUMAN_CALLBACK")
|
||||
environ = properties.get("ENVIRONMENT")
|
||||
token_resource = properties.get("TOKEN_RESOURCE", "")
|
||||
default_allowed = [
|
||||
"*.mongodb.net",
|
||||
"*.mongodb-dev.net",
|
||||
"*.mongodb-qa.net",
|
||||
"*.mongodbgov.net",
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
]
|
||||
allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed)
|
||||
msg = (
|
||||
"authentication with MONGODB-OIDC requires providing either a callback or a environment"
|
||||
)
|
||||
if passwd is not None:
|
||||
msg = "password is not supported by MONGODB-OIDC"
|
||||
raise ConfigurationError(msg)
|
||||
if callback or human_callback:
|
||||
if environ is not None:
|
||||
raise ConfigurationError(msg)
|
||||
if callback and human_callback:
|
||||
msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK"
|
||||
raise ConfigurationError(msg)
|
||||
elif environ is not None:
|
||||
if environ == "test":
|
||||
if user is not None:
|
||||
msg = "test environment for MONGODB-OIDC does not support username"
|
||||
raise ConfigurationError(msg)
|
||||
callback = _OIDCTestCallback()
|
||||
elif environ == "azure":
|
||||
passwd = None
|
||||
if not token_resource:
|
||||
raise ConfigurationError(
|
||||
"Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
|
||||
)
|
||||
callback = _OIDCAzureCallback(token_resource)
|
||||
elif environ == "gcp":
|
||||
passwd = None
|
||||
if not token_resource:
|
||||
raise ConfigurationError(
|
||||
"GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property"
|
||||
)
|
||||
callback = _OIDCGCPCallback(token_resource)
|
||||
else:
|
||||
raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}")
|
||||
else:
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
oidc_props = _OIDCProperties(
|
||||
callback=callback,
|
||||
human_callback=human_callback,
|
||||
environment=environ,
|
||||
allowed_hosts=allowed_hosts,
|
||||
token_resource=token_resource,
|
||||
username=user,
|
||||
)
|
||||
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
|
||||
|
||||
elif mech == "PLAIN":
|
||||
source_database = source or database or "$external"
|
||||
return MongoCredential(mech, source_database, user, passwd, None, None)
|
||||
else:
|
||||
source_database = source or database or "admin"
|
||||
if passwd is None:
|
||||
raise ConfigurationError("A password is required.")
|
||||
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
|
||||
|
||||
|
||||
def _xor(fir: bytes, sec: bytes) -> bytes:
|
||||
"""XOR two byte strings together."""
|
||||
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
|
||||
|
||||
|
||||
def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]:
|
||||
"""Split a scram response into key, value pairs."""
|
||||
return dict(
|
||||
typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1))
|
||||
for item in response.split(b",")
|
||||
)
|
||||
|
||||
|
||||
def _authenticate_scram_start(
|
||||
credentials: MongoCredential, mechanism: str
|
||||
) -> tuple[bytes, bytes, MutableMapping[str, Any]]:
|
||||
username = credentials.username
|
||||
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
|
||||
nonce = standard_b64encode(os.urandom(32))
|
||||
first_bare = b"n=" + user + b",r=" + nonce
|
||||
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": mechanism,
|
||||
"payload": Binary(b"n,," + first_bare),
|
||||
"autoAuthorize": 1,
|
||||
"options": {"skipEmptyExchange": True},
|
||||
}
|
||||
return nonce, first_bare, cmd
|
||||
|
||||
|
||||
def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanism: str) -> None:
|
||||
"""Authenticate using SCRAM."""
|
||||
username = credentials.username
|
||||
if mechanism == "SCRAM-SHA-256":
|
||||
digest = "sha256"
|
||||
digestmod = hashlib.sha256
|
||||
data = saslprep(credentials.password).encode("utf-8")
|
||||
else:
|
||||
digest = "sha1"
|
||||
digestmod = hashlib.sha1
|
||||
data = _password_digest(username, credentials.password).encode("utf-8")
|
||||
source = credentials.source
|
||||
cache = credentials.cache
|
||||
|
||||
# Make local
|
||||
_hmac = hmac.HMAC
|
||||
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
assert isinstance(ctx, _ScramContext)
|
||||
assert ctx.scram_data is not None
|
||||
nonce, first_bare = ctx.scram_data
|
||||
res = ctx.speculative_authenticate
|
||||
else:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
|
||||
res = conn.command(source, cmd)
|
||||
|
||||
assert res is not None
|
||||
server_first = res["payload"]
|
||||
parsed = _parse_scram_response(server_first)
|
||||
iterations = int(parsed[b"i"])
|
||||
if iterations < 4096:
|
||||
raise OperationFailure("Server returned an invalid iteration count.")
|
||||
salt = parsed[b"s"]
|
||||
rnonce = parsed[b"r"]
|
||||
if not rnonce.startswith(nonce):
|
||||
raise OperationFailure("Server returned an invalid nonce.")
|
||||
|
||||
without_proof = b"c=biws,r=" + rnonce
|
||||
if cache.data:
|
||||
client_key, server_key, csalt, citerations = cache.data
|
||||
else:
|
||||
client_key, server_key, csalt, citerations = None, None, None, None
|
||||
|
||||
# Salt and / or iterations could change for a number of different
|
||||
# reasons. Either changing invalidates the cache.
|
||||
if not client_key or salt != csalt or iterations != citerations:
|
||||
salted_pass = hashlib.pbkdf2_hmac(digest, data, standard_b64decode(salt), iterations)
|
||||
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
|
||||
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
|
||||
cache.data = (client_key, server_key, salt, iterations)
|
||||
stored_key = digestmod(client_key).digest()
|
||||
auth_msg = b",".join((first_bare, server_first, without_proof))
|
||||
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()
|
||||
client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig))
|
||||
client_final = b",".join((without_proof, client_proof))
|
||||
|
||||
server_sig = standard_b64encode(_hmac(server_key, auth_msg, digestmod).digest())
|
||||
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": res["conversationId"],
|
||||
"payload": Binary(client_final),
|
||||
}
|
||||
res = conn.command(source, cmd)
|
||||
|
||||
parsed = _parse_scram_response(res["payload"])
|
||||
if not hmac.compare_digest(parsed[b"v"], server_sig):
|
||||
raise OperationFailure("Server returned an invalid signature.")
|
||||
|
||||
# A third empty challenge may be required if the server does not support
|
||||
# skipEmptyExchange: SERVER-44857.
|
||||
if not res["done"]:
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": res["conversationId"],
|
||||
"payload": Binary(b""),
|
||||
}
|
||||
res = conn.command(source, cmd)
|
||||
if not res["done"]:
|
||||
raise OperationFailure("SASL conversation failed to complete.")
|
||||
|
||||
|
||||
def _password_digest(username: str, password: str) -> str:
|
||||
"""Get a password digest to use for authentication."""
|
||||
if not isinstance(password, str):
|
||||
raise TypeError("password must be an instance of str")
|
||||
if len(password) == 0:
|
||||
raise ValueError("password can't be empty")
|
||||
if not isinstance(username, str):
|
||||
raise TypeError("username must be an instance of str")
|
||||
|
||||
md5hash = hashlib.md5() # noqa: S324
|
||||
data = f"{username}:mongo:{password}"
|
||||
md5hash.update(data.encode("utf-8"))
|
||||
return md5hash.hexdigest()
|
||||
|
||||
|
||||
def _auth_key(nonce: str, username: str, password: str) -> str:
|
||||
"""Get an auth key to use for authentication."""
|
||||
digest = _password_digest(username, password)
|
||||
md5hash = hashlib.md5() # noqa: S324
|
||||
data = f"{nonce}{username}{digest}"
|
||||
md5hash.update(data.encode("utf-8"))
|
||||
return md5hash.hexdigest()
|
||||
|
||||
|
||||
def _canonicalize_hostname(hostname: str) -> str:
|
||||
"""Canonicalize hostname following MIT-krb5 behavior."""
|
||||
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
|
||||
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
|
||||
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
|
||||
)[0]
|
||||
|
||||
try:
|
||||
name = socket.getnameinfo(sockaddr, socket.NI_NAMEREQD)
|
||||
except socket.gaierror:
|
||||
return canonname.lower()
|
||||
|
||||
return name[0].lower()
|
||||
|
||||
|
||||
def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using GSSAPI."""
|
||||
if not HAVE_KERBEROS:
|
||||
raise ConfigurationError(
|
||||
'The "kerberos" module must be installed to use GSSAPI authentication.'
|
||||
)
|
||||
|
||||
try:
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
props = credentials.mechanism_properties
|
||||
# Starting here and continuing through the while loop below - establish
|
||||
# the security context. See RFC 4752, Section 3.1, first paragraph.
|
||||
host = conn.address[0]
|
||||
if props.canonicalize_host_name:
|
||||
host = _canonicalize_hostname(host)
|
||||
service = props.service_name + "@" + host
|
||||
if props.service_realm is not None:
|
||||
service = service + "@" + props.service_realm
|
||||
|
||||
if password is not None:
|
||||
if _USE_PRINCIPAL:
|
||||
# Note that, though we use unquote_plus for unquoting URI
|
||||
# options, we use quote here. Microsoft's UrlUnescape (used
|
||||
# by WinKerberos) doesn't support +.
|
||||
principal = ":".join((quote(username), quote(password)))
|
||||
result, ctx = kerberos.authGSSClientInit(
|
||||
service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG
|
||||
)
|
||||
else:
|
||||
if "@" in username:
|
||||
user, domain = username.split("@", 1)
|
||||
else:
|
||||
user, domain = username, None
|
||||
result, ctx = kerberos.authGSSClientInit(
|
||||
service,
|
||||
gssflags=kerberos.GSS_C_MUTUAL_FLAG,
|
||||
user=user,
|
||||
domain=domain,
|
||||
password=password,
|
||||
)
|
||||
else:
|
||||
result, ctx = kerberos.authGSSClientInit(service, gssflags=kerberos.GSS_C_MUTUAL_FLAG)
|
||||
|
||||
if result != kerberos.AUTH_GSS_COMPLETE:
|
||||
raise OperationFailure("Kerberos context failed to initialize.")
|
||||
|
||||
try:
|
||||
# pykerberos uses a weird mix of exceptions and return values
|
||||
# to indicate errors.
|
||||
# 0 == continue, 1 == complete, -1 == error
|
||||
# Only authGSSClientStep can return 0.
|
||||
if kerberos.authGSSClientStep(ctx, "") != 0:
|
||||
raise OperationFailure("Unknown kerberos failure in step function.")
|
||||
|
||||
# Start a SASL conversation with mongod/s
|
||||
# Note: pykerberos deals with base64 encoded byte strings.
|
||||
# Since mongo accepts base64 strings as the payload we don't
|
||||
# have to use bson.binary.Binary.
|
||||
payload = kerberos.authGSSClientResponse(ctx)
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": "GSSAPI",
|
||||
"payload": payload,
|
||||
"autoAuthorize": 1,
|
||||
}
|
||||
response = conn.command("$external", cmd)
|
||||
|
||||
# Limit how many times we loop to catch protocol / library issues
|
||||
for _ in range(10):
|
||||
result = kerberos.authGSSClientStep(ctx, str(response["payload"]))
|
||||
if result == -1:
|
||||
raise OperationFailure("Unknown kerberos failure in step function.")
|
||||
|
||||
payload = kerberos.authGSSClientResponse(ctx) or ""
|
||||
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": response["conversationId"],
|
||||
"payload": payload,
|
||||
}
|
||||
response = conn.command("$external", cmd)
|
||||
|
||||
if result == kerberos.AUTH_GSS_COMPLETE:
|
||||
break
|
||||
else:
|
||||
raise OperationFailure("Kerberos authentication failed to complete.")
|
||||
|
||||
# Once the security context is established actually authenticate.
|
||||
# See RFC 4752, Section 3.1, last two paragraphs.
|
||||
if kerberos.authGSSClientUnwrap(ctx, str(response["payload"])) != 1:
|
||||
raise OperationFailure("Unknown kerberos failure during GSS_Unwrap step.")
|
||||
|
||||
if kerberos.authGSSClientWrap(ctx, kerberos.authGSSClientResponse(ctx), username) != 1:
|
||||
raise OperationFailure("Unknown kerberos failure during GSS_Wrap step.")
|
||||
|
||||
payload = kerberos.authGSSClientResponse(ctx)
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": response["conversationId"],
|
||||
"payload": payload,
|
||||
}
|
||||
conn.command("$external", cmd)
|
||||
|
||||
finally:
|
||||
kerberos.authGSSClientClean(ctx)
|
||||
|
||||
except kerberos.KrbError as exc:
|
||||
raise OperationFailure(str(exc)) from None
|
||||
|
||||
|
||||
def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using SASL PLAIN (RFC 4616)"""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
payload = (f"\x00{username}\x00{password}").encode()
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": "PLAIN",
|
||||
"payload": Binary(payload),
|
||||
"autoAuthorize": 1,
|
||||
}
|
||||
conn.command(source, cmd)
|
||||
|
||||
|
||||
def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using MONGODB-X509."""
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
# MONGODB-X509 is done after the speculative auth step.
|
||||
return
|
||||
|
||||
cmd = _X509Context(credentials, conn.address).speculate_command()
|
||||
conn.command("$external", cmd)
|
||||
|
||||
|
||||
def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using MONGODB-CR."""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
# Get a nonce
|
||||
response = conn.command(source, {"getnonce": 1})
|
||||
nonce = response["nonce"]
|
||||
key = _auth_key(nonce, username, password)
|
||||
|
||||
# Actually authenticate
|
||||
query = {"authenticate": 1, "user": username, "nonce": nonce, "key": key}
|
||||
conn.command(source, query)
|
||||
|
||||
|
||||
def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None:
|
||||
if conn.max_wire_version >= 7:
|
||||
if conn.negotiated_mechs:
|
||||
mechs = conn.negotiated_mechs
|
||||
else:
|
||||
source = credentials.source
|
||||
cmd = conn.hello_cmd()
|
||||
cmd["saslSupportedMechs"] = source + "." + credentials.username
|
||||
mechs = (conn.command(source, cmd, publish_events=False)).get("saslSupportedMechs", [])
|
||||
if "SCRAM-SHA-256" in mechs:
|
||||
return _authenticate_scram(credentials, conn, "SCRAM-SHA-256")
|
||||
else:
|
||||
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
|
||||
else:
|
||||
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
|
||||
|
||||
|
||||
_AUTH_MAP: Mapping[str, Callable[..., None]] = {
|
||||
"GSSAPI": _authenticate_gssapi,
|
||||
"MONGODB-CR": _authenticate_mongo_cr,
|
||||
"MONGODB-X509": _authenticate_x509,
|
||||
"MONGODB-AWS": _authenticate_aws,
|
||||
"MONGODB-OIDC": _authenticate_oidc, # type:ignore[dict-item]
|
||||
"PLAIN": _authenticate_plain,
|
||||
"SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"),
|
||||
"SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"),
|
||||
"DEFAULT": _authenticate_default,
|
||||
}
|
||||
|
||||
|
||||
class _AuthContext:
|
||||
def __init__(self, credentials: MongoCredential, address: tuple[str, int]) -> None:
|
||||
self.credentials = credentials
|
||||
self.speculative_authenticate: Optional[Mapping[str, Any]] = None
|
||||
self.address = address
|
||||
|
||||
@staticmethod
|
||||
def from_credentials(
|
||||
creds: MongoCredential, address: tuple[str, int]
|
||||
) -> Optional[_AuthContext]:
|
||||
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
|
||||
if spec_cls:
|
||||
return cast(_AuthContext, spec_cls(creds, address))
|
||||
return None
|
||||
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_response(self, hello: Hello[Mapping[str, Any]]) -> None:
|
||||
self.speculative_authenticate = hello.speculative_authenticate
|
||||
|
||||
def speculate_succeeded(self) -> bool:
|
||||
return bool(self.speculative_authenticate)
|
||||
|
||||
|
||||
class _ScramContext(_AuthContext):
|
||||
def __init__(
|
||||
self, credentials: MongoCredential, address: tuple[str, int], mechanism: str
|
||||
) -> None:
|
||||
super().__init__(credentials, address)
|
||||
self.scram_data: Optional[tuple[bytes, bytes]] = None
|
||||
self.mechanism = mechanism
|
||||
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism)
|
||||
# The 'db' field is included only on the speculative command.
|
||||
cmd["db"] = self.credentials.source
|
||||
# Save for later use.
|
||||
self.scram_data = (nonce, first_bare)
|
||||
return cmd
|
||||
|
||||
|
||||
class _X509Context(_AuthContext):
|
||||
def speculate_command(self) -> MutableMapping[str, Any]:
|
||||
cmd = {"authenticate": 1, "mechanism": "MONGODB-X509"}
|
||||
if self.credentials.username is not None:
|
||||
cmd["user"] = self.credentials.username
|
||||
return cmd
|
||||
|
||||
|
||||
class _OIDCContext(_AuthContext):
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
authenticator = _get_authenticator(self.credentials, self.address)
|
||||
cmd = authenticator.get_spec_auth_cmd()
|
||||
if cmd is None:
|
||||
return None
|
||||
cmd["db"] = self.credentials.source
|
||||
return cmd
|
||||
|
||||
|
||||
_SPECULATIVE_AUTH_MAP: Mapping[str, Any] = {
|
||||
"MONGODB-X509": _X509Context,
|
||||
"SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"),
|
||||
"SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
|
||||
"MONGODB-OIDC": _OIDCContext,
|
||||
"DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
|
||||
}
|
||||
|
||||
|
||||
def authenticate(
|
||||
credentials: MongoCredential, conn: Connection, reauthenticate: bool = False
|
||||
) -> None:
|
||||
"""Authenticate connection."""
|
||||
mechanism = credentials.mechanism
|
||||
auth_func = _AUTH_MAP[mechanism]
|
||||
if mechanism == "MONGODB-OIDC":
|
||||
_authenticate_oidc(credentials, conn, reauthenticate)
|
||||
else:
|
||||
auth_func(credentials, conn)
|
||||
@ -23,8 +23,10 @@ from pymongo.errors import ConfigurationError, OperationFailure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.typings import _ReadableBuffer
|
||||
from pymongo.auth import MongoCredential
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.synchronous.auth import MongoCredential
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
|
||||
@ -36,7 +38,6 @@ def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"MONGODB-AWS authentication requires pymongo-auth-aws: "
|
||||
"install with: python -m pip install 'pymongo[aws]'"
|
||||
) from e
|
||||
|
||||
# Delayed import.
|
||||
from pymongo_auth_aws.auth import ( # type:ignore[import]
|
||||
set_cached_credentials,
|
||||
378
pymongo/synchronous/auth_oidc.py
Normal file
378
pymongo/synchronous/auth_oidc.py
Normal file
@ -0,0 +1,378 @@
|
||||
# Copyright 2023-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MONGODB-OIDC Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
from pymongo._azure_helpers import _get_azure_response
|
||||
from pymongo._csot import remaining
|
||||
from pymongo._gcp_helpers import _get_gcp_response
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.helpers_constants import _AUTHENTICATION_FAILURE_CODE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.auth import MongoCredential
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCIdPInfo:
|
||||
issuer: str
|
||||
clientId: Optional[str] = field(default=None)
|
||||
requestScopes: Optional[list[str]] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCCallbackContext:
|
||||
timeout_seconds: float
|
||||
username: str
|
||||
version: int
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[OIDCIdPInfo] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCCallbackResult:
|
||||
access_token: str
|
||||
expires_in_seconds: Optional[float] = field(default=None)
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
class OIDCCallback(abc.ABC):
|
||||
"""A base class for defining OIDC callbacks."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
"""Convert the given BSON value into our own type."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCProperties:
|
||||
callback: Optional[OIDCCallback] = field(default=None)
|
||||
human_callback: Optional[OIDCCallback] = field(default=None)
|
||||
environment: Optional[str] = field(default=None)
|
||||
allowed_hosts: list[str] = field(default_factory=list)
|
||||
token_resource: Optional[str] = field(default=None)
|
||||
username: str = ""
|
||||
|
||||
|
||||
"""Mechanism properties for MONGODB-OIDC authentication."""
|
||||
|
||||
TOKEN_BUFFER_MINUTES = 5
|
||||
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
|
||||
CALLBACK_VERSION = 1
|
||||
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
|
||||
TIME_BETWEEN_CALLS_SECONDS = 0.1
|
||||
|
||||
|
||||
def _get_authenticator(
|
||||
credentials: MongoCredential, address: tuple[str, int]
|
||||
) -> _OIDCAuthenticator:
|
||||
if credentials.cache.data:
|
||||
return credentials.cache.data
|
||||
|
||||
# Extract values.
|
||||
principal_name = credentials.username
|
||||
properties = credentials.mechanism_properties
|
||||
|
||||
# Validate that the address is allowed.
|
||||
if not properties.environment:
|
||||
found = False
|
||||
allowed_hosts = properties.allowed_hosts
|
||||
for patt in allowed_hosts:
|
||||
if patt == address[0]:
|
||||
found = True
|
||||
elif patt.startswith("*.") and address[0].endswith(patt[1:]):
|
||||
found = True
|
||||
if not found:
|
||||
raise ConfigurationError(
|
||||
f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
|
||||
)
|
||||
|
||||
# Get or create the cache data.
|
||||
credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties)
|
||||
return credentials.cache.data
|
||||
|
||||
|
||||
class _OIDCTestCallback(OIDCCallback):
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
token_file = os.environ.get("OIDC_TOKEN_FILE")
|
||||
if not token_file:
|
||||
raise RuntimeError(
|
||||
'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set'
|
||||
)
|
||||
with open(token_file) as fid:
|
||||
return OIDCCallbackResult(access_token=fid.read().strip())
|
||||
|
||||
|
||||
class _OIDCAWSCallback(OIDCCallback):
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
|
||||
if not token_file:
|
||||
raise RuntimeError(
|
||||
'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set'
|
||||
)
|
||||
with open(token_file) as fid:
|
||||
return OIDCCallbackResult(access_token=fid.read().strip())
|
||||
|
||||
|
||||
class _OIDCAzureCallback(OIDCCallback):
|
||||
def __init__(self, token_resource: str) -> None:
|
||||
self.token_resource = quote(token_resource)
|
||||
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds)
|
||||
return OIDCCallbackResult(
|
||||
access_token=resp["access_token"], expires_in_seconds=resp["expires_in"]
|
||||
)
|
||||
|
||||
|
||||
class _OIDCGCPCallback(OIDCCallback):
|
||||
def __init__(self, token_resource: str) -> None:
|
||||
self.token_resource = quote(token_resource)
|
||||
|
||||
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
|
||||
resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
|
||||
return OIDCCallbackResult(access_token=resp["access_token"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCAuthenticator:
|
||||
username: str
|
||||
properties: _OIDCProperties
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
access_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[OIDCIdPInfo] = field(default=None)
|
||||
token_gen_id: int = field(default=0)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
last_call_time: float = field(default=0)
|
||||
|
||||
def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
"""Handle a reauthenticate from the server."""
|
||||
# Invalidate the token for the connection.
|
||||
self._invalidate(conn)
|
||||
# Call the appropriate auth logic for the callback type.
|
||||
if self.properties.callback:
|
||||
return self._authenticate_machine(conn)
|
||||
return self._authenticate_human(conn)
|
||||
|
||||
def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
"""Handle an initial authenticate request."""
|
||||
# First handle speculative auth.
|
||||
# If it succeeded, we are done.
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
resp = ctx.speculative_authenticate
|
||||
if resp and resp["done"]:
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
return resp
|
||||
|
||||
# If spec auth failed, call the appropriate auth logic for the callback type.
|
||||
# We cannot assume that the token is invalid, because a proxy may have been
|
||||
# involved that stripped the speculative auth information.
|
||||
if self.properties.callback:
|
||||
return self._authenticate_machine(conn)
|
||||
return self._authenticate_human(conn)
|
||||
|
||||
def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]:
|
||||
"""Get the appropriate speculative auth command."""
|
||||
if not self.access_token:
|
||||
return None
|
||||
return self._get_start_command({"jwt": self.access_token})
|
||||
|
||||
def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]:
|
||||
# If there is a cached access token, try to authenticate with it. If
|
||||
# authentication fails with error code 18, invalidate the access token,
|
||||
# fetch a new access token, and try to authenticate again. If authentication
|
||||
# fails for any other reason, raise the error to the user.
|
||||
if self.access_token:
|
||||
try:
|
||||
return self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
return self._authenticate_machine(conn)
|
||||
raise
|
||||
return self._sasl_start_jwt(conn)
|
||||
|
||||
def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
# If we have a cached access token, try a JwtStepRequest.
|
||||
# authentication fails with error code 18, invalidate the access token,
|
||||
# and try to authenticate again. If authentication fails for any other
|
||||
# reason, raise the error to the user.
|
||||
if self.access_token:
|
||||
try:
|
||||
return self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
return self._authenticate_human(conn)
|
||||
raise
|
||||
|
||||
# If we have a cached refresh token, try a JwtStepRequest with that.
|
||||
# If authentication fails with error code 18, invalidate the access and
|
||||
# refresh tokens, and try to authenticate again. If authentication fails for
|
||||
# any other reason, raise the error to the user.
|
||||
if self.refresh_token:
|
||||
try:
|
||||
return self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
self.refresh_token = None
|
||||
return self._authenticate_human(conn)
|
||||
raise
|
||||
|
||||
# Start a new Two-Step SASL conversation.
|
||||
# Run a PrincipalStepRequest to get the IdpInfo.
|
||||
cmd = self._get_start_command(None)
|
||||
start_resp = self._run_command(conn, cmd)
|
||||
# Attempt to authenticate with a JwtStepRequest.
|
||||
return self._sasl_continue_jwt(conn, start_resp)
|
||||
|
||||
def _get_access_token(self) -> Optional[str]:
|
||||
properties = self.properties
|
||||
cb: Union[None, OIDCCallback]
|
||||
resp: OIDCCallbackResult
|
||||
|
||||
is_human = properties.human_callback is not None
|
||||
if is_human and self.idp_info is None:
|
||||
return None
|
||||
|
||||
if properties.callback:
|
||||
cb = properties.callback
|
||||
if properties.human_callback:
|
||||
cb = properties.human_callback
|
||||
|
||||
prev_token = self.access_token
|
||||
if prev_token:
|
||||
return prev_token
|
||||
|
||||
if cb is None and not prev_token:
|
||||
return None
|
||||
|
||||
if not prev_token and cb is not None:
|
||||
with self.lock:
|
||||
# See if the token was changed while we were waiting for the
|
||||
# lock.
|
||||
new_token = self.access_token
|
||||
if new_token != prev_token:
|
||||
return new_token
|
||||
|
||||
# Ensure that we are waiting a min time between callback invocations.
|
||||
delta = time.time() - self.last_call_time
|
||||
if delta < TIME_BETWEEN_CALLS_SECONDS:
|
||||
time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
|
||||
self.last_call_time = time.time()
|
||||
|
||||
if is_human:
|
||||
timeout = HUMAN_CALLBACK_TIMEOUT_SECONDS
|
||||
assert self.idp_info is not None
|
||||
else:
|
||||
timeout = int(remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS)
|
||||
context = OIDCCallbackContext(
|
||||
timeout_seconds=timeout,
|
||||
version=CALLBACK_VERSION,
|
||||
refresh_token=self.refresh_token,
|
||||
idp_info=self.idp_info,
|
||||
username=self.properties.username,
|
||||
)
|
||||
resp = cb.fetch(context)
|
||||
if not isinstance(resp, OIDCCallbackResult):
|
||||
raise ValueError("Callback result must be of type OIDCCallbackResult")
|
||||
self.refresh_token = resp.refresh_token
|
||||
self.access_token = resp.access_token
|
||||
self.token_gen_id += 1
|
||||
|
||||
return self.access_token
|
||||
|
||||
def _run_command(self, conn: Connection, cmd: MutableMapping[str, Any]) -> Mapping[str, Any]:
|
||||
try:
|
||||
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
self._invalidate(conn)
|
||||
raise
|
||||
|
||||
def _is_auth_error(self, err: Exception) -> bool:
|
||||
if not isinstance(err, OperationFailure):
|
||||
return False
|
||||
return err.code == _AUTHENTICATION_FAILURE_CODE
|
||||
|
||||
def _invalidate(self, conn: Connection) -> None:
|
||||
# Ignore the invalidation if a token gen id is given and is less than our
|
||||
# current token gen id.
|
||||
token_gen_id = conn.oidc_token_gen_id or 0
|
||||
if token_gen_id is not None and token_gen_id < self.token_gen_id:
|
||||
return
|
||||
self.access_token = None
|
||||
|
||||
def _sasl_continue_jwt(
|
||||
self, conn: Connection, start_resp: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
self.access_token = None
|
||||
self.refresh_token = None
|
||||
start_payload: dict = bson.decode(start_resp["payload"])
|
||||
if "issuer" in start_payload:
|
||||
self.idp_info = OIDCIdPInfo(**start_payload)
|
||||
access_token = self._get_access_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
|
||||
return self._run_command(conn, cmd)
|
||||
|
||||
def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]:
|
||||
access_token = self._get_access_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
cmd = self._get_start_command({"jwt": access_token})
|
||||
return self._run_command(conn, cmd)
|
||||
|
||||
def _get_start_command(self, payload: Optional[Mapping[str, Any]]) -> MutableMapping[str, Any]:
|
||||
if payload is None:
|
||||
principal_name = self.username
|
||||
if principal_name:
|
||||
payload = {"n": principal_name}
|
||||
else:
|
||||
payload = {}
|
||||
bin_payload = Binary(bson.encode(payload))
|
||||
return {"saslStart": 1, "mechanism": "MONGODB-OIDC", "payload": bin_payload}
|
||||
|
||||
def _get_continue_command(
|
||||
self, payload: Mapping[str, Any], start_resp: Mapping[str, Any]
|
||||
) -> MutableMapping[str, Any]:
|
||||
bin_payload = Binary(bson.encode(payload))
|
||||
return {
|
||||
"saslContinue": 1,
|
||||
"payload": bin_payload,
|
||||
"conversationId": start_resp["conversationId"],
|
||||
}
|
||||
|
||||
|
||||
def _authenticate_oidc(
|
||||
credentials: MongoCredential, conn: Connection, reauthenticate: bool
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""Authenticate using MONGODB-OIDC."""
|
||||
authenticator = _get_authenticator(credentials, conn.address)
|
||||
if reauthenticate:
|
||||
return authenticator.reauthenticate(conn)
|
||||
else:
|
||||
return authenticator.authenticate(conn)
|
||||
@ -34,21 +34,23 @@ from typing import (
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo import _csot, common
|
||||
from pymongo.client_session import ClientSession, _validate_session_write_concern
|
||||
from pymongo.common import (
|
||||
validate_is_document_type,
|
||||
validate_ok_for_replace,
|
||||
validate_ok_for_update,
|
||||
)
|
||||
from pymongo import _csot
|
||||
from pymongo.errors import (
|
||||
BulkWriteError,
|
||||
ConfigurationError,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.helpers import _RETRYABLE_ERROR_CODES, _get_wce_doc
|
||||
from pymongo.message import (
|
||||
from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.synchronous import common
|
||||
from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern
|
||||
from pymongo.synchronous.common import (
|
||||
validate_is_document_type,
|
||||
validate_ok_for_replace,
|
||||
validate_ok_for_update,
|
||||
)
|
||||
from pymongo.synchronous.helpers import _get_wce_doc
|
||||
from pymongo.synchronous.message import (
|
||||
_DELETE,
|
||||
_INSERT,
|
||||
_UPDATE,
|
||||
@ -56,13 +58,15 @@ from pymongo.message import (
|
||||
_EncryptedBulkWriteContext,
|
||||
_randint,
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.synchronous.read_preferences import ReadPreference
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.pool import Connection
|
||||
from pymongo.synchronous.typings import _DocumentOut, _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
_DELETE_ALL: int = 0
|
||||
_DELETE_ONE: int = 1
|
||||
@ -449,7 +453,7 @@ class _Bulk:
|
||||
)
|
||||
|
||||
client = self.collection.database.client
|
||||
client._retryable_write(
|
||||
_ = client._retryable_write(
|
||||
self.is_retryable,
|
||||
retryable_bulk,
|
||||
session,
|
||||
497
pymongo/synchronous/change_stream.py
Normal file
497
pymongo/synchronous/change_stream.py
Normal file
@ -0,0 +1,497 @@
|
||||
# Copyright 2017 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Watch changes on a collection, a database, or the entire cluster."""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
|
||||
|
||||
from bson import CodecOptions, _bson_to_dict
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot
|
||||
from pymongo.errors import (
|
||||
ConnectionFailure,
|
||||
CursorNotFound,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
)
|
||||
from pymongo.synchronous import common
|
||||
from pymongo.synchronous.aggregation import (
|
||||
_AggregationCommand,
|
||||
_CollectionAggregationCommand,
|
||||
_DatabaseAggregationCommand,
|
||||
)
|
||||
from pymongo.synchronous.collation import validate_collation_or_none
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.synchronous.operations import _Op
|
||||
from pymongo.synchronous.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# The change streams spec considers the following server errors from the
|
||||
# getMore command non-resumable. All other getMore errors are resumable.
|
||||
_RESUMABLE_GETMORE_ERRORS = frozenset(
|
||||
[
|
||||
6, # HostUnreachable
|
||||
7, # HostNotFound
|
||||
89, # NetworkTimeout
|
||||
91, # ShutdownInProgress
|
||||
189, # PrimarySteppedDown
|
||||
262, # ExceededTimeLimit
|
||||
9001, # SocketException
|
||||
10107, # NotWritablePrimary
|
||||
11600, # InterruptedAtShutdown
|
||||
11602, # InterruptedDueToReplStateChange
|
||||
13435, # NotPrimaryNoSecondaryOk
|
||||
13436, # NotPrimaryOrSecondary
|
||||
63, # StaleShardVersion
|
||||
150, # StaleEpoch
|
||||
13388, # StaleConfig
|
||||
234, # RetryChangeStream
|
||||
133, # FailedToSatisfyReadPreference
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
|
||||
def _resumable(exc: PyMongoError) -> bool:
|
||||
"""Return True if given a resumable change stream error."""
|
||||
if isinstance(exc, (ConnectionFailure, CursorNotFound)):
|
||||
return True
|
||||
if isinstance(exc, OperationFailure):
|
||||
if exc._max_wire_version is None:
|
||||
return False
|
||||
return (
|
||||
exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError")
|
||||
) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS)
|
||||
return False
|
||||
|
||||
|
||||
class ChangeStream(Generic[_DocumentType]):
|
||||
"""The internal abstract base class for change stream cursors.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
:meth:`pymongo.collection.Collection.watch`,
|
||||
:meth:`pymongo.database.Database.watch`, or
|
||||
:meth:`pymongo.mongo_client.MongoClient.watch` instead.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
.. seealso:: The MongoDB documentation on `changeStreams <https://mongodb.com/docs/manual/changeStreams/>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[
|
||||
MongoClient[_DocumentType],
|
||||
Database[_DocumentType],
|
||||
Collection[_DocumentType],
|
||||
],
|
||||
pipeline: Optional[_Pipeline],
|
||||
full_document: Optional[str],
|
||||
resume_after: Optional[Mapping[str, Any]],
|
||||
max_await_time_ms: Optional[int],
|
||||
batch_size: Optional[int],
|
||||
collation: Optional[_CollationIn],
|
||||
start_at_operation_time: Optional[Timestamp],
|
||||
session: Optional[ClientSession],
|
||||
start_after: Optional[Mapping[str, Any]],
|
||||
comment: Optional[Any] = None,
|
||||
full_document_before_change: Optional[str] = None,
|
||||
show_expanded_events: Optional[bool] = None,
|
||||
) -> None:
|
||||
if pipeline is None:
|
||||
pipeline = []
|
||||
pipeline = common.validate_list("pipeline", pipeline)
|
||||
common.validate_string_or_none("full_document", full_document)
|
||||
validate_collation_or_none(collation)
|
||||
common.validate_non_negative_integer_or_none("batchSize", batch_size)
|
||||
|
||||
self._decode_custom = False
|
||||
self._orig_codec_options: CodecOptions[_DocumentType] = target.codec_options
|
||||
if target.codec_options.type_registry._decoder_map:
|
||||
self._decode_custom = True
|
||||
# Keep the type registry so that we support encoding custom types
|
||||
# in the pipeline.
|
||||
self._target = target.with_options( # type: ignore
|
||||
codec_options=target.codec_options.with_options(document_class=RawBSONDocument)
|
||||
)
|
||||
else:
|
||||
self._target = target
|
||||
|
||||
self._pipeline = copy.deepcopy(pipeline)
|
||||
self._full_document = full_document
|
||||
self._full_document_before_change = full_document_before_change
|
||||
self._uses_start_after = start_after is not None
|
||||
self._uses_resume_after = resume_after is not None
|
||||
self._resume_token = copy.deepcopy(start_after or resume_after)
|
||||
self._max_await_time_ms = max_await_time_ms
|
||||
self._batch_size = batch_size
|
||||
self._collation = collation
|
||||
self._start_at_operation_time = start_at_operation_time
|
||||
self._session = session
|
||||
self._comment = comment
|
||||
self._closed = False
|
||||
self._timeout = self._target._timeout
|
||||
self._show_expanded_events = show_expanded_events
|
||||
|
||||
def _initialize_cursor(self) -> None:
|
||||
# Initialize cursor.
|
||||
self._cursor = self._create_cursor()
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_AggregationCommand]:
|
||||
"""The aggregation command class to be used."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _client(self) -> MongoClient:
|
||||
"""The client against which the aggregation commands for
|
||||
this ChangeStream will be run.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _change_stream_options(self) -> dict[str, Any]:
|
||||
"""Return the options dict for the $changeStream pipeline stage."""
|
||||
options: dict[str, Any] = {}
|
||||
if self._full_document is not None:
|
||||
options["fullDocument"] = self._full_document
|
||||
|
||||
if self._full_document_before_change is not None:
|
||||
options["fullDocumentBeforeChange"] = self._full_document_before_change
|
||||
|
||||
resume_token = self.resume_token
|
||||
if resume_token is not None:
|
||||
if self._uses_start_after:
|
||||
options["startAfter"] = resume_token
|
||||
else:
|
||||
options["resumeAfter"] = resume_token
|
||||
|
||||
elif self._start_at_operation_time is not None:
|
||||
options["startAtOperationTime"] = self._start_at_operation_time
|
||||
|
||||
if self._show_expanded_events:
|
||||
options["showExpandedEvents"] = self._show_expanded_events
|
||||
|
||||
return options
|
||||
|
||||
def _command_options(self) -> dict[str, Any]:
|
||||
"""Return the options dict for the aggregation command."""
|
||||
options = {}
|
||||
if self._max_await_time_ms is not None:
|
||||
options["maxAwaitTimeMS"] = self._max_await_time_ms
|
||||
if self._batch_size is not None:
|
||||
options["batchSize"] = self._batch_size
|
||||
return options
|
||||
|
||||
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
|
||||
"""Return the full aggregation pipeline for this ChangeStream."""
|
||||
options = self._change_stream_options()
|
||||
full_pipeline: list = [{"$changeStream": options}]
|
||||
full_pipeline.extend(self._pipeline)
|
||||
return full_pipeline
|
||||
|
||||
def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None:
|
||||
"""Callback that caches the postBatchResumeToken or
|
||||
startAtOperationTime from a changeStream aggregate command response
|
||||
containing an empty batch of change documents.
|
||||
|
||||
This is implemented as a callback because we need access to the wire
|
||||
version in order to determine whether to cache this value.
|
||||
"""
|
||||
if not result["cursor"]["firstBatch"]:
|
||||
if "postBatchResumeToken" in result["cursor"]:
|
||||
self._resume_token = result["cursor"]["postBatchResumeToken"]
|
||||
elif (
|
||||
self._start_at_operation_time is None
|
||||
and self._uses_resume_after is False
|
||||
and self._uses_start_after is False
|
||||
and conn.max_wire_version >= 7
|
||||
):
|
||||
self._start_at_operation_time = result.get("operationTime")
|
||||
# PYTHON-2181: informative error on missing operationTime.
|
||||
if self._start_at_operation_time is None:
|
||||
raise OperationFailure(
|
||||
"Expected field 'operationTime' missing from command "
|
||||
f"response : {result!r}"
|
||||
)
|
||||
|
||||
def _run_aggregation_cmd(
|
||||
self, session: Optional[ClientSession], explicit_session: bool
|
||||
) -> CommandCursor:
|
||||
"""Run the full aggregation pipeline for this ChangeStream and return
|
||||
the corresponding CommandCursor.
|
||||
"""
|
||||
cmd = self._aggregation_command_class(
|
||||
self._target,
|
||||
CommandCursor,
|
||||
self._aggregation_pipeline(),
|
||||
self._command_options(),
|
||||
explicit_session,
|
||||
result_processor=self._process_result,
|
||||
comment=self._comment,
|
||||
)
|
||||
return self._client._retryable_read(
|
||||
cmd.get_cursor,
|
||||
self._target._read_preference_for(session),
|
||||
session,
|
||||
operation=_Op.AGGREGATE,
|
||||
)
|
||||
|
||||
def _create_cursor(self) -> CommandCursor:
|
||||
with self._client._tmp_session(self._session, close=False) as s:
|
||||
return self._run_aggregation_cmd(session=s, explicit_session=self._session is not None)
|
||||
|
||||
def _resume(self) -> None:
|
||||
"""Reestablish this change stream after a resumable error."""
|
||||
try:
|
||||
self._cursor.close()
|
||||
except PyMongoError:
|
||||
pass
|
||||
self._cursor = self._create_cursor()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close this ChangeStream."""
|
||||
self._closed = True
|
||||
self._cursor.close()
|
||||
|
||||
def __iter__(self) -> ChangeStream[_DocumentType]:
|
||||
return self
|
||||
|
||||
@property
|
||||
def resume_token(self) -> Optional[Mapping[str, Any]]:
|
||||
"""The cached resume token that will be used to resume after the most
|
||||
recently returned change.
|
||||
|
||||
.. versionadded:: 3.9
|
||||
"""
|
||||
return copy.deepcopy(self._resume_token)
|
||||
|
||||
@_csot.apply
|
||||
def next(self) -> _DocumentType:
|
||||
"""Advance the cursor.
|
||||
|
||||
This method blocks until the next change document is returned or an
|
||||
unrecoverable error is raised. This method is used when iterating over
|
||||
all changes in the cursor. For example::
|
||||
|
||||
try:
|
||||
resume_token = None
|
||||
pipeline = [{'$match': {'operationType': 'insert'}}]
|
||||
async with db.collection.watch(pipeline) as stream:
|
||||
async for insert_change in stream:
|
||||
print(insert_change)
|
||||
resume_token = stream.resume_token
|
||||
except pymongo.errors.PyMongoError:
|
||||
# The ChangeStream encountered an unrecoverable error or the
|
||||
# resume attempt failed to recreate the cursor.
|
||||
if resume_token is None:
|
||||
# There is no usable resume token because there was a
|
||||
# failure during ChangeStream initialization.
|
||||
logging.error('...')
|
||||
else:
|
||||
# Use the interrupted ChangeStream's resume token to create
|
||||
# a new ChangeStream. The new stream will continue from the
|
||||
# last seen insert change without missing any events.
|
||||
async with db.collection.watch(
|
||||
pipeline, resume_after=resume_token) as stream:
|
||||
async for insert_change in stream:
|
||||
print(insert_change)
|
||||
|
||||
Raises :exc:`StopIteration` if this ChangeStream is closed.
|
||||
"""
|
||||
while self.alive:
|
||||
doc = self.try_next()
|
||||
if doc is not None:
|
||||
return doc
|
||||
|
||||
raise StopIteration
|
||||
|
||||
__next__ = next
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
"""Does this cursor have the potential to return more data?
|
||||
|
||||
.. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise
|
||||
:exc:`StopIteration` and :meth:`try_next` can return ``None``.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
"""
|
||||
return not self._closed
|
||||
|
||||
@_csot.apply
|
||||
def try_next(self) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor without blocking indefinitely.
|
||||
|
||||
This method returns the next change document without waiting
|
||||
indefinitely for the next change. For example::
|
||||
|
||||
async with db.collection.watch() as stream:
|
||||
while stream.alive:
|
||||
change = await stream.try_next()
|
||||
# Note that the ChangeStream's resume token may be updated
|
||||
# even when no changes are returned.
|
||||
print("Current resume token: %r" % (stream.resume_token,))
|
||||
if change is not None:
|
||||
print("Change document: %r" % (change,))
|
||||
continue
|
||||
# We end up here when there are no recent changes.
|
||||
# Sleep for a while before trying again to avoid flooding
|
||||
# the server with getMore requests when no changes are
|
||||
# available.
|
||||
time.sleep(10)
|
||||
|
||||
If no change document is cached locally then this method runs a single
|
||||
getMore command. If the getMore yields any documents, the next
|
||||
document is returned, otherwise, if the getMore returns no documents
|
||||
(because there have been no changes) then ``None`` is returned.
|
||||
|
||||
:return: The next change document or ``None`` when no document is available
|
||||
after running a single getMore or when the cursor is closed.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
"""
|
||||
if not self._closed and not self._cursor.alive:
|
||||
self._resume()
|
||||
|
||||
# Attempt to get the next change with at most one getMore and at most
|
||||
# one resume attempt.
|
||||
try:
|
||||
try:
|
||||
change = self._cursor._try_next(True)
|
||||
except PyMongoError as exc:
|
||||
if not _resumable(exc):
|
||||
raise
|
||||
self._resume()
|
||||
change = self._cursor._try_next(False)
|
||||
except PyMongoError as exc:
|
||||
# Close the stream after a fatal error.
|
||||
if not _resumable(exc) and not exc.timeout:
|
||||
self.close()
|
||||
raise
|
||||
except Exception:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
# Check if the cursor was invalidated.
|
||||
if not self._cursor.alive:
|
||||
self._closed = True
|
||||
|
||||
# If no changes are available.
|
||||
if change is None:
|
||||
# We have either iterated over all documents in the cursor,
|
||||
# OR the most-recently returned batch is empty. In either case,
|
||||
# update the cached resume token with the postBatchResumeToken if
|
||||
# one was returned. We also clear the startAtOperationTime.
|
||||
if self._cursor._post_batch_resume_token is not None:
|
||||
self._resume_token = self._cursor._post_batch_resume_token
|
||||
self._start_at_operation_time = None
|
||||
return change
|
||||
|
||||
# Else, changes are available.
|
||||
try:
|
||||
resume_token = change["_id"]
|
||||
except KeyError:
|
||||
self.close()
|
||||
raise InvalidOperation(
|
||||
"Cannot provide resume functionality when the resume token is missing."
|
||||
) from None
|
||||
|
||||
# If this is the last change document from the current batch, cache the
|
||||
# postBatchResumeToken.
|
||||
if not self._cursor._has_next() and self._cursor._post_batch_resume_token:
|
||||
resume_token = self._cursor._post_batch_resume_token
|
||||
|
||||
# Hereafter, don't use startAfter; instead use resumeAfter.
|
||||
self._uses_start_after = False
|
||||
self._uses_resume_after = True
|
||||
|
||||
# Cache the resume token and clear startAtOperationTime.
|
||||
self._resume_token = resume_token
|
||||
self._start_at_operation_time = None
|
||||
|
||||
if self._decode_custom:
|
||||
return _bson_to_dict(change.raw, self._orig_codec_options)
|
||||
return change
|
||||
|
||||
def __enter__(self) -> ChangeStream[_DocumentType]:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class CollectionChangeStream(ChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on a single collection.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.collection.Collection.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
_target: Collection[_DocumentType]
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_CollectionAggregationCommand]:
|
||||
return _CollectionAggregationCommand
|
||||
|
||||
@property
|
||||
def _client(self) -> MongoClient[_DocumentType]:
|
||||
return self._target.database.client
|
||||
|
||||
|
||||
class DatabaseChangeStream(ChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in a database.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.database.Database.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
_target: Database[_DocumentType]
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_DatabaseAggregationCommand]:
|
||||
return _DatabaseAggregationCommand
|
||||
|
||||
@property
|
||||
def _client(self) -> MongoClient[_DocumentType]:
|
||||
return self._target.client
|
||||
|
||||
|
||||
class ClusterChangeStream(DatabaseChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in the cluster.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.mongo_client.MongoClient.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
def _change_stream_options(self) -> dict[str, Any]:
|
||||
options = super()._change_stream_options()
|
||||
options["allChangesForCluster"] = True
|
||||
return options
|
||||
334
pymongo/synchronous/client_options.py
Normal file
334
pymongo/synchronous/client_options.py
Normal file
@ -0,0 +1,334 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Tools to parse mongo client options."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast
|
||||
|
||||
from bson.codec_options import _parse_codec_options
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.ssl_support import get_ssl_context
|
||||
from pymongo.synchronous import common
|
||||
from pymongo.synchronous.compression_support import CompressionSettings
|
||||
from pymongo.synchronous.monitoring import _EventListener, _EventListeners
|
||||
from pymongo.synchronous.pool import PoolOptions
|
||||
from pymongo.synchronous.read_preferences import (
|
||||
_ServerMode,
|
||||
make_read_preference,
|
||||
read_pref_mode_from_name,
|
||||
)
|
||||
from pymongo.synchronous.server_selectors import any_server_selector
|
||||
from pymongo.write_concern import WriteConcern, validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.codec_options import CodecOptions
|
||||
from pymongo.pyopenssl_context import SSLContext
|
||||
from pymongo.synchronous.auth import MongoCredential
|
||||
from pymongo.synchronous.encryption_options import AutoEncryptionOpts
|
||||
from pymongo.synchronous.topology_description import _ServerSelector
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def _parse_credentials(
|
||||
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
) -> Optional[MongoCredential]:
|
||||
"""Parse authentication credentials."""
|
||||
mechanism = options.get("authmechanism", "DEFAULT" if username else None)
|
||||
source = options.get("authsource")
|
||||
if username or mechanism:
|
||||
from pymongo.synchronous.auth import _build_credentials_tuple
|
||||
|
||||
return _build_credentials_tuple(mechanism, source, username, password, options, database)
|
||||
return None
|
||||
|
||||
|
||||
def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode:
|
||||
"""Parse read preference options."""
|
||||
if "read_preference" in options:
|
||||
return options["read_preference"]
|
||||
|
||||
name = options.get("readpreference", "primary")
|
||||
mode = read_pref_mode_from_name(name)
|
||||
tags = options.get("readpreferencetags")
|
||||
max_staleness = options.get("maxstalenessseconds", -1)
|
||||
return make_read_preference(mode, tags, max_staleness)
|
||||
|
||||
|
||||
def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern:
|
||||
"""Parse write concern options."""
|
||||
concern = options.get("w")
|
||||
wtimeout = options.get("wtimeoutms")
|
||||
j = options.get("journal")
|
||||
fsync = options.get("fsync")
|
||||
return WriteConcern(concern, wtimeout, j, fsync)
|
||||
|
||||
|
||||
def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
|
||||
"""Parse read concern options."""
|
||||
concern = options.get("readconcernlevel")
|
||||
return ReadConcern(concern)
|
||||
|
||||
|
||||
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
|
||||
"""Parse ssl options."""
|
||||
use_tls = options.get("tls")
|
||||
if use_tls is not None:
|
||||
validate_boolean("tls", use_tls)
|
||||
|
||||
certfile = options.get("tlscertificatekeyfile")
|
||||
passphrase = options.get("tlscertificatekeyfilepassword")
|
||||
ca_certs = options.get("tlscafile")
|
||||
crlfile = options.get("tlscrlfile")
|
||||
allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False)
|
||||
allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False)
|
||||
disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False)
|
||||
|
||||
enabled_tls_opts = []
|
||||
for opt in (
|
||||
"tlscertificatekeyfile",
|
||||
"tlscertificatekeyfilepassword",
|
||||
"tlscafile",
|
||||
"tlscrlfile",
|
||||
):
|
||||
# Any non-null value of these options implies tls=True.
|
||||
if opt in options and options[opt]:
|
||||
enabled_tls_opts.append(opt)
|
||||
for opt in (
|
||||
"tlsallowinvalidcertificates",
|
||||
"tlsallowinvalidhostnames",
|
||||
"tlsdisableocspendpointcheck",
|
||||
):
|
||||
# A value of False for these options implies tls=True.
|
||||
if opt in options and not options[opt]:
|
||||
enabled_tls_opts.append(opt)
|
||||
|
||||
if enabled_tls_opts:
|
||||
if use_tls is None:
|
||||
# Implicitly enable TLS when one of the tls* options is set.
|
||||
use_tls = True
|
||||
elif not use_tls:
|
||||
# Error since tls is explicitly disabled but a tls option is set.
|
||||
raise ConfigurationError(
|
||||
"TLS has not been enabled but the "
|
||||
"following tls parameters have been set: "
|
||||
"%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts)
|
||||
)
|
||||
|
||||
if use_tls:
|
||||
ctx = get_ssl_context(
|
||||
certfile,
|
||||
passphrase,
|
||||
ca_certs,
|
||||
crlfile,
|
||||
allow_invalid_certificates,
|
||||
allow_invalid_hostnames,
|
||||
disable_ocsp_endpoint_check,
|
||||
)
|
||||
return ctx, allow_invalid_hostnames
|
||||
return None, allow_invalid_hostnames
|
||||
|
||||
|
||||
def _parse_pool_options(
|
||||
username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
) -> PoolOptions:
|
||||
"""Parse connection pool options."""
|
||||
credentials = _parse_credentials(username, password, database, options)
|
||||
max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE)
|
||||
min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE)
|
||||
max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC)
|
||||
if max_pool_size is not None and min_pool_size > max_pool_size:
|
||||
raise ValueError("minPoolSize must be smaller or equal to maxPoolSize")
|
||||
connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT)
|
||||
socket_timeout = options.get("sockettimeoutms")
|
||||
wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT)
|
||||
event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners"))
|
||||
appname = options.get("appname")
|
||||
driver = options.get("driver")
|
||||
server_api = options.get("server_api")
|
||||
compression_settings = CompressionSettings(
|
||||
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
|
||||
)
|
||||
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
|
||||
load_balanced = options.get("loadbalanced")
|
||||
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
|
||||
return PoolOptions(
|
||||
max_pool_size,
|
||||
min_pool_size,
|
||||
max_idle_time_seconds,
|
||||
connect_timeout,
|
||||
socket_timeout,
|
||||
wait_queue_timeout,
|
||||
ssl_context,
|
||||
tls_allow_invalid_hostnames,
|
||||
_EventListeners(event_listeners),
|
||||
appname,
|
||||
driver,
|
||||
compression_settings,
|
||||
max_connecting=max_connecting,
|
||||
server_api=server_api,
|
||||
load_balanced=load_balanced,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
|
||||
class ClientOptions:
|
||||
"""Read only configuration options for a MongoClient.
|
||||
|
||||
Should not be instantiated directly by application developers. Access
|
||||
a client's options via :attr:`pymongo.mongo_client.MongoClient.options`
|
||||
instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, username: str, password: str, database: Optional[str], options: Mapping[str, Any]
|
||||
):
|
||||
self.__options = options
|
||||
self.__codec_options = _parse_codec_options(options)
|
||||
self.__direct_connection = options.get("directconnection")
|
||||
self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS)
|
||||
# self.__server_selection_timeout is in seconds. Must use full name for
|
||||
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
|
||||
self.__server_selection_timeout = options.get(
|
||||
"serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT
|
||||
)
|
||||
self.__pool_options = _parse_pool_options(username, password, database, options)
|
||||
self.__read_preference = _parse_read_preference(options)
|
||||
self.__replica_set_name = options.get("replicaset")
|
||||
self.__write_concern = _parse_write_concern(options)
|
||||
self.__read_concern = _parse_read_concern(options)
|
||||
self.__connect = options.get("connect")
|
||||
self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY)
|
||||
self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES)
|
||||
self.__retry_reads = options.get("retryreads", common.RETRY_READS)
|
||||
self.__server_selector = options.get("server_selector", any_server_selector)
|
||||
self.__auto_encryption_opts = options.get("auto_encryption_opts")
|
||||
self.__load_balanced = options.get("loadbalanced")
|
||||
self.__timeout = options.get("timeoutms")
|
||||
self.__server_monitoring_mode = options.get(
|
||||
"servermonitoringmode", common.SERVER_MONITORING_MODE
|
||||
)
|
||||
|
||||
@property
|
||||
def _options(self) -> Mapping[str, Any]:
|
||||
"""The original options used to create this ClientOptions."""
|
||||
return self.__options
|
||||
|
||||
@property
|
||||
def connect(self) -> Optional[bool]:
|
||||
"""Whether to begin discovering a MongoDB topology automatically."""
|
||||
return self.__connect
|
||||
|
||||
@property
|
||||
def codec_options(self) -> CodecOptions:
|
||||
"""A :class:`~bson.codec_options.CodecOptions` instance."""
|
||||
return self.__codec_options
|
||||
|
||||
@property
|
||||
def direct_connection(self) -> Optional[bool]:
|
||||
"""Whether to connect to the deployment in 'Single' topology."""
|
||||
return self.__direct_connection
|
||||
|
||||
@property
|
||||
def local_threshold_ms(self) -> int:
|
||||
"""The local threshold for this instance."""
|
||||
return self.__local_threshold_ms
|
||||
|
||||
@property
|
||||
def server_selection_timeout(self) -> int:
|
||||
"""The server selection timeout for this instance in seconds."""
|
||||
return self.__server_selection_timeout
|
||||
|
||||
@property
|
||||
def server_selector(self) -> _ServerSelector:
|
||||
return self.__server_selector
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self) -> int:
|
||||
"""The monitoring frequency in seconds."""
|
||||
return self.__heartbeat_frequency
|
||||
|
||||
@property
|
||||
def pool_options(self) -> PoolOptions:
|
||||
"""A :class:`~pymongo.pool.PoolOptions` instance."""
|
||||
return self.__pool_options
|
||||
|
||||
@property
|
||||
def read_preference(self) -> _ServerMode:
|
||||
"""A read preference instance."""
|
||||
return self.__read_preference
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""Replica set name or None."""
|
||||
return self.__replica_set_name
|
||||
|
||||
@property
|
||||
def write_concern(self) -> WriteConcern:
|
||||
"""A :class:`~pymongo.write_concern.WriteConcern` instance."""
|
||||
return self.__write_concern
|
||||
|
||||
@property
|
||||
def read_concern(self) -> ReadConcern:
|
||||
"""A :class:`~pymongo.read_concern.ReadConcern` instance."""
|
||||
return self.__read_concern
|
||||
|
||||
@property
|
||||
def timeout(self) -> Optional[float]:
|
||||
"""The configured timeoutMS converted to seconds, or None.
|
||||
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
return self.__timeout
|
||||
|
||||
@property
|
||||
def retry_writes(self) -> bool:
|
||||
"""If this instance should retry supported write operations."""
|
||||
return self.__retry_writes
|
||||
|
||||
@property
|
||||
def retry_reads(self) -> bool:
|
||||
"""If this instance should retry supported read operations."""
|
||||
return self.__retry_reads
|
||||
|
||||
@property
|
||||
def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]:
|
||||
"""A :class:`~pymongo.encryption.AutoEncryptionOpts` or None."""
|
||||
return self.__auto_encryption_opts
|
||||
|
||||
@property
|
||||
def load_balanced(self) -> Optional[bool]:
|
||||
"""True if the client was configured to connect to a load balancer."""
|
||||
return self.__load_balanced
|
||||
|
||||
@property
|
||||
def event_listeners(self) -> list[_EventListeners]:
|
||||
"""The event listeners registered for this client.
|
||||
|
||||
See :mod:`~pymongo.monitoring` for details.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
assert self.__pool_options._event_listeners is not None
|
||||
return self.__pool_options._event_listeners.event_listeners()
|
||||
|
||||
@property
|
||||
def server_monitoring_mode(self) -> str:
|
||||
"""The configured serverMonitoringMode option.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
"""
|
||||
return self.__server_monitoring_mode
|
||||
1157
pymongo/synchronous/client_session.py
Normal file
1157
pymongo/synchronous/client_session.py
Normal file
File diff suppressed because it is too large
Load Diff
226
pymongo/synchronous/collation.py
Normal file
226
pymongo/synchronous/collation.py
Normal file
@ -0,0 +1,226 @@
|
||||
# Copyright 2016 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tools for working with `collations`_.
|
||||
|
||||
.. _collations: https://www.mongodb.com/docs/manual/reference/collation/
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping, Optional, Union
|
||||
|
||||
from pymongo.synchronous import common
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class CollationStrength:
|
||||
"""
|
||||
An enum that defines values for `strength` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
PRIMARY = 1
|
||||
"""Differentiate base (unadorned) characters."""
|
||||
|
||||
SECONDARY = 2
|
||||
"""Differentiate character accents."""
|
||||
|
||||
TERTIARY = 3
|
||||
"""Differentiate character case."""
|
||||
|
||||
QUATERNARY = 4
|
||||
"""Differentiate words with and without punctuation."""
|
||||
|
||||
IDENTICAL = 5
|
||||
"""Differentiate unicode code point (characters are exactly identical)."""
|
||||
|
||||
|
||||
class CollationAlternate:
|
||||
"""
|
||||
An enum that defines values for `alternate` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
NON_IGNORABLE = "non-ignorable"
|
||||
"""Spaces and punctuation are treated as base characters."""
|
||||
|
||||
SHIFTED = "shifted"
|
||||
"""Spaces and punctuation are *not* considered base characters.
|
||||
|
||||
Spaces and punctuation are distinguished regardless when the
|
||||
:class:`~pymongo.collation.Collation` strength is at least
|
||||
:data:`~pymongo.collation.CollationStrength.QUATERNARY`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class CollationMaxVariable:
|
||||
"""
|
||||
An enum that defines values for `max_variable` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
PUNCT = "punct"
|
||||
"""Both punctuation and spaces are ignored."""
|
||||
|
||||
SPACE = "space"
|
||||
"""Spaces alone are ignored."""
|
||||
|
||||
|
||||
class CollationCaseFirst:
|
||||
"""
|
||||
An enum that defines values for `case_first` on a
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
"""
|
||||
|
||||
UPPER = "upper"
|
||||
"""Sort uppercase characters first."""
|
||||
|
||||
LOWER = "lower"
|
||||
"""Sort lowercase characters first."""
|
||||
|
||||
OFF = "off"
|
||||
"""Default for locale or collation strength."""
|
||||
|
||||
|
||||
class Collation:
|
||||
"""Collation
|
||||
|
||||
:param locale: (string) The locale of the collation. This should be a string
|
||||
that identifies an `ICU locale ID` exactly. For example, ``en_US`` is
|
||||
valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB
|
||||
documentation for a list of supported locales.
|
||||
:param caseLevel: (optional) If ``True``, turn on case sensitivity if
|
||||
`strength` is 1 or 2 (case sensitivity is implied if `strength` is
|
||||
greater than 2). Defaults to ``False``.
|
||||
:param caseFirst: (optional) Specify that either uppercase or lowercase
|
||||
characters take precedence. Must be one of the following values:
|
||||
|
||||
* :data:`~CollationCaseFirst.UPPER`
|
||||
* :data:`~CollationCaseFirst.LOWER`
|
||||
* :data:`~CollationCaseFirst.OFF` (the default)
|
||||
|
||||
:param strength: Specify the comparison strength. This is also
|
||||
known as the ICU comparison level. This must be one of the following
|
||||
values:
|
||||
|
||||
* :data:`~CollationStrength.PRIMARY`
|
||||
* :data:`~CollationStrength.SECONDARY`
|
||||
* :data:`~CollationStrength.TERTIARY` (the default)
|
||||
* :data:`~CollationStrength.QUATERNARY`
|
||||
* :data:`~CollationStrength.IDENTICAL`
|
||||
|
||||
Each successive level builds upon the previous. For example, a
|
||||
`strength` of :data:`~CollationStrength.SECONDARY` differentiates
|
||||
characters based both on the unadorned base character and its accents.
|
||||
|
||||
:param numericOrdering: If ``True``, order numbers numerically
|
||||
instead of in collation order (defaults to ``False``).
|
||||
:param alternate: Specify whether spaces and punctuation are
|
||||
considered base characters. This must be one of the following values:
|
||||
|
||||
* :data:`~CollationAlternate.NON_IGNORABLE` (the default)
|
||||
* :data:`~CollationAlternate.SHIFTED`
|
||||
|
||||
:param maxVariable: When `alternate` is
|
||||
:data:`~CollationAlternate.SHIFTED`, this option specifies what
|
||||
characters may be ignored. This must be one of the following values:
|
||||
|
||||
* :data:`~CollationMaxVariable.PUNCT` (the default)
|
||||
* :data:`~CollationMaxVariable.SPACE`
|
||||
|
||||
:param normalization: If ``True``, normalizes text into Unicode
|
||||
NFD. Defaults to ``False``.
|
||||
:param backwards: If ``True``, accents on characters are
|
||||
considered from the back of the word to the front, as it is done in some
|
||||
French dictionary ordering traditions. Defaults to ``False``.
|
||||
:param kwargs: Keyword arguments supplying any additional options
|
||||
to be sent with this Collation object.
|
||||
|
||||
.. versionadded: 3.4
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
locale: str,
|
||||
caseLevel: Optional[bool] = None,
|
||||
caseFirst: Optional[str] = None,
|
||||
strength: Optional[int] = None,
|
||||
numericOrdering: Optional[bool] = None,
|
||||
alternate: Optional[str] = None,
|
||||
maxVariable: Optional[str] = None,
|
||||
normalization: Optional[bool] = None,
|
||||
backwards: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
locale = common.validate_string("locale", locale)
|
||||
self.__document: dict[str, Any] = {"locale": locale}
|
||||
if caseLevel is not None:
|
||||
self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel)
|
||||
if caseFirst is not None:
|
||||
self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst)
|
||||
if strength is not None:
|
||||
self.__document["strength"] = common.validate_integer("strength", strength)
|
||||
if numericOrdering is not None:
|
||||
self.__document["numericOrdering"] = validate_boolean(
|
||||
"numericOrdering", numericOrdering
|
||||
)
|
||||
if alternate is not None:
|
||||
self.__document["alternate"] = common.validate_string("alternate", alternate)
|
||||
if maxVariable is not None:
|
||||
self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable)
|
||||
if normalization is not None:
|
||||
self.__document["normalization"] = validate_boolean("normalization", normalization)
|
||||
if backwards is not None:
|
||||
self.__document["backwards"] = validate_boolean("backwards", backwards)
|
||||
self.__document.update(kwargs)
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
"""The document representation of this collation.
|
||||
|
||||
.. note::
|
||||
:class:`Collation` is immutable. Mutating the value of
|
||||
:attr:`document` does not mutate this :class:`Collation`.
|
||||
"""
|
||||
return self.__document.copy()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
document = self.document
|
||||
return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Collation):
|
||||
return self.document == other.document
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
def validate_collation_or_none(
|
||||
value: Optional[Union[Mapping[str, Any], Collation]]
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, Collation):
|
||||
return value.document
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")
|
||||
3547
pymongo/synchronous/collection.py
Normal file
3547
pymongo/synchronous/collection.py
Normal file
File diff suppressed because it is too large
Load Diff
415
pymongo/synchronous/command_cursor.py
Normal file
415
pymongo/synchronous/command_cursor.py
Normal file
@ -0,0 +1,415 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CommandCursor class to iterate over command results."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Iterator,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson import CodecOptions, _convert_raw_document_lists_to_streams
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.synchronous.cursor import _ConnectionManager
|
||||
from pymongo.synchronous.message import (
|
||||
_CursorAddress,
|
||||
_GetMore,
|
||||
_OpMsg,
|
||||
_OpReply,
|
||||
_RawBatchGetMore,
|
||||
)
|
||||
from pymongo.synchronous.response import PinnedResponse
|
||||
from pymongo.synchronous.typings import _Address, _DocumentOut, _DocumentType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class CommandCursor(Generic[_DocumentType]):
|
||||
"""A cursor / iterator over command cursors."""
|
||||
|
||||
_getmore_class = _GetMore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: Collection[_DocumentType],
|
||||
cursor_info: Mapping[str, Any],
|
||||
address: Optional[_Address],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
explicit_session: bool = False,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
"""Create a new command cursor."""
|
||||
self._sock_mgr: Any = None
|
||||
self._collection: Collection[_DocumentType] = collection
|
||||
self._id = cursor_info["id"]
|
||||
self._data = deque(cursor_info["firstBatch"])
|
||||
self._postbatchresumetoken: Optional[Mapping[str, Any]] = cursor_info.get(
|
||||
"postBatchResumeToken"
|
||||
)
|
||||
self._address = address
|
||||
self._batch_size = batch_size
|
||||
self._max_await_time_ms = max_await_time_ms
|
||||
self._session = session
|
||||
self._explicit_session = explicit_session
|
||||
self._killed = self._id == 0
|
||||
self._comment = comment
|
||||
if _IS_SYNC and self._killed:
|
||||
self._end_session(True) # type: ignore[unused-coroutine]
|
||||
|
||||
if "ns" in cursor_info: # noqa: SIM401
|
||||
self._ns = cursor_info["ns"]
|
||||
else:
|
||||
self._ns = collection.full_name
|
||||
|
||||
self.batch_size(batch_size)
|
||||
|
||||
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
|
||||
raise TypeError("max_await_time_ms must be an integer or None")
|
||||
|
||||
def __del__(self) -> None:
|
||||
if _IS_SYNC:
|
||||
self._die(False) # type: ignore[unused-coroutine]
|
||||
|
||||
def batch_size(self, batch_size: int) -> CommandCursor[_DocumentType]:
|
||||
"""Limits the number of documents returned in one batch. Each batch
|
||||
requires a round trip to the server. It can be adjusted to optimize
|
||||
performance and limit data transfer.
|
||||
|
||||
.. note:: batch_size can not override MongoDB's internal limits on the
|
||||
amount of data it will return to the client in a single batch (i.e
|
||||
if you set batch size to 1,000,000,000, MongoDB will currently only
|
||||
return 4-16MB of results per batch).
|
||||
|
||||
Raises :exc:`TypeError` if `batch_size` is not an integer.
|
||||
Raises :exc:`ValueError` if `batch_size` is less than ``0``.
|
||||
|
||||
:param batch_size: The size of each batch of results requested.
|
||||
"""
|
||||
if not isinstance(batch_size, int):
|
||||
raise TypeError("batch_size must be an integer")
|
||||
if batch_size < 0:
|
||||
raise ValueError("batch_size must be >= 0")
|
||||
|
||||
self._batch_size = batch_size == 1 and 2 or batch_size
|
||||
return self
|
||||
|
||||
def _has_next(self) -> bool:
|
||||
"""Returns `True` if the cursor has documents remaining from the
|
||||
previous batch.
|
||||
"""
|
||||
return len(self._data) > 0
|
||||
|
||||
@property
|
||||
def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]:
|
||||
"""Retrieve the postBatchResumeToken from the response to a
|
||||
changeStream aggregate or getMore.
|
||||
"""
|
||||
return self._postbatchresumetoken
|
||||
|
||||
def _maybe_pin_connection(self, conn: Connection) -> None:
|
||||
client = self._collection.database.client
|
||||
if not client._should_pin_cursor(self._session):
|
||||
return
|
||||
if not self._sock_mgr:
|
||||
conn.pin_cursor()
|
||||
conn_mgr = _ConnectionManager(conn, False)
|
||||
# Ensure the connection gets returned when the entire result is
|
||||
# returned in the first batch.
|
||||
if self._id == 0:
|
||||
conn_mgr.close()
|
||||
else:
|
||||
self._sock_mgr = conn_mgr
|
||||
|
||||
def _unpack_response(
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions[Mapping[str, Any]],
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> Sequence[_DocumentOut]:
|
||||
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
"""Does this cursor have the potential to return more data?
|
||||
|
||||
Even if :attr:`alive` is ``True``, :meth:`next` can raise
|
||||
:exc:`StopIteration`. Best to use a for loop::
|
||||
|
||||
async for doc in collection.aggregate(pipeline):
|
||||
print(doc)
|
||||
|
||||
.. note:: :attr:`alive` can be True while iterating a cursor from
|
||||
a failed server. In this case :attr:`alive` will return False after
|
||||
:meth:`next` fails to retrieve the next batch of results from the
|
||||
server.
|
||||
"""
|
||||
return bool(len(self._data) or (not self._killed))
|
||||
|
||||
@property
|
||||
def cursor_id(self) -> int:
|
||||
"""Returns the id of the cursor."""
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def address(self) -> Optional[_Address]:
|
||||
"""The (host, port) of the server used, or None.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return self._address
|
||||
|
||||
@property
|
||||
def session(self) -> Optional[ClientSession]:
|
||||
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
"""
|
||||
if self._explicit_session:
|
||||
return self._session
|
||||
return None
|
||||
|
||||
def _die(self, synchronous: bool = False) -> None:
|
||||
"""Closes this cursor."""
|
||||
already_killed = self._killed
|
||||
self._killed = True
|
||||
if self._id and not already_killed:
|
||||
cursor_id = self._id
|
||||
assert self._address is not None
|
||||
address = _CursorAddress(self._address, self._ns)
|
||||
else:
|
||||
# Skip killCursors.
|
||||
cursor_id = 0
|
||||
address = None
|
||||
self._collection.database.client._cleanup_cursor(
|
||||
synchronous,
|
||||
cursor_id,
|
||||
address,
|
||||
self._sock_mgr,
|
||||
self._session,
|
||||
self._explicit_session,
|
||||
)
|
||||
if not self._explicit_session:
|
||||
self._session = None
|
||||
self._sock_mgr = None
|
||||
|
||||
def _end_session(self, synchronous: bool) -> None:
|
||||
if self._session and not self._explicit_session:
|
||||
self._session._end_session(lock=synchronous)
|
||||
self._session = None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Explicitly close / kill this cursor."""
|
||||
self._die(True)
|
||||
|
||||
def _send_message(self, operation: _GetMore) -> None:
|
||||
"""Send a getmore message and handle the response."""
|
||||
client = self._collection.database.client
|
||||
try:
|
||||
response = client._run_operation(
|
||||
operation, self._unpack_response, address=self._address
|
||||
)
|
||||
except OperationFailure as exc:
|
||||
if exc.code in _CURSOR_CLOSED_ERRORS:
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self._killed = True
|
||||
if exc.timeout:
|
||||
self._die(False)
|
||||
else:
|
||||
# Return the session and pinned connection, if necessary.
|
||||
self.close()
|
||||
raise
|
||||
except ConnectionFailure:
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self._killed = True
|
||||
# Return the session and pinned connection, if necessary.
|
||||
self.close()
|
||||
raise
|
||||
except Exception:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self._sock_mgr:
|
||||
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
|
||||
if response.from_command:
|
||||
cursor = response.docs[0]["cursor"]
|
||||
documents = cursor["nextBatch"]
|
||||
self._postbatchresumetoken = cursor.get("postBatchResumeToken")
|
||||
self._id = cursor["id"]
|
||||
else:
|
||||
documents = response.docs
|
||||
assert isinstance(response.data, _OpReply)
|
||||
self._id = response.data.cursor_id
|
||||
|
||||
if self._id == 0:
|
||||
self.close()
|
||||
self._data = deque(documents)
|
||||
|
||||
def _refresh(self) -> int:
|
||||
"""Refreshes the cursor with more data from the server.
|
||||
|
||||
Returns the length of self._data after refresh. Will exit early if
|
||||
self._data is already non-empty. Raises OperationFailure when the
|
||||
cursor cannot be refreshed due to an error on the query.
|
||||
"""
|
||||
if len(self._data) or self._killed:
|
||||
return len(self._data)
|
||||
|
||||
if self._id: # Get More
|
||||
dbname, collname = self._ns.split(".", 1)
|
||||
read_pref = self._collection._read_preference_for(self.session)
|
||||
self._send_message(
|
||||
self._getmore_class(
|
||||
dbname,
|
||||
collname,
|
||||
self._batch_size,
|
||||
self._id,
|
||||
self._collection.codec_options,
|
||||
read_pref,
|
||||
self._session,
|
||||
self._collection.database.client,
|
||||
self._max_await_time_ms,
|
||||
self._sock_mgr,
|
||||
False,
|
||||
self._comment,
|
||||
)
|
||||
)
|
||||
else: # Cursor id is zero nothing else to return
|
||||
self._die(True)
|
||||
|
||||
return len(self._data)
|
||||
|
||||
def __iter__(self) -> Iterator[_DocumentType]:
|
||||
return self
|
||||
|
||||
def next(self) -> _DocumentType:
|
||||
"""Advance the cursor."""
|
||||
# Block until a document is returnable.
|
||||
while self.alive:
|
||||
doc = self._try_next(True)
|
||||
if doc is not None:
|
||||
return doc
|
||||
|
||||
raise StopIteration
|
||||
|
||||
def __next__(self) -> _DocumentType:
|
||||
return self.next()
|
||||
|
||||
def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor blocking for at most one getMore command."""
|
||||
if not len(self._data) and not self._killed and get_more_allowed:
|
||||
self._refresh()
|
||||
if len(self._data):
|
||||
return self._data.popleft()
|
||||
else:
|
||||
return None
|
||||
|
||||
def try_next(self) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor without blocking indefinitely.
|
||||
|
||||
This method returns the next document without waiting
|
||||
indefinitely for data.
|
||||
|
||||
If no document is cached locally then this method runs a single
|
||||
getMore command. If the getMore yields any documents, the next
|
||||
document is returned, otherwise, if the getMore returns no documents
|
||||
(because there is no additional data) then ``None`` is returned.
|
||||
|
||||
:return: The next document or ``None`` when no document is available
|
||||
after running a single getMore or when the cursor is closed.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
"""
|
||||
return self._try_next(get_more_allowed=True)
|
||||
|
||||
def __enter__(self) -> CommandCursor[_DocumentType]:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def to_list(self) -> list[_DocumentType]:
|
||||
return [x for x in self] # noqa: C416,RUF100
|
||||
|
||||
|
||||
class RawBatchCommandCursor(CommandCursor[_DocumentType]):
|
||||
_getmore_class = _RawBatchGetMore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: Collection[_DocumentType],
|
||||
cursor_info: Mapping[str, Any],
|
||||
address: Optional[_Address],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
explicit_session: bool = False,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
"""Create a new cursor / iterator over raw batches of BSON data.
|
||||
|
||||
Should not be called directly by application developers -
|
||||
see :meth:`~pymongo.collection.Collection.aggregate_raw_batches`
|
||||
instead.
|
||||
|
||||
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
|
||||
"""
|
||||
assert not cursor_info.get("firstBatch")
|
||||
super().__init__(
|
||||
collection,
|
||||
cursor_info,
|
||||
address,
|
||||
batch_size,
|
||||
max_await_time_ms,
|
||||
session,
|
||||
explicit_session,
|
||||
comment,
|
||||
)
|
||||
|
||||
def _unpack_response( # type: ignore[override]
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
|
||||
if not legacy_response:
|
||||
# OP_MSG returns firstBatch/nextBatch documents as a BSON array
|
||||
# Re-assemble the array of documents into a document stream
|
||||
_convert_raw_document_lists_to_streams(raw_response[0])
|
||||
return raw_response # type: ignore[return-value]
|
||||
|
||||
def __getitem__(self, index: int) -> NoReturn:
|
||||
raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor")
|
||||
@ -40,20 +40,22 @@ from bson import SON
|
||||
from bson.binary import UuidRepresentation
|
||||
from bson.codec_options import CodecOptions, DatetimeConversion, TypeRegistry
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo.compression_support import (
|
||||
from pymongo.driver_info import DriverInfo
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.synchronous.compression_support import (
|
||||
validate_compressors,
|
||||
validate_zlib_compression_level,
|
||||
)
|
||||
from pymongo.driver_info import DriverInfo
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.monitoring import _validate_event_listeners
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _MONGOS_MODES, _ServerMode
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.synchronous.monitoring import _validate_event_listeners
|
||||
from pymongo.synchronous.read_preferences import _MONGOS_MODES, _ServerMode
|
||||
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict)
|
||||
|
||||
@ -378,7 +380,7 @@ def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode:
|
||||
|
||||
def validate_auth_mechanism(option: str, value: Any) -> str:
|
||||
"""Validate the authMechanism URI option."""
|
||||
from pymongo.auth import MECHANISMS
|
||||
from pymongo.synchronous.auth import MECHANISMS
|
||||
|
||||
if value not in MECHANISMS:
|
||||
raise ValueError(f"{option} must be in {tuple(MECHANISMS)}")
|
||||
@ -444,7 +446,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
|
||||
elif key in ["ALLOWED_HOSTS"] and isinstance(value, list):
|
||||
props[key] = value
|
||||
elif key in ["OIDC_CALLBACK", "OIDC_HUMAN_CALLBACK"]:
|
||||
from pymongo.auth_oidc import OIDCCallback
|
||||
from pymongo.synchronous.auth_oidc import OIDCCallback
|
||||
|
||||
if not isinstance(value, OIDCCallback):
|
||||
raise ValueError("callback must be an OIDCCallback object")
|
||||
@ -640,7 +642,7 @@ def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[A
|
||||
"""Validate the driver keyword arg."""
|
||||
if value is None:
|
||||
return value
|
||||
from pymongo.encryption_options import AutoEncryptionOpts
|
||||
from pymongo.synchronous.encryption_options import AutoEncryptionOpts
|
||||
|
||||
if not isinstance(value, AutoEncryptionOpts):
|
||||
raise TypeError(f"{option} must be an instance of AutoEncryptionOpts")
|
||||
@ -902,7 +904,7 @@ class BaseObject:
|
||||
) -> None:
|
||||
if not isinstance(codec_options, CodecOptions):
|
||||
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
|
||||
self.__codec_options = codec_options
|
||||
self._codec_options = codec_options
|
||||
|
||||
if not isinstance(read_preference, _ServerMode):
|
||||
raise TypeError(
|
||||
@ -910,24 +912,24 @@ class BaseObject:
|
||||
"pymongo.read_preferences for valid "
|
||||
"options."
|
||||
)
|
||||
self.__read_preference = read_preference
|
||||
self._read_preference = read_preference
|
||||
|
||||
if not isinstance(write_concern, WriteConcern):
|
||||
raise TypeError(
|
||||
"write_concern must be an instance of pymongo.write_concern.WriteConcern"
|
||||
)
|
||||
self.__write_concern = write_concern
|
||||
self._write_concern = write_concern
|
||||
|
||||
if not isinstance(read_concern, ReadConcern):
|
||||
raise TypeError("read_concern must be an instance of pymongo.read_concern.ReadConcern")
|
||||
self.__read_concern = read_concern
|
||||
self._read_concern = read_concern
|
||||
|
||||
@property
|
||||
def codec_options(self) -> CodecOptions:
|
||||
"""Read only access to the :class:`~bson.codec_options.CodecOptions`
|
||||
of this instance.
|
||||
"""
|
||||
return self.__codec_options
|
||||
return self._codec_options
|
||||
|
||||
@property
|
||||
def write_concern(self) -> WriteConcern:
|
||||
@ -937,7 +939,7 @@ class BaseObject:
|
||||
.. versionchanged:: 3.0
|
||||
The :attr:`write_concern` attribute is now read only.
|
||||
"""
|
||||
return self.__write_concern
|
||||
return self._write_concern
|
||||
|
||||
def _write_concern_for(self, session: Optional[ClientSession]) -> WriteConcern:
|
||||
"""Read only access to the write concern of this instance or session."""
|
||||
@ -953,14 +955,14 @@ class BaseObject:
|
||||
.. versionchanged:: 3.0
|
||||
The :attr:`read_preference` attribute is now read only.
|
||||
"""
|
||||
return self.__read_preference
|
||||
return self._read_preference
|
||||
|
||||
def _read_preference_for(self, session: Optional[ClientSession]) -> _ServerMode:
|
||||
"""Read only access to the read preference of this instance or session."""
|
||||
# Override this operation's read preference with the transaction's.
|
||||
if session:
|
||||
return session._txn_read_preference() or self.__read_preference
|
||||
return self.__read_preference
|
||||
return session._txn_read_preference() or self._read_preference
|
||||
return self._read_preference
|
||||
|
||||
@property
|
||||
def read_concern(self) -> ReadConcern:
|
||||
@ -969,7 +971,7 @@ class BaseObject:
|
||||
|
||||
.. versionadded:: 3.2
|
||||
"""
|
||||
return self.__read_concern
|
||||
return self._read_concern
|
||||
|
||||
|
||||
class _CaseInsensitiveDictionary(MutableMapping[str, Any]):
|
||||
@ -16,8 +16,11 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.helpers import _SENSITIVE_COMMANDS
|
||||
from pymongo.helpers_constants import _SENSITIVE_COMMANDS
|
||||
from pymongo.synchronous.hello_compat import HelloCompat
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"}
|
||||
_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD}
|
||||
@ -146,6 +149,7 @@ class ZstdContext:
|
||||
def compress(data: bytes) -> bytes:
|
||||
# ZstdCompressor is not thread safe.
|
||||
# TODO: Use a pool?
|
||||
|
||||
import zstandard
|
||||
|
||||
return zstandard.ZstdCompressor().compress(data)
|
||||
1289
pymongo/synchronous/cursor.py
Normal file
1289
pymongo/synchronous/cursor.py
Normal file
File diff suppressed because it is too large
Load Diff
1419
pymongo/synchronous/database.py
Normal file
1419
pymongo/synchronous/database.py
Normal file
File diff suppressed because it is too large
Load Diff
1120
pymongo/synchronous/encryption.py
Normal file
1120
pymongo/synchronous/encryption.py
Normal file
File diff suppressed because it is too large
Load Diff
270
pymongo/synchronous/encryption_options.py
Normal file
270
pymongo/synchronous/encryption_options.py
Normal file
@ -0,0 +1,270 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Support for automatic client-side field level encryption."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
||||
|
||||
try:
|
||||
import pymongocrypt # type:ignore[import] # noqa: F401
|
||||
|
||||
_HAVE_PYMONGOCRYPT = True
|
||||
except ImportError:
|
||||
_HAVE_PYMONGOCRYPT = False
|
||||
from bson import int64
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.synchronous.common import validate_is_mapping
|
||||
from pymongo.synchronous.uri_parser import _parse_kms_tls_options
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.typings import _DocumentTypeArg
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class AutoEncryptionOpts:
|
||||
"""Options to configure automatic client-side field level encryption."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: Optional[MongoClient[_DocumentTypeArg]] = None,
|
||||
schema_map: Optional[Mapping[str, Any]] = None,
|
||||
bypass_auto_encryption: bool = False,
|
||||
mongocryptd_uri: str = "mongodb://localhost:27020",
|
||||
mongocryptd_bypass_spawn: bool = False,
|
||||
mongocryptd_spawn_path: str = "mongocryptd",
|
||||
mongocryptd_spawn_args: Optional[list[str]] = None,
|
||||
kms_tls_options: Optional[Mapping[str, Any]] = None,
|
||||
crypt_shared_lib_path: Optional[str] = None,
|
||||
crypt_shared_lib_required: bool = False,
|
||||
bypass_query_analysis: bool = False,
|
||||
encrypted_fields_map: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Options to configure automatic client-side field level encryption.
|
||||
|
||||
Automatic client-side field level encryption requires MongoDB >=4.2
|
||||
enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not
|
||||
supported for operations on a database or view and will result in
|
||||
error.
|
||||
|
||||
Although automatic encryption requires MongoDB >=4.2 enterprise or a
|
||||
MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all
|
||||
users. To configure automatic *decryption* without automatic
|
||||
*encryption* set ``bypass_auto_encryption=True``. Explicit
|
||||
encryption and explicit decryption is also supported for all users
|
||||
with the :class:`~pymongo.encryption.ClientEncryption` class.
|
||||
|
||||
See :ref:`automatic-client-side-encryption` for an example.
|
||||
|
||||
:param kms_providers: Map of KMS provider options. The `kms_providers`
|
||||
map values differ by provider:
|
||||
|
||||
- `aws`: Map with "accessKeyId" and "secretAccessKey" as strings.
|
||||
These are the AWS access key ID and AWS secret access key used
|
||||
to generate KMS messages. An optional "sessionToken" may be
|
||||
included to support temporary AWS credentials.
|
||||
- `azure`: Map with "tenantId", "clientId", and "clientSecret" as
|
||||
strings. Additionally, "identityPlatformEndpoint" may also be
|
||||
specified as a string (defaults to 'login.microsoftonline.com').
|
||||
These are the Azure Active Directory credentials used to
|
||||
generate Azure Key Vault messages.
|
||||
- `gcp`: Map with "email" as a string and "privateKey"
|
||||
as `bytes` or a base64 encoded string.
|
||||
Additionally, "endpoint" may also be specified as a string
|
||||
(defaults to 'oauth2.googleapis.com'). These are the
|
||||
credentials used to generate Google Cloud KMS messages.
|
||||
- `kmip`: Map with "endpoint" as a host with required port.
|
||||
For example: ``{"endpoint": "example.com:443"}``.
|
||||
- `local`: Map with "key" as `bytes` (96 bytes in length) or
|
||||
a base64 encoded string which decodes
|
||||
to 96 bytes. "key" is the master key used to encrypt/decrypt
|
||||
data keys. This key should be generated and stored as securely
|
||||
as possible.
|
||||
|
||||
KMS providers may be specified with an optional name suffix
|
||||
separated by a colon, for example "kmip:name" or "aws:name".
|
||||
Named KMS providers do not support :ref:`CSFLE on-demand credentials`.
|
||||
Named KMS providers enables more than one of each KMS provider type to be configured.
|
||||
For example, to configure multiple local KMS providers::
|
||||
|
||||
kms_providers = {
|
||||
"local": {"key": local_kek1}, # Unnamed KMS provider.
|
||||
"local:myname": {"key": local_kek2}, # Named KMS provider with name "myname".
|
||||
}
|
||||
|
||||
:param key_vault_namespace: The namespace for the key vault collection.
|
||||
The key vault collection contains all data keys used for encryption
|
||||
and decryption. Data keys are stored as documents in this MongoDB
|
||||
collection. Data keys are protected with encryption by a KMS
|
||||
provider.
|
||||
:param key_vault_client: By default, the key vault collection
|
||||
is assumed to reside in the same MongoDB cluster as the encrypted
|
||||
MongoClient. Use this option to route data key queries to a
|
||||
separate MongoDB cluster.
|
||||
:param schema_map: Map of collection namespace ("db.coll") to
|
||||
JSON Schema. By default, a collection's JSONSchema is periodically
|
||||
polled with the listCollections command. But a JSONSchema may be
|
||||
specified locally with the schemaMap option.
|
||||
|
||||
**Supplying a `schema_map` provides more security than relying on
|
||||
JSON Schemas obtained from the server. It protects against a
|
||||
malicious server advertising a false JSON Schema, which could trick
|
||||
the client into sending unencrypted data that should be
|
||||
encrypted.**
|
||||
|
||||
Schemas supplied in the schemaMap only apply to configuring
|
||||
automatic encryption for client side encryption. Other validation
|
||||
rules in the JSON schema will not be enforced by the driver and
|
||||
will result in an error.
|
||||
:param bypass_auto_encryption: If ``True``, automatic
|
||||
encryption will be disabled but automatic decryption will still be
|
||||
enabled. Defaults to ``False``.
|
||||
:param mongocryptd_uri: The MongoDB URI used to connect
|
||||
to the *local* mongocryptd process. Defaults to
|
||||
``'mongodb://localhost:27020'``.
|
||||
:param mongocryptd_bypass_spawn: If ``True``, the encrypted
|
||||
MongoClient will not attempt to spawn the mongocryptd process.
|
||||
Defaults to ``False``.
|
||||
:param mongocryptd_spawn_path: Used for spawning the
|
||||
mongocryptd process. Defaults to ``'mongocryptd'`` and spawns
|
||||
mongocryptd from the system path.
|
||||
:param mongocryptd_spawn_args: A list of string arguments to
|
||||
use when spawning the mongocryptd process. Defaults to
|
||||
``['--idleShutdownTimeoutSecs=60']``. If the list does not include
|
||||
the ``idleShutdownTimeoutSecs`` option then
|
||||
``'--idleShutdownTimeoutSecs=60'`` will be added.
|
||||
:param kms_tls_options: A map of KMS provider names to TLS
|
||||
options to use when creating secure connections to KMS providers.
|
||||
Accepts the same TLS options as
|
||||
:class:`pymongo.mongo_client.MongoClient`. For example, to
|
||||
override the system default CA file::
|
||||
|
||||
kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}}
|
||||
|
||||
Or to supply a client certificate::
|
||||
|
||||
kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}}
|
||||
:param crypt_shared_lib_path: Override the path to load the crypt_shared library.
|
||||
:param crypt_shared_lib_required: If True, raise an error if libmongocrypt is
|
||||
unable to load the crypt_shared library.
|
||||
:param bypass_query_analysis: If ``True``, disable automatic analysis
|
||||
of outgoing commands. Set `bypass_query_analysis` to use explicit
|
||||
encryption on indexed fields without the MongoDB Enterprise Advanced
|
||||
licensed crypt_shared library.
|
||||
:param encrypted_fields_map: Map of collection namespace ("db.coll") to documents
|
||||
that described the encrypted fields for Queryable Encryption. For example::
|
||||
|
||||
{
|
||||
"db.encryptedCollection": {
|
||||
"escCollection": "enxcol_.encryptedCollection.esc",
|
||||
"ecocCollection": "enxcol_.encryptedCollection.ecoc",
|
||||
"fields": [
|
||||
{
|
||||
"path": "firstName",
|
||||
"keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')),
|
||||
"bsonType": "string",
|
||||
"queries": {"queryType": "equality"}
|
||||
},
|
||||
{
|
||||
"path": "ssn",
|
||||
"keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')),
|
||||
"bsonType": "string"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`,
|
||||
and `bypass_query_analysis` parameters.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Added the `kms_tls_options` parameter and the "kmip" KMS provider.
|
||||
|
||||
.. versionadded:: 3.9
|
||||
"""
|
||||
if not _HAVE_PYMONGOCRYPT:
|
||||
raise ConfigurationError(
|
||||
"client side encryption requires the pymongocrypt library: "
|
||||
"install a compatible version with: "
|
||||
"python -m pip install 'pymongo[encryption]'"
|
||||
)
|
||||
if encrypted_fields_map:
|
||||
validate_is_mapping("encrypted_fields_map", encrypted_fields_map)
|
||||
self._encrypted_fields_map = encrypted_fields_map
|
||||
self._bypass_query_analysis = bypass_query_analysis
|
||||
self._crypt_shared_lib_path = crypt_shared_lib_path
|
||||
self._crypt_shared_lib_required = crypt_shared_lib_required
|
||||
self._kms_providers = kms_providers
|
||||
self._key_vault_namespace = key_vault_namespace
|
||||
self._key_vault_client = key_vault_client
|
||||
self._schema_map = schema_map
|
||||
self._bypass_auto_encryption = bypass_auto_encryption
|
||||
self._mongocryptd_uri = mongocryptd_uri
|
||||
self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
|
||||
self._mongocryptd_spawn_path = mongocryptd_spawn_path
|
||||
if mongocryptd_spawn_args is None:
|
||||
mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"]
|
||||
self._mongocryptd_spawn_args = mongocryptd_spawn_args
|
||||
if not isinstance(self._mongocryptd_spawn_args, list):
|
||||
raise TypeError("mongocryptd_spawn_args must be a list")
|
||||
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
|
||||
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
|
||||
# Maps KMS provider name to a SSLContext.
|
||||
self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options)
|
||||
self._bypass_query_analysis = bypass_query_analysis
|
||||
|
||||
|
||||
class RangeOpts:
|
||||
"""Options to configure encrypted queries using the rangePreview algorithm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sparsity: int,
|
||||
min: Optional[Any] = None,
|
||||
max: Optional[Any] = None,
|
||||
precision: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Options to configure encrypted queries using the rangePreview algorithm.
|
||||
|
||||
.. note:: This feature is experimental only, and not intended for public use.
|
||||
|
||||
:param sparsity: An integer.
|
||||
:param min: A BSON scalar value corresponding to the type being queried.
|
||||
:param max: A BSON scalar value corresponding to the type being queried.
|
||||
:param precision: An integer, may only be set for double or decimal128 types.
|
||||
|
||||
.. versionadded:: 4.4
|
||||
"""
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.sparsity = sparsity
|
||||
self.precision = precision
|
||||
|
||||
@property
|
||||
def document(self) -> dict[str, Any]:
|
||||
doc = {}
|
||||
for k, v in [
|
||||
("sparsity", int64.Int64(self.sparsity)),
|
||||
("precision", self.precision),
|
||||
("min", self.min),
|
||||
("max", self.max),
|
||||
]:
|
||||
if v is not None:
|
||||
doc[k] = v
|
||||
return doc
|
||||
225
pymongo/synchronous/event_loggers.py
Normal file
225
pymongo/synchronous/event_loggers.py
Normal file
@ -0,0 +1,225 @@
|
||||
# Copyright 2020-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Example event logger classes.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
|
||||
These loggers can be registered using :func:`register` or
|
||||
:class:`~pymongo.mongo_client.MongoClient`.
|
||||
|
||||
``monitoring.register(CommandLogger())``
|
||||
|
||||
or
|
||||
|
||||
``MongoClient(event_listeners=[CommandLogger()])``
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from pymongo.synchronous import monitoring
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class CommandLogger(monitoring.CommandListener):
|
||||
"""A simple listener that logs command events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.CommandStartedEvent`,
|
||||
:class:`~pymongo.monitoring.CommandSucceededEvent` and
|
||||
:class:`~pymongo.monitoring.CommandFailedEvent` events and
|
||||
logs them at the `INFO` severity level using :mod:`logging`.
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def started(self, event: monitoring.CommandStartedEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} started on server "
|
||||
f"{event.connection_id}"
|
||||
)
|
||||
|
||||
def succeeded(self, event: monitoring.CommandSucceededEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} on server {event.connection_id} "
|
||||
f"succeeded in {event.duration_micros} "
|
||||
"microseconds"
|
||||
)
|
||||
|
||||
def failed(self, event: monitoring.CommandFailedEvent) -> None:
|
||||
logging.info(
|
||||
f"Command {event.command_name} with request id "
|
||||
f"{event.request_id} on server {event.connection_id} "
|
||||
f"failed in {event.duration_micros} "
|
||||
"microseconds"
|
||||
)
|
||||
|
||||
|
||||
class ServerLogger(monitoring.ServerListener):
|
||||
"""A simple listener that logs server discovery events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`,
|
||||
:class:`~pymongo.monitoring.ServerDescriptionChangedEvent`,
|
||||
and :class:`~pymongo.monitoring.ServerClosedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def opened(self, event: monitoring.ServerOpeningEvent) -> None:
|
||||
logging.info(f"Server {event.server_address} added to topology {event.topology_id}")
|
||||
|
||||
def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None:
|
||||
previous_server_type = event.previous_description.server_type
|
||||
new_server_type = event.new_description.server_type
|
||||
if new_server_type != previous_server_type:
|
||||
# server_type_name was added in PyMongo 3.4
|
||||
logging.info(
|
||||
f"Server {event.server_address} changed type from "
|
||||
f"{event.previous_description.server_type_name} to "
|
||||
f"{event.new_description.server_type_name}"
|
||||
)
|
||||
|
||||
def closed(self, event: monitoring.ServerClosedEvent) -> None:
|
||||
logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}")
|
||||
|
||||
|
||||
class HeartbeatLogger(monitoring.ServerHeartbeatListener):
|
||||
"""A simple listener that logs server heartbeat events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`,
|
||||
:class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`,
|
||||
and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
|
||||
logging.info(f"Heartbeat sent to server {event.connection_id}")
|
||||
|
||||
def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
|
||||
# The reply.document attribute was added in PyMongo 3.4.
|
||||
logging.info(
|
||||
f"Heartbeat to server {event.connection_id} "
|
||||
"succeeded with reply "
|
||||
f"{event.reply.document}"
|
||||
)
|
||||
|
||||
def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
|
||||
logging.warning(
|
||||
f"Heartbeat to server {event.connection_id} failed with error {event.reply}"
|
||||
)
|
||||
|
||||
|
||||
class TopologyLogger(monitoring.TopologyListener):
|
||||
"""A simple listener that logs server topology events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`,
|
||||
:class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`,
|
||||
and :class:`~pymongo.monitoring.TopologyClosedEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def opened(self, event: monitoring.TopologyOpenedEvent) -> None:
|
||||
logging.info(f"Topology with id {event.topology_id} opened")
|
||||
|
||||
def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None:
|
||||
logging.info(f"Topology description updated for topology id {event.topology_id}")
|
||||
previous_topology_type = event.previous_description.topology_type
|
||||
new_topology_type = event.new_description.topology_type
|
||||
if new_topology_type != previous_topology_type:
|
||||
# topology_type_name was added in PyMongo 3.4
|
||||
logging.info(
|
||||
f"Topology {event.topology_id} changed type from "
|
||||
f"{event.previous_description.topology_type_name} to "
|
||||
f"{event.new_description.topology_type_name}"
|
||||
)
|
||||
# The has_writable_server and has_readable_server methods
|
||||
# were added in PyMongo 3.4.
|
||||
if not event.new_description.has_writable_server():
|
||||
logging.warning("No writable servers available.")
|
||||
if not event.new_description.has_readable_server():
|
||||
logging.warning("No readable servers available.")
|
||||
|
||||
def closed(self, event: monitoring.TopologyClosedEvent) -> None:
|
||||
logging.info(f"Topology with id {event.topology_id} closed")
|
||||
|
||||
|
||||
class ConnectionPoolLogger(monitoring.ConnectionPoolListener):
|
||||
"""A simple listener that logs server connection pool events.
|
||||
|
||||
Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`,
|
||||
:class:`~pymongo.monitoring.PoolClearedEvent`,
|
||||
:class:`~pymongo.monitoring.PoolClosedEvent`,
|
||||
:~pymongo.monitoring.class:`ConnectionCreatedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionReadyEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionClosedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`,
|
||||
:class:`~pymongo.monitoring.ConnectionCheckedOutEvent`,
|
||||
and :class:`~pymongo.monitoring.ConnectionCheckedInEvent`
|
||||
events and logs them at the `INFO` severity level using :mod:`logging`.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
|
||||
def pool_created(self, event: monitoring.PoolCreatedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool created")
|
||||
|
||||
def pool_ready(self, event: monitoring.PoolReadyEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool ready")
|
||||
|
||||
def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool cleared")
|
||||
|
||||
def pool_closed(self, event: monitoring.PoolClosedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] pool closed")
|
||||
|
||||
def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created")
|
||||
|
||||
def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded"
|
||||
)
|
||||
|
||||
def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] "
|
||||
f'connection closed, reason: "{event.reason}"'
|
||||
)
|
||||
|
||||
def connection_check_out_started(
|
||||
self, event: monitoring.ConnectionCheckOutStartedEvent
|
||||
) -> None:
|
||||
logging.info(f"[pool {event.address}] connection check out started")
|
||||
|
||||
def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None:
|
||||
logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}")
|
||||
|
||||
def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool"
|
||||
)
|
||||
|
||||
def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None:
|
||||
logging.info(
|
||||
f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool"
|
||||
)
|
||||
219
pymongo/synchronous/hello.py
Normal file
219
pymongo/synchronous/hello.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Copyright 2021-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Helpers for the 'hello' and legacy hello commands."""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import itertools
|
||||
from typing import Any, Generic, Mapping, Optional
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.synchronous import common
|
||||
from pymongo.synchronous.hello_compat import HelloCompat
|
||||
from pymongo.synchronous.typings import ClusterTime, _DocumentType
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def _get_server_type(doc: Mapping[str, Any]) -> int:
|
||||
"""Determine the server type from a hello response."""
|
||||
if not doc.get("ok"):
|
||||
return SERVER_TYPE.Unknown
|
||||
|
||||
if doc.get("serviceId"):
|
||||
return SERVER_TYPE.LoadBalancer
|
||||
elif doc.get("isreplicaset"):
|
||||
return SERVER_TYPE.RSGhost
|
||||
elif doc.get("setName"):
|
||||
if doc.get("hidden"):
|
||||
return SERVER_TYPE.RSOther
|
||||
elif doc.get(HelloCompat.PRIMARY):
|
||||
return SERVER_TYPE.RSPrimary
|
||||
elif doc.get(HelloCompat.LEGACY_PRIMARY):
|
||||
return SERVER_TYPE.RSPrimary
|
||||
elif doc.get("secondary"):
|
||||
return SERVER_TYPE.RSSecondary
|
||||
elif doc.get("arbiterOnly"):
|
||||
return SERVER_TYPE.RSArbiter
|
||||
else:
|
||||
return SERVER_TYPE.RSOther
|
||||
elif doc.get("msg") == "isdbgrid":
|
||||
return SERVER_TYPE.Mongos
|
||||
else:
|
||||
return SERVER_TYPE.Standalone
|
||||
|
||||
|
||||
class Hello(Generic[_DocumentType]):
|
||||
"""Parse a hello response from the server.
|
||||
|
||||
.. versionadded:: 3.12
|
||||
"""
|
||||
|
||||
__slots__ = ("_doc", "_server_type", "_is_writable", "_is_readable", "_awaitable")
|
||||
|
||||
def __init__(self, doc: _DocumentType, awaitable: bool = False) -> None:
|
||||
self._server_type = _get_server_type(doc)
|
||||
self._doc: _DocumentType = doc
|
||||
self._is_writable = self._server_type in (
|
||||
SERVER_TYPE.RSPrimary,
|
||||
SERVER_TYPE.Standalone,
|
||||
SERVER_TYPE.Mongos,
|
||||
SERVER_TYPE.LoadBalancer,
|
||||
)
|
||||
|
||||
self._is_readable = self.server_type == SERVER_TYPE.RSSecondary or self._is_writable
|
||||
self._awaitable = awaitable
|
||||
|
||||
@property
|
||||
def document(self) -> _DocumentType:
|
||||
"""The complete hello command response document.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
return copy.copy(self._doc)
|
||||
|
||||
@property
|
||||
def server_type(self) -> int:
|
||||
return self._server_type
|
||||
|
||||
@property
|
||||
def all_hosts(self) -> set[tuple[str, int]]:
|
||||
"""List of hosts, passives, and arbiters known to this server."""
|
||||
return set(
|
||||
map(
|
||||
common.clean_node,
|
||||
itertools.chain(
|
||||
self._doc.get("hosts", []),
|
||||
self._doc.get("passives", []),
|
||||
self._doc.get("arbiters", []),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def tags(self) -> Mapping[str, Any]:
|
||||
"""Replica set member tags or empty dict."""
|
||||
return self._doc.get("tags", {})
|
||||
|
||||
@property
|
||||
def primary(self) -> Optional[tuple[str, int]]:
|
||||
"""This server's opinion about who the primary is, or None."""
|
||||
if self._doc.get("primary"):
|
||||
return common.partition_node(self._doc["primary"])
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""Replica set name or None."""
|
||||
return self._doc.get("setName")
|
||||
|
||||
@property
|
||||
def max_bson_size(self) -> int:
|
||||
return self._doc.get("maxBsonObjectSize", common.MAX_BSON_SIZE)
|
||||
|
||||
@property
|
||||
def max_message_size(self) -> int:
|
||||
return self._doc.get("maxMessageSizeBytes", 2 * self.max_bson_size)
|
||||
|
||||
@property
|
||||
def max_write_batch_size(self) -> int:
|
||||
return self._doc.get("maxWriteBatchSize", common.MAX_WRITE_BATCH_SIZE)
|
||||
|
||||
@property
|
||||
def min_wire_version(self) -> int:
|
||||
return self._doc.get("minWireVersion", common.MIN_WIRE_VERSION)
|
||||
|
||||
@property
|
||||
def max_wire_version(self) -> int:
|
||||
return self._doc.get("maxWireVersion", common.MAX_WIRE_VERSION)
|
||||
|
||||
@property
|
||||
def set_version(self) -> Optional[int]:
|
||||
return self._doc.get("setVersion")
|
||||
|
||||
@property
|
||||
def election_id(self) -> Optional[ObjectId]:
|
||||
return self._doc.get("electionId")
|
||||
|
||||
@property
|
||||
def cluster_time(self) -> Optional[ClusterTime]:
|
||||
return self._doc.get("$clusterTime")
|
||||
|
||||
@property
|
||||
def logical_session_timeout_minutes(self) -> Optional[int]:
|
||||
return self._doc.get("logicalSessionTimeoutMinutes")
|
||||
|
||||
@property
|
||||
def is_writable(self) -> bool:
|
||||
return self._is_writable
|
||||
|
||||
@property
|
||||
def is_readable(self) -> bool:
|
||||
return self._is_readable
|
||||
|
||||
@property
|
||||
def me(self) -> Optional[tuple[str, int]]:
|
||||
me = self._doc.get("me")
|
||||
if me:
|
||||
return common.clean_node(me)
|
||||
return None
|
||||
|
||||
@property
|
||||
def last_write_date(self) -> Optional[datetime.datetime]:
|
||||
return self._doc.get("lastWrite", {}).get("lastWriteDate")
|
||||
|
||||
@property
|
||||
def compressors(self) -> Optional[list[str]]:
|
||||
return self._doc.get("compression")
|
||||
|
||||
@property
|
||||
def sasl_supported_mechs(self) -> list[str]:
|
||||
"""Supported authentication mechanisms for the current user.
|
||||
|
||||
For example::
|
||||
|
||||
>>> hello.sasl_supported_mechs
|
||||
["SCRAM-SHA-1", "SCRAM-SHA-256"]
|
||||
|
||||
"""
|
||||
return self._doc.get("saslSupportedMechs", [])
|
||||
|
||||
@property
|
||||
def speculative_authenticate(self) -> Optional[Mapping[str, Any]]:
|
||||
"""The speculativeAuthenticate field."""
|
||||
return self._doc.get("speculativeAuthenticate")
|
||||
|
||||
@property
|
||||
def topology_version(self) -> Optional[Mapping[str, Any]]:
|
||||
return self._doc.get("topologyVersion")
|
||||
|
||||
@property
|
||||
def awaitable(self) -> bool:
|
||||
return self._awaitable
|
||||
|
||||
@property
|
||||
def service_id(self) -> Optional[ObjectId]:
|
||||
return self._doc.get("serviceId")
|
||||
|
||||
@property
|
||||
def hello_ok(self) -> bool:
|
||||
return self._doc.get("helloOk", False)
|
||||
|
||||
@property
|
||||
def connection_id(self) -> Optional[int]:
|
||||
return self._doc.get("connectionId")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user