PYTHON-3801 gRPC POC phase 1 (#1317)
This commit is contained in:
parent
c259dde1de
commit
34ca694c9f
@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
from pymongo.database import Database
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.server import Server
|
||||
from pymongo.typings import _Pipeline
|
||||
@ -52,7 +52,7 @@ class _AggregationCommand:
|
||||
explicit_session: bool,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
user_fields: Optional[MutableMapping[str, Any]] = None,
|
||||
result_processor: Optional[Callable[[Mapping[str, Any], SocketInfo], None]] = None,
|
||||
result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
if "explain" in options:
|
||||
@ -134,7 +134,7 @@ class _AggregationCommand:
|
||||
self,
|
||||
session: ClientSession,
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
read_preference: _ServerMode,
|
||||
) -> CommandCursor:
|
||||
# Serialize command.
|
||||
@ -146,7 +146,7 @@ class _AggregationCommand:
|
||||
# - server version is >= 4.2 or
|
||||
# - server version is >= 3.2 and pipeline doesn't use $out
|
||||
if ("readConcern" not in cmd) and (
|
||||
not self._performs_write or (sock_info.max_wire_version >= 8)
|
||||
not self._performs_write or (connection.max_wire_version >= 8)
|
||||
):
|
||||
read_concern = self._target.read_concern
|
||||
else:
|
||||
@ -161,7 +161,7 @@ class _AggregationCommand:
|
||||
write_concern = None
|
||||
|
||||
# Run command.
|
||||
result = sock_info.command(
|
||||
result = connection.command(
|
||||
self._database.name,
|
||||
cmd,
|
||||
read_preference,
|
||||
@ -176,7 +176,7 @@ class _AggregationCommand:
|
||||
)
|
||||
|
||||
if self._result_processor:
|
||||
self._result_processor(result, sock_info)
|
||||
self._result_processor(result, connection)
|
||||
|
||||
# Extract cursor from result or mock/fake one if necessary.
|
||||
if "cursor" in result:
|
||||
@ -193,14 +193,14 @@ class _AggregationCommand:
|
||||
cmd_cursor = self._cursor_class(
|
||||
self._cursor_collection(cursor),
|
||||
cursor,
|
||||
sock_info.address,
|
||||
connection.address,
|
||||
batch_size=self._batch_size or 0,
|
||||
max_await_time_ms=self._max_await_time_ms,
|
||||
session=session,
|
||||
explicit_session=self._explicit_session,
|
||||
comment=self._options.get("comment"),
|
||||
)
|
||||
cmd_cursor._maybe_pin_connection(sock_info)
|
||||
cmd_cursor._maybe_pin_connection(connection)
|
||||
return cmd_cursor
|
||||
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ from pymongo.saslprep import saslprep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
|
||||
HAVE_KERBEROS = True
|
||||
_USE_PRINCIPAL = False
|
||||
@ -221,7 +221,7 @@ def _authenticate_scram_start(
|
||||
|
||||
|
||||
def _authenticate_scram(
|
||||
credentials: MongoCredential, sock_info: SocketInfo, mechanism: str
|
||||
credentials: MongoCredential, connection: Connection, mechanism: str
|
||||
) -> None:
|
||||
"""Authenticate using SCRAM."""
|
||||
username = credentials.username
|
||||
@ -239,13 +239,13 @@ def _authenticate_scram(
|
||||
# Make local
|
||||
_hmac = hmac.HMAC
|
||||
|
||||
ctx = sock_info.auth_ctx
|
||||
ctx = connection.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
nonce, first_bare = ctx.scram_data
|
||||
res = ctx.speculative_authenticate
|
||||
else:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
|
||||
res = sock_info.command(source, cmd)
|
||||
res = connection.command(source, cmd)
|
||||
|
||||
server_first = res["payload"]
|
||||
parsed = _parse_scram_response(server_first)
|
||||
@ -285,7 +285,7 @@ def _authenticate_scram(
|
||||
("payload", Binary(client_final)),
|
||||
]
|
||||
)
|
||||
res = sock_info.command(source, cmd)
|
||||
res = connection.command(source, cmd)
|
||||
|
||||
parsed = _parse_scram_response(res["payload"])
|
||||
if not hmac.compare_digest(parsed[b"v"], server_sig):
|
||||
@ -301,7 +301,7 @@ def _authenticate_scram(
|
||||
("payload", Binary(b"")),
|
||||
]
|
||||
)
|
||||
res = sock_info.command(source, cmd)
|
||||
res = connection.command(source, cmd)
|
||||
if not res["done"]:
|
||||
raise OperationFailure("SASL conversation failed to complete.")
|
||||
|
||||
@ -345,7 +345,7 @@ def _canonicalize_hostname(hostname: str) -> str:
|
||||
return name[0].lower()
|
||||
|
||||
|
||||
def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
def _authenticate_gssapi(credentials: MongoCredential, connection: Connection) -> None:
|
||||
"""Authenticate using GSSAPI."""
|
||||
if not HAVE_KERBEROS:
|
||||
raise ConfigurationError(
|
||||
@ -358,7 +358,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
|
||||
props = credentials.mechanism_properties
|
||||
# Starting here and continuing through the while loop below - establish
|
||||
# the security context. See RFC 4752, Section 3.1, first paragraph.
|
||||
host = sock_info.address[0]
|
||||
host = connection.address[0]
|
||||
if props.canonicalize_host_name:
|
||||
host = _canonicalize_hostname(host)
|
||||
service = props.service_name + "@" + host
|
||||
@ -413,7 +413,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
|
||||
("autoAuthorize", 1),
|
||||
]
|
||||
)
|
||||
response = sock_info.command("$external", cmd)
|
||||
response = connection.command("$external", cmd)
|
||||
|
||||
# Limit how many times we loop to catch protocol / library issues
|
||||
for _ in range(10):
|
||||
@ -430,7 +430,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
|
||||
("payload", payload),
|
||||
]
|
||||
)
|
||||
response = sock_info.command("$external", cmd)
|
||||
response = connection.command("$external", cmd)
|
||||
|
||||
if result == kerberos.AUTH_GSS_COMPLETE:
|
||||
break
|
||||
@ -453,7 +453,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
|
||||
("payload", payload),
|
||||
]
|
||||
)
|
||||
sock_info.command("$external", cmd)
|
||||
connection.command("$external", cmd)
|
||||
|
||||
finally:
|
||||
kerberos.authGSSClientClean(ctx)
|
||||
@ -462,7 +462,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
|
||||
raise OperationFailure(str(exc))
|
||||
|
||||
|
||||
def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
def _authenticate_plain(credentials: MongoCredential, connection: Connection) -> None:
|
||||
"""Authenticate using SASL PLAIN (RFC 4616)"""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
@ -476,52 +476,52 @@ def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) ->
|
||||
("autoAuthorize", 1),
|
||||
]
|
||||
)
|
||||
sock_info.command(source, cmd)
|
||||
connection.command(source, cmd)
|
||||
|
||||
|
||||
def _authenticate_x509(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
def _authenticate_x509(credentials: MongoCredential, connection: Connection) -> None:
|
||||
"""Authenticate using MONGODB-X509."""
|
||||
ctx = sock_info.auth_ctx
|
||||
ctx = connection.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
# MONGODB-X509 is done after the speculative auth step.
|
||||
return
|
||||
|
||||
cmd = _X509Context(credentials, sock_info.address).speculate_command()
|
||||
sock_info.command("$external", cmd)
|
||||
cmd = _X509Context(credentials, connection.address).speculate_command()
|
||||
connection.command("$external", cmd)
|
||||
|
||||
|
||||
def _authenticate_mongo_cr(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
def _authenticate_mongo_cr(credentials: MongoCredential, connection: Connection) -> None:
|
||||
"""Authenticate using MONGODB-CR."""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
# Get a nonce
|
||||
response = sock_info.command(source, {"getnonce": 1})
|
||||
response = connection.command(source, {"getnonce": 1})
|
||||
nonce = response["nonce"]
|
||||
key = _auth_key(nonce, username, password)
|
||||
|
||||
# Actually authenticate
|
||||
query = SON([("authenticate", 1), ("user", username), ("nonce", nonce), ("key", key)])
|
||||
sock_info.command(source, query)
|
||||
connection.command(source, query)
|
||||
|
||||
|
||||
def _authenticate_default(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
if sock_info.max_wire_version >= 7:
|
||||
if sock_info.negotiated_mechs:
|
||||
mechs = sock_info.negotiated_mechs
|
||||
def _authenticate_default(credentials: MongoCredential, connection: Connection) -> None:
|
||||
if connection.max_wire_version >= 7:
|
||||
if connection.negotiated_mechs:
|
||||
mechs = connection.negotiated_mechs
|
||||
else:
|
||||
source = credentials.source
|
||||
cmd = sock_info.hello_cmd()
|
||||
cmd = connection.hello_cmd()
|
||||
cmd["saslSupportedMechs"] = source + "." + credentials.username
|
||||
mechs = sock_info.command(source, cmd, publish_events=False).get(
|
||||
mechs = connection.command(source, cmd, publish_events=False).get(
|
||||
"saslSupportedMechs", []
|
||||
)
|
||||
if "SCRAM-SHA-256" in mechs:
|
||||
return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-256")
|
||||
return _authenticate_scram(credentials, connection, "SCRAM-SHA-256")
|
||||
else:
|
||||
return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1")
|
||||
return _authenticate_scram(credentials, connection, "SCRAM-SHA-1")
|
||||
else:
|
||||
return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1")
|
||||
return _authenticate_scram(credentials, connection, "SCRAM-SHA-1")
|
||||
|
||||
|
||||
_AUTH_MAP: Mapping[str, Callable] = {
|
||||
@ -606,12 +606,12 @@ _SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = {
|
||||
|
||||
|
||||
def authenticate(
|
||||
credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool = False
|
||||
credentials: MongoCredential, connection: Connection, reauthenticate: bool = False
|
||||
) -> None:
|
||||
"""Authenticate sock_info."""
|
||||
"""Authenticate connection."""
|
||||
mechanism = credentials.mechanism
|
||||
auth_func = _AUTH_MAP[mechanism]
|
||||
if mechanism == "MONGODB-OIDC":
|
||||
_authenticate_oidc(credentials, sock_info, reauthenticate)
|
||||
_authenticate_oidc(credentials, connection, reauthenticate)
|
||||
else:
|
||||
auth_func(credentials, sock_info)
|
||||
auth_func(credentials, connection)
|
||||
|
||||
@ -49,7 +49,7 @@ from pymongo.errors import ConfigurationError, OperationFailure
|
||||
if TYPE_CHECKING:
|
||||
from bson.typings import _ReadableBuffer
|
||||
from pymongo.auth import MongoCredential
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
|
||||
|
||||
class _AwsSaslContext(AwsSaslContext): # type: ignore
|
||||
@ -67,7 +67,7 @@ class _AwsSaslContext(AwsSaslContext): # type: ignore
|
||||
return bson.decode(data)
|
||||
|
||||
|
||||
def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
def _authenticate_aws(credentials: MongoCredential, connection: Connection) -> None:
|
||||
"""Authenticate using MONGODB-AWS."""
|
||||
if not _HAVE_MONGODB_AWS:
|
||||
raise ConfigurationError(
|
||||
@ -75,7 +75,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No
|
||||
"install with: python -m pip install 'pymongo[aws]'"
|
||||
)
|
||||
|
||||
if sock_info.max_wire_version < 9:
|
||||
if connection.max_wire_version < 9:
|
||||
raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later")
|
||||
|
||||
try:
|
||||
@ -90,7 +90,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No
|
||||
client_first = SON(
|
||||
[("saslStart", 1), ("mechanism", "MONGODB-AWS"), ("payload", client_payload)]
|
||||
)
|
||||
server_first = sock_info.command("$external", client_first)
|
||||
server_first = connection.command("$external", client_first)
|
||||
res = server_first
|
||||
# Limit how many times we loop to catch protocol / library issues
|
||||
for _ in range(10):
|
||||
@ -102,7 +102,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No
|
||||
("payload", client_payload),
|
||||
]
|
||||
)
|
||||
res = sock_info.command("$external", cmd)
|
||||
res = connection.command("$external", cmd)
|
||||
if res["done"]:
|
||||
# SASL complete.
|
||||
break
|
||||
|
||||
@ -29,7 +29,7 @@ from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.auth import MongoCredential
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -243,24 +243,24 @@ class _OIDCAuthenticator:
|
||||
self.token_exp_utc = None
|
||||
|
||||
def run_command(
|
||||
self, sock_info: SocketInfo, cmd: Mapping[str, Any]
|
||||
self, connection: Connection, cmd: Mapping[str, Any]
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
try:
|
||||
return sock_info.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
|
||||
return connection.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
|
||||
except OperationFailure as exc:
|
||||
self.clear()
|
||||
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
|
||||
if "jwt" in bson.decode(cmd["payload"]):
|
||||
if self.idp_info_gen_id > self.reauth_gen_id:
|
||||
raise
|
||||
return self.authenticate(sock_info, reauthenticate=True)
|
||||
return self.authenticate(connection, reauthenticate=True)
|
||||
raise
|
||||
|
||||
def authenticate(
|
||||
self, sock_info: SocketInfo, reauthenticate: bool = False
|
||||
self, connection: Connection, reauthenticate: bool = False
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
if reauthenticate:
|
||||
prev_id = getattr(sock_info, "oidc_token_gen_id", None)
|
||||
prev_id = getattr(connection, "oidc_token_gen_id", None)
|
||||
# Check if we've already changed tokens.
|
||||
if prev_id == self.token_gen_id:
|
||||
self.reauth_gen_id = self.idp_info_gen_id
|
||||
@ -268,7 +268,7 @@ class _OIDCAuthenticator:
|
||||
if not self.properties.refresh_token_callback:
|
||||
self.clear()
|
||||
|
||||
ctx = sock_info.auth_ctx
|
||||
ctx = connection.auth_ctx
|
||||
cmd = None
|
||||
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
@ -276,10 +276,10 @@ class _OIDCAuthenticator:
|
||||
else:
|
||||
cmd = self.auth_start_cmd()
|
||||
assert cmd is not None
|
||||
resp = self.run_command(sock_info, cmd)
|
||||
resp = self.run_command(connection, cmd)
|
||||
|
||||
if resp["done"]:
|
||||
sock_info.oidc_token_gen_id = self.token_gen_id
|
||||
connection.oidc_token_gen_id = self.token_gen_id
|
||||
return None
|
||||
|
||||
server_resp: Dict = bson.decode(resp["payload"])
|
||||
@ -289,7 +289,7 @@ class _OIDCAuthenticator:
|
||||
|
||||
conversation_id = resp["conversationId"]
|
||||
token = self.get_current_token()
|
||||
sock_info.oidc_token_gen_id = self.token_gen_id
|
||||
connection.oidc_token_gen_id = self.token_gen_id
|
||||
bin_payload = Binary(bson.encode({"jwt": token}))
|
||||
cmd = SON(
|
||||
[
|
||||
@ -298,7 +298,7 @@ class _OIDCAuthenticator:
|
||||
("payload", bin_payload),
|
||||
]
|
||||
)
|
||||
resp = self.run_command(sock_info, cmd)
|
||||
resp = self.run_command(connection, cmd)
|
||||
if not resp["done"]:
|
||||
self.clear()
|
||||
raise OperationFailure("SASL conversation failed to complete.")
|
||||
@ -306,8 +306,8 @@ class _OIDCAuthenticator:
|
||||
|
||||
|
||||
def _authenticate_oidc(
|
||||
credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool
|
||||
credentials: MongoCredential, connection: Connection, reauthenticate: bool
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""Authenticate using MONGODB-OIDC."""
|
||||
authenticator = _get_authenticator(credentials, sock_info.address)
|
||||
return authenticator.authenticate(sock_info, reauthenticate=reauthenticate)
|
||||
authenticator = _get_authenticator(credentials, connection.address)
|
||||
return authenticator.authenticate(connection, reauthenticate=reauthenticate)
|
||||
|
||||
@ -65,7 +65,7 @@ from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
|
||||
|
||||
_DELETE_ALL: int = 0
|
||||
@ -311,7 +311,7 @@ class _Bulk:
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[ClientSession],
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
op_id: int,
|
||||
retryable: bool,
|
||||
full_result: MutableMapping[str, Any],
|
||||
@ -326,9 +326,9 @@ class _Bulk:
|
||||
self.next_run = None
|
||||
run = self.current_run
|
||||
|
||||
# sock_info.command validates the session, but we use
|
||||
# sock_info.write_command.
|
||||
sock_info.validate_session(client, session)
|
||||
# connection.command validates the session, but we use
|
||||
# connection.write_command.
|
||||
connection.validate_session(client, session)
|
||||
last_run = False
|
||||
|
||||
while run:
|
||||
@ -341,7 +341,7 @@ class _Bulk:
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
sock_info,
|
||||
connection,
|
||||
op_id,
|
||||
listeners,
|
||||
session,
|
||||
@ -369,11 +369,11 @@ class _Bulk:
|
||||
if retryable and not self.started_retryable_write:
|
||||
session._start_retryable_write()
|
||||
self.started_retryable_write = True
|
||||
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, sock_info)
|
||||
sock_info.send_cluster_time(cmd, session, client)
|
||||
sock_info.add_server_api(cmd)
|
||||
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, connection)
|
||||
connection.send_cluster_time(cmd, session, client)
|
||||
connection.add_server_api(cmd)
|
||||
# CSOT: apply timeout before encoding the command.
|
||||
sock_info.apply_timeout(client, cmd)
|
||||
connection.apply_timeout(client, cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
|
||||
# Run as many ops as possible in one command.
|
||||
@ -430,13 +430,13 @@ class _Bulk:
|
||||
op_id = _randint()
|
||||
|
||||
def retryable_bulk(
|
||||
session: Optional[ClientSession], sock_info: SocketInfo, retryable: bool
|
||||
session: Optional[ClientSession], connection: Connection, retryable: bool
|
||||
) -> None:
|
||||
self._execute_command(
|
||||
generator,
|
||||
write_concern,
|
||||
session,
|
||||
sock_info,
|
||||
connection,
|
||||
op_id,
|
||||
retryable,
|
||||
full_result,
|
||||
@ -450,7 +450,7 @@ class _Bulk:
|
||||
_raise_bulk_write_error(full_result)
|
||||
return full_result
|
||||
|
||||
def execute_op_msg_no_results(self, sock_info: SocketInfo, generator: Iterator[Any]) -> None:
|
||||
def execute_op_msg_no_results(self, connection: Connection, generator: Iterator[Any]) -> None:
|
||||
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
|
||||
db_name = self.collection.database.name
|
||||
client = self.collection.database.client
|
||||
@ -466,7 +466,7 @@ class _Bulk:
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
sock_info,
|
||||
connection,
|
||||
op_id,
|
||||
listeners,
|
||||
None,
|
||||
@ -482,7 +482,7 @@ class _Bulk:
|
||||
("writeConcern", {"w": 0}),
|
||||
]
|
||||
)
|
||||
sock_info.add_server_api(cmd)
|
||||
connection.add_server_api(cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
# Run as many ops as possible.
|
||||
to_send = bwc.execute_unack(cmd, ops, client)
|
||||
@ -491,7 +491,7 @@ class _Bulk:
|
||||
|
||||
def execute_command_no_results(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
) -> None:
|
||||
@ -516,7 +516,7 @@ class _Bulk:
|
||||
generator,
|
||||
initial_write_concern,
|
||||
None,
|
||||
sock_info,
|
||||
connection,
|
||||
op_id,
|
||||
False,
|
||||
full_result,
|
||||
@ -527,7 +527,7 @@ class _Bulk:
|
||||
|
||||
def execute_no_results(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
) -> None:
|
||||
@ -538,11 +538,11 @@ class _Bulk:
|
||||
raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.")
|
||||
# Guard against unsupported unacknowledged writes.
|
||||
unack = write_concern and not write_concern.acknowledged
|
||||
if unack and self.uses_hint_delete and sock_info.max_wire_version < 9:
|
||||
if unack and self.uses_hint_delete and connection.max_wire_version < 9:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
|
||||
)
|
||||
if unack and self.uses_hint_update and sock_info.max_wire_version < 8:
|
||||
if unack and self.uses_hint_update and connection.max_wire_version < 8:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
|
||||
)
|
||||
@ -553,8 +553,8 @@ class _Bulk:
|
||||
)
|
||||
|
||||
if self.ordered:
|
||||
return self.execute_command_no_results(sock_info, generator, write_concern)
|
||||
return self.execute_op_msg_no_results(sock_info, generator)
|
||||
return self.execute_command_no_results(connection, generator, write_concern)
|
||||
return self.execute_op_msg_no_results(connection, generator)
|
||||
|
||||
def execute(self, write_concern: WriteConcern, session: Optional[ClientSession]) -> Any:
|
||||
"""Execute operations."""
|
||||
@ -573,8 +573,8 @@ class _Bulk:
|
||||
|
||||
client = self.collection.database.client
|
||||
if not write_concern.acknowledged:
|
||||
with client._socket_for_writes(session) as sock_info:
|
||||
self.execute_no_results(sock_info, generator, write_concern)
|
||||
with client._socket_for_writes(session) as connection:
|
||||
self.execute_no_results(connection, generator, write_concern)
|
||||
return None
|
||||
else:
|
||||
return self.execute_command(generator, write_concern, session)
|
||||
|
||||
@ -78,7 +78,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
|
||||
|
||||
def _resumable(exc: PyMongoError) -> bool:
|
||||
@ -213,7 +213,7 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
full_pipeline.extend(self._pipeline)
|
||||
return full_pipeline
|
||||
|
||||
def _process_result(self, result: Mapping[str, Any], sock_info: SocketInfo) -> None:
|
||||
def _process_result(self, result: Mapping[str, Any], connection: Connection) -> None:
|
||||
"""Callback that caches the postBatchResumeToken or
|
||||
startAtOperationTime from a changeStream aggregate command response
|
||||
containing an empty batch of change documents.
|
||||
@ -228,7 +228,7 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
self._start_at_operation_time is None
|
||||
and self._uses_resume_after is False
|
||||
and self._uses_start_after is False
|
||||
and sock_info.max_wire_version >= 7
|
||||
and connection.max_wire_version >= 7
|
||||
):
|
||||
self._start_at_operation_time = result.get("operationTime")
|
||||
# PYTHON-2181: informative error on missing operationTime.
|
||||
|
||||
@ -24,7 +24,7 @@ from pymongo.common import validate_boolean
|
||||
from pymongo.compression_support import CompressionSettings
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.pool import PoolOptions
|
||||
from pymongo.pool import ConnectionProtocol, PoolOptions
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import (
|
||||
_ServerMode,
|
||||
@ -162,6 +162,13 @@ def _parse_pool_options(
|
||||
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
|
||||
load_balanced = options.get("loadbalanced")
|
||||
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
|
||||
|
||||
grpc_enabled = options.get("grpc", False)
|
||||
if grpc_enabled:
|
||||
protocol = ConnectionProtocol.GRPC
|
||||
else:
|
||||
protocol = ConnectionProtocol.TCP_SOCKET
|
||||
|
||||
return PoolOptions(
|
||||
max_pool_size,
|
||||
min_pool_size,
|
||||
@ -179,6 +186,7 @@ def _parse_pool_options(
|
||||
server_api=server_api,
|
||||
load_balanced=load_balanced,
|
||||
credentials=credentials,
|
||||
protocol=protocol,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -178,7 +178,7 @@ from pymongo.write_concern import WriteConcern
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.server import Server
|
||||
|
||||
|
||||
@ -412,17 +412,17 @@ class _Transaction:
|
||||
return self.state == _TxnState.STARTING
|
||||
|
||||
@property
|
||||
def pinned_conn(self) -> Optional[SocketInfo]:
|
||||
def pinned_conn(self) -> Optional[Connection]:
|
||||
if self.active() and self.sock_mgr:
|
||||
return self.sock_mgr.sock
|
||||
return None
|
||||
|
||||
def pin(self, server: Server, sock_info: SocketInfo) -> None:
|
||||
def pin(self, server: Server, connection: Connection) -> None:
|
||||
self.sharded = True
|
||||
self.pinned_address = server.description.address
|
||||
if server.description.server_type == SERVER_TYPE.LoadBalancer:
|
||||
sock_info.pin_txn()
|
||||
self.sock_mgr = _SocketManager(sock_info, False)
|
||||
connection.pin_txn()
|
||||
self.sock_mgr = _SocketManager(connection, False)
|
||||
|
||||
def unpin(self) -> None:
|
||||
self.pinned_address = None
|
||||
@ -839,12 +839,12 @@ class ClientSession:
|
||||
- `command_name`: Either "commitTransaction" or "abortTransaction".
|
||||
"""
|
||||
|
||||
def func(session: ClientSession, sock_info: SocketInfo, retryable: bool) -> Dict[str, Any]:
|
||||
return self._finish_transaction(sock_info, command_name)
|
||||
def func(session: ClientSession, connection: Connection, retryable: bool) -> Dict[str, Any]:
|
||||
return self._finish_transaction(connection, command_name)
|
||||
|
||||
return self._client._retry_internal(True, func, self, None)
|
||||
|
||||
def _finish_transaction(self, sock_info: SocketInfo, command_name: str) -> Dict[str, Any]:
|
||||
def _finish_transaction(self, connection: Connection, command_name: str) -> Dict[str, Any]:
|
||||
self._transaction.attempt += 1
|
||||
opts = self._transaction.opts
|
||||
assert opts
|
||||
@ -868,7 +868,7 @@ class ClientSession:
|
||||
cmd["recoveryToken"] = self._transaction.recovery_token
|
||||
|
||||
return self._client.admin._command(
|
||||
sock_info, cmd, session=self, write_concern=wc, parse_write_concern_error=True
|
||||
connection, cmd, session=self, write_concern=wc, parse_write_concern_error=True
|
||||
)
|
||||
|
||||
def _advance_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None:
|
||||
@ -954,13 +954,13 @@ class ClientSession:
|
||||
return None
|
||||
|
||||
@property
|
||||
def _pinned_connection(self) -> Optional[SocketInfo]:
|
||||
def _pinned_connection(self) -> Optional[Connection]:
|
||||
"""The connection this transaction was started on."""
|
||||
return self._transaction.pinned_conn
|
||||
|
||||
def _pin(self, server: Server, sock_info: SocketInfo) -> None:
|
||||
def _pin(self, server: Server, connection: Connection) -> None:
|
||||
"""Pin this session to the given Server or to the given connection."""
|
||||
self._transaction.pin(server, sock_info)
|
||||
self._transaction.pin(server, connection)
|
||||
|
||||
def _unpin(self) -> None:
|
||||
"""Unpin this session from any pinned Server."""
|
||||
@ -985,12 +985,12 @@ class ClientSession:
|
||||
command: MutableMapping[str, Any],
|
||||
is_retryable: bool,
|
||||
read_preference: ReadPreference,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
self._check_ended()
|
||||
self._materialize()
|
||||
if self.options.snapshot:
|
||||
self._update_read_concern(command, sock_info)
|
||||
self._update_read_concern(command, connection)
|
||||
|
||||
self._server_session.last_use = time.monotonic()
|
||||
command["lsid"] = self._server_session.session_id
|
||||
@ -1016,7 +1016,7 @@ class ClientSession:
|
||||
rc = self._transaction.opts.read_concern.document
|
||||
if rc:
|
||||
command["readConcern"] = rc
|
||||
self._update_read_concern(command, sock_info)
|
||||
self._update_read_concern(command, connection)
|
||||
|
||||
command["txnNumber"] = self._server_session.transaction_id
|
||||
command["autocommit"] = False
|
||||
@ -1025,11 +1025,11 @@ class ClientSession:
|
||||
self._check_ended()
|
||||
self._server_session.inc_transaction_id()
|
||||
|
||||
def _update_read_concern(self, cmd: MutableMapping[str, Any], sock_info: SocketInfo) -> None:
|
||||
def _update_read_concern(self, cmd: MutableMapping[str, Any], connection: Connection) -> None:
|
||||
if self.options.causal_consistency and self.operation_time is not None:
|
||||
cmd.setdefault("readConcern", {})["afterClusterTime"] = self.operation_time
|
||||
if self.options.snapshot:
|
||||
if sock_info.max_wire_version < 13:
|
||||
if connection.max_wire_version < 13:
|
||||
raise ConfigurationError("Snapshot reads require MongoDB 5.0 or later")
|
||||
rc = cmd.setdefault("readConcern", {})
|
||||
rc["level"] = "snapshot"
|
||||
|
||||
@ -126,7 +126,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collation import Collation
|
||||
from pymongo.database import Database
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.server import Server
|
||||
|
||||
@ -264,15 +264,15 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _socket_for_reads(
|
||||
self, session: ClientSession
|
||||
) -> ContextManager[Tuple[SocketInfo, Union[PrimaryPreferred, Primary]]]:
|
||||
) -> ContextManager[Tuple[Connection, Union[PrimaryPreferred, Primary]]]:
|
||||
return self.__database.client._socket_for_reads(self._read_preference_for(session), session)
|
||||
|
||||
def _socket_for_writes(self, session: Optional[ClientSession]) -> ContextManager[SocketInfo]:
|
||||
def _socket_for_writes(self, session: Optional[ClientSession]) -> ContextManager[Connection]:
|
||||
return self.__database.client._socket_for_writes(session)
|
||||
|
||||
def _command(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
command: Mapping[str, Any],
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
@ -288,7 +288,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
"""Internal command helper.
|
||||
|
||||
:Parameters:
|
||||
- `sock_info` - A SocketInfo instance.
|
||||
- `connection` - A Connection instance.
|
||||
- `command` - The command itself, as a :class:`~bson.son.SON` instance.
|
||||
- `read_preference` (optional) - The read preference to use.
|
||||
- `codec_options` (optional) - An instance of
|
||||
@ -313,7 +313,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
The result document.
|
||||
"""
|
||||
with self.__database.client._tmp_session(session) as s:
|
||||
return sock_info.command(
|
||||
return connection.command(
|
||||
self.__database.name,
|
||||
command,
|
||||
read_preference or self._read_preference_for(session),
|
||||
@ -348,16 +348,16 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
if "size" in options:
|
||||
options["size"] = float(options["size"])
|
||||
cmd.update(options)
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
if qev2_required and sock_info.max_wire_version < 21:
|
||||
with self._socket_for_writes(session) as connection:
|
||||
if qev2_required and connection.max_wire_version < 21:
|
||||
raise ConfigurationError(
|
||||
"Driver support of Queryable Encryption is incompatible with server. "
|
||||
"Upgrade server to use Queryable Encryption. "
|
||||
f"Got maxWireVersion {sock_info.max_wire_version} but need maxWireVersion >= 21 (MongoDB >=7.0)"
|
||||
f"Got maxWireVersion {connection.max_wire_version} but need maxWireVersion >= 21 (MongoDB >=7.0)"
|
||||
)
|
||||
|
||||
self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
write_concern=self._write_concern_for(session),
|
||||
@ -597,12 +597,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
command["comment"] = comment
|
||||
|
||||
def _insert_command(
|
||||
session: ClientSession, sock_info: SocketInfo, retryable_write: bool
|
||||
session: ClientSession, connection: Connection, retryable_write: bool
|
||||
) -> None:
|
||||
if bypass_doc_val:
|
||||
command["bypassDocumentValidation"] = True
|
||||
|
||||
result = sock_info.command(
|
||||
result = connection.command(
|
||||
self.__database.name,
|
||||
command,
|
||||
write_concern=write_concern,
|
||||
@ -765,7 +765,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _update(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
criteria: Mapping[str, Any],
|
||||
document: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
@ -801,7 +801,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
else:
|
||||
update_doc["arrayFilters"] = array_filters
|
||||
if hint is not None:
|
||||
if not acknowledged and sock_info.max_wire_version < 8:
|
||||
if not acknowledged and connection.max_wire_version < 8:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
|
||||
)
|
||||
@ -821,7 +821,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
# The command result has to be published for APM unmodified
|
||||
# so we make a shallow copy here before adding updatedExisting.
|
||||
result = sock_info.command(
|
||||
result = connection.command(
|
||||
self.__database.name,
|
||||
command,
|
||||
write_concern=write_concern,
|
||||
@ -865,10 +865,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
"""Internal update / replace helper."""
|
||||
|
||||
def _update(
|
||||
session: Optional[ClientSession], sock_info: SocketInfo, retryable_write: bool
|
||||
session: Optional[ClientSession], connection: Connection, retryable_write: bool
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
return self._update(
|
||||
sock_info,
|
||||
connection,
|
||||
criteria,
|
||||
document,
|
||||
upsert=upsert,
|
||||
@ -1255,7 +1255,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _delete(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
criteria: Mapping[str, Any],
|
||||
multi: bool,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
@ -1280,7 +1280,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
else:
|
||||
delete_doc["collation"] = collation
|
||||
if hint is not None:
|
||||
if not acknowledged and sock_info.max_wire_version < 9:
|
||||
if not acknowledged and connection.max_wire_version < 9:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
|
||||
)
|
||||
@ -1297,7 +1297,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
command["comment"] = comment
|
||||
|
||||
# Delete command.
|
||||
result = sock_info.command(
|
||||
result = connection.command(
|
||||
self.__database.name,
|
||||
command,
|
||||
write_concern=write_concern,
|
||||
@ -1325,10 +1325,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
"""Internal delete helper."""
|
||||
|
||||
def _delete(
|
||||
session: Optional[ClientSession], sock_info: SocketInfo, retryable_write: bool
|
||||
session: Optional[ClientSession], connection: Connection, retryable_write: bool
|
||||
) -> Mapping[str, Any]:
|
||||
return self._delete(
|
||||
sock_info,
|
||||
connection,
|
||||
criteria,
|
||||
multi,
|
||||
write_concern=write_concern,
|
||||
@ -1738,7 +1738,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
def _count_cmd(
|
||||
self,
|
||||
session: ClientSession,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
cmd: Mapping[str, Any],
|
||||
collation: Optional[Collation],
|
||||
@ -1747,7 +1747,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
# XXX: "ns missing" checks can be removed when we drop support for
|
||||
# MongoDB 3.0, see SERVER-17051.
|
||||
res = self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
cmd,
|
||||
read_preference=read_preference,
|
||||
allowable_errors=["ns missing"],
|
||||
@ -1762,7 +1762,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _aggregate_one_result(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
cmd: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn],
|
||||
@ -1770,7 +1770,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""Internal helper to run an aggregate that returns a single result."""
|
||||
result = self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
cmd,
|
||||
read_preference,
|
||||
allowable_errors=[26], # Ignore NamespaceNotFound.
|
||||
@ -1821,12 +1821,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
def _cmd(
|
||||
session: ClientSession,
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
) -> int:
|
||||
cmd: SON[str, Any] = SON([("count", self.__name)])
|
||||
cmd.update(kwargs)
|
||||
return self._count_cmd(session, sock_info, read_preference, cmd, collation=None)
|
||||
return self._count_cmd(session, connection, read_preference, cmd, collation=None)
|
||||
|
||||
return self._retryable_non_cursor_read(_cmd, None)
|
||||
|
||||
@ -1910,10 +1910,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
def _cmd(
|
||||
session: ClientSession,
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
) -> int:
|
||||
result = self._aggregate_one_result(sock_info, read_preference, cmd, collation, session)
|
||||
result = self._aggregate_one_result(
|
||||
connection, read_preference, cmd, collation, session
|
||||
)
|
||||
if not result:
|
||||
return 0
|
||||
return result["n"]
|
||||
@ -1922,7 +1924,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _retryable_non_cursor_read(
|
||||
self,
|
||||
func: Callable[[ClientSession, Server, SocketInfo, Optional[_ServerMode]], T],
|
||||
func: Callable[[ClientSession, Server, Connection, Optional[_ServerMode]], T],
|
||||
session: Optional[ClientSession],
|
||||
) -> T:
|
||||
"""Non-cursor read helper to handle implicit session creation."""
|
||||
@ -1993,8 +1995,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
command (like maxTimeMS) can be passed as keyword arguments.
|
||||
"""
|
||||
names = []
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
supports_quorum = sock_info.max_wire_version >= 9
|
||||
with self._socket_for_writes(session) as connection:
|
||||
supports_quorum = connection.max_wire_version >= 9
|
||||
|
||||
def gen_indexes() -> Iterator[Mapping[str, Any]]:
|
||||
for index in indexes:
|
||||
@ -2015,7 +2017,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
@ -2236,9 +2238,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd.update(kwargs)
|
||||
if comment is not None:
|
||||
cmd["comment"] = comment
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
with self._socket_for_writes(session) as connection:
|
||||
self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
allowable_errors=["ns not found", 26],
|
||||
@ -2285,7 +2287,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
def _cmd(
|
||||
session: ClientSession,
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
read_preference: _ServerMode,
|
||||
) -> CommandCursor[_DocumentType]:
|
||||
cmd = SON([("listIndexes", self.__name), ("cursor", {})])
|
||||
@ -2294,7 +2296,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
try:
|
||||
cursor = self._command(
|
||||
sock_info, cmd, read_preference, codec_options, session=session
|
||||
connection, cmd, read_preference, codec_options, session=session
|
||||
)["cursor"]
|
||||
except OperationFailure as exc:
|
||||
# Ignore NamespaceNotFound errors to match the behavior
|
||||
@ -2305,12 +2307,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd_cursor = CommandCursor(
|
||||
coll,
|
||||
cursor,
|
||||
sock_info.address,
|
||||
connection.address,
|
||||
session=session,
|
||||
explicit_session=explicit_session,
|
||||
comment=cmd.get("comment"),
|
||||
)
|
||||
cmd_cursor._maybe_pin_connection(sock_info)
|
||||
cmd_cursor._maybe_pin_connection(connection)
|
||||
return cmd_cursor
|
||||
|
||||
with self.__database.client._tmp_session(session, False) as s:
|
||||
@ -2479,9 +2481,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd = SON([("createSearchIndexes", self.name), ("indexes", list(gen_indexes()))])
|
||||
cmd.update(kwargs)
|
||||
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
with self._socket_for_writes(session) as connection:
|
||||
resp = self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
@ -2514,9 +2516,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd.update(kwargs)
|
||||
if comment is not None:
|
||||
cmd["comment"] = comment
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
with self._socket_for_writes(session) as connection:
|
||||
self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
allowable_errors=["ns not found", 26],
|
||||
@ -2551,9 +2553,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd.update(kwargs)
|
||||
if comment is not None:
|
||||
cmd["comment"] = comment
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
with self._socket_for_writes(session) as connection:
|
||||
self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
allowable_errors=["ns not found", 26],
|
||||
@ -2980,9 +2982,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd["comment"] = comment
|
||||
write_concern = self._write_concern_for_cmd(cmd, session)
|
||||
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
with self._socket_for_writes(session) as connection:
|
||||
with self.__database.client._tmp_session(session) as s:
|
||||
return sock_info.command(
|
||||
return connection.command(
|
||||
"admin",
|
||||
cmd,
|
||||
write_concern=write_concern,
|
||||
@ -3049,11 +3051,11 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
def _cmd(
|
||||
session: ClientSession,
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
) -> List:
|
||||
return self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
cmd,
|
||||
read_preference=read_preference,
|
||||
read_concern=self.read_concern,
|
||||
@ -3112,7 +3114,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
write_concern = self._write_concern_for_cmd(cmd, session)
|
||||
|
||||
def _find_and_modify(
|
||||
session: ClientSession, sock_info: SocketInfo, retryable_write: bool
|
||||
session: ClientSession, connection: Connection, retryable_write: bool
|
||||
) -> Any:
|
||||
acknowledged = write_concern.acknowledged
|
||||
if array_filters is not None:
|
||||
@ -3122,17 +3124,17 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
cmd["arrayFilters"] = list(array_filters)
|
||||
if hint is not None:
|
||||
if sock_info.max_wire_version < 8:
|
||||
if connection.max_wire_version < 8:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.2+ to use hint on find and modify commands."
|
||||
)
|
||||
elif not acknowledged and sock_info.max_wire_version < 9:
|
||||
elif not acknowledged and connection.max_wire_version < 9:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged find and modify commands."
|
||||
)
|
||||
cmd["hint"] = hint
|
||||
out = self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
write_concern=write_concern,
|
||||
|
||||
@ -38,7 +38,7 @@ from pymongo.typings import _Address, _DocumentType
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
|
||||
|
||||
class CommandCursor(Generic[_DocumentType]):
|
||||
@ -157,13 +157,13 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
"""
|
||||
return self.__postbatchresumetoken
|
||||
|
||||
def _maybe_pin_connection(self, sock_info: SocketInfo) -> None:
|
||||
def _maybe_pin_connection(self, connection: Connection) -> None:
|
||||
client = self.__collection.database.client
|
||||
if not client._should_pin_cursor(self.__session):
|
||||
return
|
||||
if not self.__sock_mgr:
|
||||
sock_info.pin_cursor()
|
||||
sock_mgr = _SocketManager(sock_info, False)
|
||||
connection.pin_cursor()
|
||||
sock_mgr = _SocketManager(connection, False)
|
||||
# Ensure the connection gets returned when the entire result is
|
||||
# returned in the first batch.
|
||||
if self.__id == 0:
|
||||
|
||||
@ -750,6 +750,7 @@ KW_VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = {
|
||||
"server_selector": validate_is_callable_or_none,
|
||||
"auto_encryption_opts": validate_auto_encryption_opts_or_none,
|
||||
"authoidcallowedhosts": validate_list,
|
||||
"grpc": validate_boolean,
|
||||
}
|
||||
|
||||
# Dictionary where keys are any URI option name, and values are the
|
||||
|
||||
@ -64,7 +64,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.message import _OpMsg, _OpReply
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
|
||||
|
||||
@ -142,8 +142,8 @@ class CursorType:
|
||||
class _SocketManager:
|
||||
"""Used with exhaust cursors to ensure the socket is returned."""
|
||||
|
||||
def __init__(self, sock: SocketInfo, more_to_come: bool):
|
||||
self.sock: Optional[SocketInfo] = sock
|
||||
def __init__(self, sock: Connection, more_to_come: bool):
|
||||
self.sock: Optional[Connection] = sock
|
||||
self.more_to_come = more_to_come
|
||||
self.lock = _create_lock()
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.server import Server
|
||||
|
||||
|
||||
@ -689,7 +689,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
@overload
|
||||
def _command(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
value: int = 1,
|
||||
check: bool = True,
|
||||
@ -706,7 +706,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
@overload
|
||||
def _command(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
value: int = 1,
|
||||
check: bool = True,
|
||||
@ -722,7 +722,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _command(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
value: int = 1,
|
||||
check: bool = True,
|
||||
@ -742,7 +742,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
command.update(kwargs)
|
||||
with self.__client._tmp_session(session) as s:
|
||||
return sock_info.command(
|
||||
return connection.command(
|
||||
self.__name,
|
||||
command,
|
||||
read_preference,
|
||||
@ -890,11 +890,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
if read_preference is None:
|
||||
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
|
||||
with self.__client._socket_for_reads(read_preference, session) as (
|
||||
sock_info,
|
||||
connection,
|
||||
read_preference,
|
||||
):
|
||||
return self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
command,
|
||||
value,
|
||||
check,
|
||||
@ -974,11 +974,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
tmp_session and tmp_session._txn_read_preference()
|
||||
) or ReadPreference.PRIMARY
|
||||
with self.__client._socket_for_reads(read_preference, tmp_session) as (
|
||||
sock_info,
|
||||
connection,
|
||||
read_preference,
|
||||
):
|
||||
response = self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
command,
|
||||
value,
|
||||
True,
|
||||
@ -993,13 +993,13 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd_cursor = CommandCursor(
|
||||
coll,
|
||||
response["cursor"],
|
||||
sock_info.address,
|
||||
connection.address,
|
||||
max_await_time_ms=max_await_time_ms,
|
||||
session=tmp_session,
|
||||
explicit_session=session is not None,
|
||||
comment=comment,
|
||||
)
|
||||
cmd_cursor._maybe_pin_connection(sock_info)
|
||||
cmd_cursor._maybe_pin_connection(connection)
|
||||
return cmd_cursor
|
||||
else:
|
||||
raise InvalidOperation("Command does not return a cursor.")
|
||||
@ -1015,11 +1015,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
def _cmd(
|
||||
session: Optional[ClientSession],
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
read_preference: _ServerMode,
|
||||
) -> Dict[str, Any]:
|
||||
return self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
command,
|
||||
read_preference=read_preference,
|
||||
session=session,
|
||||
@ -1029,7 +1029,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _list_collections(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
session: Optional[ClientSession],
|
||||
read_preference: _ServerMode,
|
||||
**kwargs: Any,
|
||||
@ -1040,17 +1040,17 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd.update(kwargs)
|
||||
with self.__client._tmp_session(session, close=False) as tmp_session:
|
||||
cursor = self._command(
|
||||
sock_info, cmd, read_preference=read_preference, session=tmp_session
|
||||
connection, cmd, read_preference=read_preference, session=tmp_session
|
||||
)["cursor"]
|
||||
cmd_cursor = CommandCursor(
|
||||
coll,
|
||||
cursor,
|
||||
sock_info.address,
|
||||
connection.address,
|
||||
session=tmp_session,
|
||||
explicit_session=session is not None,
|
||||
comment=cmd.get("comment"),
|
||||
)
|
||||
cmd_cursor._maybe_pin_connection(sock_info)
|
||||
cmd_cursor._maybe_pin_connection(connection)
|
||||
return cmd_cursor
|
||||
|
||||
def list_collections(
|
||||
@ -1090,11 +1090,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
def _cmd(
|
||||
session: Optional[ClientSession],
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
read_preference: _ServerMode,
|
||||
) -> CommandCursor[_DocumentType]:
|
||||
return self._list_collections(
|
||||
sock_info, session, read_preference=read_preference, **kwargs
|
||||
connection, session, read_preference=read_preference, **kwargs
|
||||
)
|
||||
|
||||
return self.__client._retryable_read(_cmd, read_pref, session)
|
||||
@ -1154,9 +1154,9 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
if comment is not None:
|
||||
command["comment"] = comment
|
||||
|
||||
with self.__client._socket_for_writes(session) as sock_info:
|
||||
with self.__client._socket_for_writes(session) as connection:
|
||||
return self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
command,
|
||||
allowable_errors=["ns not found", 26],
|
||||
write_concern=self._write_concern_for(session),
|
||||
|
||||
@ -307,7 +307,7 @@ F = TypeVar("F", bound=Callable[..., Any])
|
||||
def _handle_reauth(func: F) -> F:
|
||||
def inner(*args: Any, **kwargs: Any) -> Any:
|
||||
no_reauth = kwargs.pop("no_reauth", False)
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@ -315,19 +315,19 @@ def _handle_reauth(func: F) -> F:
|
||||
if no_reauth:
|
||||
raise
|
||||
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
|
||||
# Look for an argument that either is a SocketInfo
|
||||
# Look for an argument that either is a Connection
|
||||
# or has a socket_info attribute, so we can trigger
|
||||
# a reauth.
|
||||
sock_info = None
|
||||
connection = None
|
||||
for arg in args:
|
||||
if isinstance(arg, SocketInfo):
|
||||
sock_info = arg
|
||||
if isinstance(arg, Connection):
|
||||
connection = arg
|
||||
break
|
||||
if hasattr(arg, "sock_info"):
|
||||
sock_info = arg.sock_info
|
||||
if hasattr(arg, "connection"):
|
||||
connection = arg.connection
|
||||
break
|
||||
if sock_info:
|
||||
sock_info.authenticate(reauthenticate=True)
|
||||
if connection:
|
||||
connection.authenticate(reauthenticate=True)
|
||||
else:
|
||||
raise
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@ -227,14 +227,14 @@ def _gen_find_command(
|
||||
return cmd
|
||||
|
||||
|
||||
def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, sock_info):
|
||||
def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, connection):
|
||||
"""Generate a getMore command document."""
|
||||
cmd = SON([("getMore", cursor_id), ("collection", coll)])
|
||||
if batch_size:
|
||||
cmd["batchSize"] = batch_size
|
||||
if max_await_time_ms is not None:
|
||||
cmd["maxTimeMS"] = max_await_time_ms
|
||||
if comment is not None and sock_info.max_wire_version >= 9:
|
||||
if comment is not None and connection.max_wire_version >= 9:
|
||||
cmd["comment"] = comment
|
||||
return cmd
|
||||
|
||||
@ -311,24 +311,24 @@ class _Query:
|
||||
def namespace(self):
|
||||
return f"{self.db}.{self.coll}"
|
||||
|
||||
def use_command(self, sock_info):
|
||||
def use_command(self, connection):
|
||||
use_find_cmd = False
|
||||
if not self.exhaust:
|
||||
use_find_cmd = True
|
||||
elif sock_info.max_wire_version >= 8:
|
||||
elif connection.max_wire_version >= 8:
|
||||
# OP_MSG supports exhaust on MongoDB 4.2+
|
||||
use_find_cmd = True
|
||||
elif not self.read_concern.ok_for_legacy:
|
||||
raise ConfigurationError(
|
||||
"read concern level of %s is not valid "
|
||||
"with a max wire version of %d."
|
||||
% (self.read_concern.level, sock_info.max_wire_version)
|
||||
% (self.read_concern.level, connection.max_wire_version)
|
||||
)
|
||||
|
||||
sock_info.validate_session(self.client, self.session)
|
||||
connection.validate_session(self.client, self.session)
|
||||
return use_find_cmd
|
||||
|
||||
def as_command(self, sock_info, apply_timeout=False):
|
||||
def as_command(self, connection, apply_timeout=False):
|
||||
"""Return a find command document for this query."""
|
||||
# We use the command twice: on the wire and for command monitoring.
|
||||
# Generate it once, for speed and to avoid repeating side-effects.
|
||||
@ -353,24 +353,24 @@ class _Query:
|
||||
self.name = "explain"
|
||||
cmd = SON([("explain", cmd)])
|
||||
session = self.session
|
||||
sock_info.add_server_api(cmd)
|
||||
connection.add_server_api(cmd)
|
||||
if session:
|
||||
session._apply_to(cmd, False, self.read_preference, sock_info)
|
||||
session._apply_to(cmd, False, self.read_preference, connection)
|
||||
# Explain does not support readConcern.
|
||||
if not explain and not session.in_transaction:
|
||||
session._update_read_concern(cmd, sock_info)
|
||||
sock_info.send_cluster_time(cmd, session, self.client)
|
||||
session._update_read_concern(cmd, connection)
|
||||
connection.send_cluster_time(cmd, session, self.client)
|
||||
# Support auto encryption
|
||||
client = self.client
|
||||
if client._encrypter and not client._encrypter._bypass_auto_encryption:
|
||||
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
|
||||
# Support CSOT
|
||||
if apply_timeout:
|
||||
sock_info.apply_timeout(client, cmd)
|
||||
connection.apply_timeout(client, cmd)
|
||||
self._as_command = cmd, self.db
|
||||
return self._as_command
|
||||
|
||||
def get_message(self, read_preference, sock_info, use_cmd=False):
|
||||
def get_message(self, read_preference, connection, use_cmd=False):
|
||||
"""Get a query message, possibly setting the secondaryOk bit."""
|
||||
# Use the read_preference decided by _socket_from_server.
|
||||
self.read_preference = read_preference
|
||||
@ -384,14 +384,14 @@ class _Query:
|
||||
spec = self.spec
|
||||
|
||||
if use_cmd:
|
||||
spec = self.as_command(sock_info, apply_timeout=True)[0]
|
||||
spec = self.as_command(connection, apply_timeout=True)[0]
|
||||
request_id, msg, size, _ = _op_msg(
|
||||
0,
|
||||
spec,
|
||||
self.db,
|
||||
read_preference,
|
||||
self.codec_options,
|
||||
ctx=sock_info.compression_context,
|
||||
ctx=connection.compression_context,
|
||||
)
|
||||
return request_id, msg, size
|
||||
|
||||
@ -405,7 +405,7 @@ class _Query:
|
||||
else:
|
||||
ntoreturn = self.limit
|
||||
|
||||
if sock_info.is_mongos:
|
||||
if connection.is_mongos:
|
||||
spec = _maybe_add_read_preference(spec, read_preference)
|
||||
|
||||
return _query(
|
||||
@ -416,7 +416,7 @@ class _Query:
|
||||
spec,
|
||||
None if use_cmd else self.fields,
|
||||
self.codec_options,
|
||||
ctx=sock_info.compression_context,
|
||||
ctx=connection.compression_context,
|
||||
)
|
||||
|
||||
|
||||
@ -476,18 +476,18 @@ class _GetMore:
|
||||
def namespace(self):
|
||||
return f"{self.db}.{self.coll}"
|
||||
|
||||
def use_command(self, sock_info):
|
||||
def use_command(self, connection):
|
||||
use_cmd = False
|
||||
if not self.exhaust:
|
||||
use_cmd = True
|
||||
elif sock_info.max_wire_version >= 8:
|
||||
elif connection.max_wire_version >= 8:
|
||||
# OP_MSG supports exhaust on MongoDB 4.2+
|
||||
use_cmd = True
|
||||
|
||||
sock_info.validate_session(self.client, self.session)
|
||||
connection.validate_session(self.client, self.session)
|
||||
return use_cmd
|
||||
|
||||
def as_command(self, sock_info, apply_timeout=False):
|
||||
def as_command(self, connection, apply_timeout=False):
|
||||
"""Return a getMore command document for this query."""
|
||||
# See _Query.as_command for an explanation of this caching.
|
||||
if self._as_command is not None:
|
||||
@ -499,35 +499,35 @@ class _GetMore:
|
||||
self.ntoreturn,
|
||||
self.max_await_time_ms,
|
||||
self.comment,
|
||||
sock_info,
|
||||
connection,
|
||||
)
|
||||
if self.session:
|
||||
self.session._apply_to(cmd, False, self.read_preference, sock_info)
|
||||
sock_info.add_server_api(cmd)
|
||||
sock_info.send_cluster_time(cmd, self.session, self.client)
|
||||
self.session._apply_to(cmd, False, self.read_preference, connection)
|
||||
connection.add_server_api(cmd)
|
||||
connection.send_cluster_time(cmd, self.session, self.client)
|
||||
# Support auto encryption
|
||||
client = self.client
|
||||
if client._encrypter and not client._encrypter._bypass_auto_encryption:
|
||||
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
|
||||
# Support CSOT
|
||||
if apply_timeout:
|
||||
sock_info.apply_timeout(client, cmd=None)
|
||||
connection.apply_timeout(client, cmd=None)
|
||||
self._as_command = cmd, self.db
|
||||
return self._as_command
|
||||
|
||||
def get_message(self, dummy0, sock_info, use_cmd=False):
|
||||
def get_message(self, dummy0, connection, use_cmd=False):
|
||||
"""Get a getmore message."""
|
||||
ns = self.namespace()
|
||||
ctx = sock_info.compression_context
|
||||
ctx = connection.compression_context
|
||||
|
||||
if use_cmd:
|
||||
spec = self.as_command(sock_info, apply_timeout=True)[0]
|
||||
spec = self.as_command(connection, apply_timeout=True)[0]
|
||||
if self.sock_mgr:
|
||||
flags = _OpMsg.EXHAUST_ALLOWED
|
||||
else:
|
||||
flags = 0
|
||||
request_id, msg, size, _ = _op_msg(
|
||||
flags, spec, self.db, None, self.codec_options, ctx=sock_info.compression_context
|
||||
flags, spec, self.db, None, self.codec_options, ctx=connection.compression_context
|
||||
)
|
||||
return request_id, msg, size
|
||||
|
||||
@ -535,10 +535,10 @@ class _GetMore:
|
||||
|
||||
|
||||
class _RawBatchQuery(_Query):
|
||||
def use_command(self, sock_info):
|
||||
def use_command(self, connection):
|
||||
# Compatibility checks.
|
||||
super().use_command(sock_info)
|
||||
if sock_info.max_wire_version >= 8:
|
||||
super().use_command(connection)
|
||||
if connection.max_wire_version >= 8:
|
||||
# MongoDB 4.2+ supports exhaust over OP_MSG
|
||||
return True
|
||||
elif not self.exhaust:
|
||||
@ -547,10 +547,10 @@ class _RawBatchQuery(_Query):
|
||||
|
||||
|
||||
class _RawBatchGetMore(_GetMore):
|
||||
def use_command(self, sock_info):
|
||||
def use_command(self, connection):
|
||||
# Compatibility checks.
|
||||
super().use_command(sock_info)
|
||||
if sock_info.max_wire_version >= 8:
|
||||
super().use_command(connection)
|
||||
if connection.max_wire_version >= 8:
|
||||
# MongoDB 4.2+ supports exhaust over OP_MSG
|
||||
return True
|
||||
elif not self.exhaust:
|
||||
@ -794,11 +794,11 @@ def _get_more(collection_name, num_to_return, cursor_id, ctx=None):
|
||||
|
||||
|
||||
class _BulkWriteContext:
|
||||
"""A wrapper around SocketInfo for use with write splitting functions."""
|
||||
"""A wrapper around Connection for use with write splitting functions."""
|
||||
|
||||
__slots__ = (
|
||||
"db_name",
|
||||
"sock_info",
|
||||
"connection",
|
||||
"op_id",
|
||||
"name",
|
||||
"field",
|
||||
@ -812,10 +812,10 @@ class _BulkWriteContext:
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, database_name, cmd_name, sock_info, operation_id, listeners, session, op_type, codec
|
||||
self, database_name, cmd_name, connection, operation_id, listeners, session, op_type, codec
|
||||
):
|
||||
self.db_name = database_name
|
||||
self.sock_info = sock_info
|
||||
self.connection = connection
|
||||
self.op_id = operation_id
|
||||
self.listeners = listeners
|
||||
self.publish = listeners.enabled_for_commands
|
||||
@ -823,7 +823,7 @@ class _BulkWriteContext:
|
||||
self.field = _FIELD_MAP[self.name]
|
||||
self.start_time = datetime.datetime.now() if self.publish else None
|
||||
self.session = session
|
||||
self.compress = True if sock_info.compression_context else False
|
||||
self.compress = True if connection.compression_context else False
|
||||
self.op_type = op_type
|
||||
self.codec = codec
|
||||
|
||||
@ -855,20 +855,20 @@ class _BulkWriteContext:
|
||||
@property
|
||||
def max_bson_size(self):
|
||||
"""A proxy for SockInfo.max_bson_size."""
|
||||
return self.sock_info.max_bson_size
|
||||
return self.connection.max_bson_size
|
||||
|
||||
@property
|
||||
def max_message_size(self):
|
||||
"""A proxy for SockInfo.max_message_size."""
|
||||
if self.compress:
|
||||
# Subtract 16 bytes for the message header.
|
||||
return self.sock_info.max_message_size - 16
|
||||
return self.sock_info.max_message_size
|
||||
return self.connection.max_message_size - 16
|
||||
return self.connection.max_message_size
|
||||
|
||||
@property
|
||||
def max_write_batch_size(self):
|
||||
"""A proxy for SockInfo.max_write_batch_size."""
|
||||
return self.sock_info.max_write_batch_size
|
||||
return self.connection.max_write_batch_size
|
||||
|
||||
@property
|
||||
def max_split_size(self):
|
||||
@ -876,14 +876,14 @@ class _BulkWriteContext:
|
||||
return self.max_bson_size
|
||||
|
||||
def unack_write(self, cmd, request_id, msg, max_doc_size, docs):
|
||||
"""A proxy for SocketInfo.unack_write that handles event publishing."""
|
||||
"""A proxy for Connection.unack_write that handles event publishing."""
|
||||
if self.publish:
|
||||
assert self.start_time is not None
|
||||
duration = datetime.datetime.now() - self.start_time
|
||||
cmd = self._start(cmd, request_id, docs)
|
||||
start = datetime.datetime.now()
|
||||
try:
|
||||
result = self.sock_info.unack_write(msg, max_doc_size)
|
||||
result = self.connection.unack_write(msg, max_doc_size)
|
||||
if self.publish:
|
||||
duration = (datetime.datetime.now() - start) + duration
|
||||
if result is not None:
|
||||
@ -910,14 +910,14 @@ class _BulkWriteContext:
|
||||
|
||||
@_handle_reauth
|
||||
def write_command(self, cmd, request_id, msg, docs):
|
||||
"""A proxy for SocketInfo.write_command that handles event publishing."""
|
||||
"""A proxy for Connection.write_command that handles event publishing."""
|
||||
if self.publish:
|
||||
assert self.start_time is not None
|
||||
duration = datetime.datetime.now() - self.start_time
|
||||
self._start(cmd, request_id, docs)
|
||||
start = datetime.datetime.now()
|
||||
try:
|
||||
reply = self.sock_info.write_command(request_id, msg, self.codec)
|
||||
reply = self.connection.write_command(request_id, msg, self.codec)
|
||||
if self.publish:
|
||||
duration = (datetime.datetime.now() - start) + duration
|
||||
self._succeed(request_id, reply, duration)
|
||||
@ -941,9 +941,9 @@ class _BulkWriteContext:
|
||||
cmd,
|
||||
self.db_name,
|
||||
request_id,
|
||||
self.sock_info.address,
|
||||
self.connection.address,
|
||||
self.op_id,
|
||||
self.sock_info.service_id,
|
||||
self.connection.service_id,
|
||||
)
|
||||
return cmd
|
||||
|
||||
@ -954,9 +954,9 @@ class _BulkWriteContext:
|
||||
reply,
|
||||
self.name,
|
||||
request_id,
|
||||
self.sock_info.address,
|
||||
self.connection.address,
|
||||
self.op_id,
|
||||
self.sock_info.service_id,
|
||||
self.connection.service_id,
|
||||
)
|
||||
|
||||
def _fail(self, request_id, failure, duration):
|
||||
@ -966,9 +966,9 @@ class _BulkWriteContext:
|
||||
failure,
|
||||
self.name,
|
||||
request_id,
|
||||
self.sock_info.address,
|
||||
self.connection.address,
|
||||
self.op_id,
|
||||
self.sock_info.service_id,
|
||||
self.connection.service_id,
|
||||
)
|
||||
|
||||
|
||||
@ -997,14 +997,14 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
|
||||
|
||||
def execute(self, cmd, docs, client):
|
||||
batched_cmd, to_send = self._batch_command(cmd, docs)
|
||||
result = self.sock_info.command(
|
||||
result = self.connection.command(
|
||||
self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client
|
||||
)
|
||||
return result, to_send
|
||||
|
||||
def execute_unack(self, cmd, docs, client):
|
||||
batched_cmd, to_send = self._batch_command(cmd, docs)
|
||||
self.sock_info.command(
|
||||
self.connection.command(
|
||||
self.db_name,
|
||||
batched_cmd,
|
||||
write_concern=WriteConcern(w=0),
|
||||
@ -1124,7 +1124,7 @@ def _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx):
|
||||
"""
|
||||
data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx)
|
||||
|
||||
request_id, msg = _compress(2013, data, ctx.sock_info.compression_context)
|
||||
request_id, msg = _compress(2013, data, ctx.connection.compression_context)
|
||||
return request_id, msg, to_send
|
||||
|
||||
|
||||
@ -1162,7 +1162,7 @@ def _do_batched_op_msg(namespace, operation, command, docs, opts, ctx):
|
||||
ack = bool(command["writeConcern"].get("w", 1))
|
||||
else:
|
||||
ack = True
|
||||
if ctx.sock_info.compression_context:
|
||||
if ctx.connection.compression_context:
|
||||
return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx)
|
||||
return _batched_op_msg(operation, command, docs, ack, opts, ctx)
|
||||
|
||||
|
||||
@ -1160,18 +1160,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
def _end_sessions(self, session_ids):
|
||||
"""Send endSessions command(s) with the given session ids."""
|
||||
try:
|
||||
# Use SocketInfo.command directly to avoid implicitly creating
|
||||
# Use Connection.command directly to avoid implicitly creating
|
||||
# another session.
|
||||
with self._socket_for_reads(ReadPreference.PRIMARY_PREFERRED, None) as (
|
||||
sock_info,
|
||||
connection,
|
||||
read_pref,
|
||||
):
|
||||
if not sock_info.supports_sessions:
|
||||
if not connection.supports_sessions:
|
||||
return
|
||||
|
||||
for i in range(0, len(session_ids), common._MAX_END_SESSIONS):
|
||||
spec = SON([("endSessions", session_ids[i : i + common._MAX_END_SESSIONS])])
|
||||
sock_info.command("admin", spec, read_preference=read_pref, client=self)
|
||||
connection.command("admin", spec, read_preference=read_pref, client=self)
|
||||
except PyMongoError:
|
||||
# Drivers MUST ignore any errors returned by the endSessions
|
||||
# command.
|
||||
@ -1224,23 +1224,23 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
err_handler.contribute_socket(session._pinned_connection)
|
||||
yield session._pinned_connection
|
||||
return
|
||||
with server.get_socket(handler=err_handler) as sock_info:
|
||||
with server.get_socket(handler=err_handler) as connection:
|
||||
# Pin this session to the selected server or connection.
|
||||
if in_txn and server.description.server_type in (
|
||||
SERVER_TYPE.Mongos,
|
||||
SERVER_TYPE.LoadBalancer,
|
||||
):
|
||||
session._pin(server, sock_info)
|
||||
err_handler.contribute_socket(sock_info)
|
||||
session._pin(server, connection)
|
||||
err_handler.contribute_socket(connection)
|
||||
if (
|
||||
self._encrypter
|
||||
and not self._encrypter._bypass_auto_encryption
|
||||
and sock_info.max_wire_version < 8
|
||||
and connection.max_wire_version < 8
|
||||
):
|
||||
raise ConfigurationError(
|
||||
"Auto-encryption requires a minimum MongoDB version of 4.2"
|
||||
)
|
||||
yield sock_info
|
||||
yield connection
|
||||
|
||||
def _select_server(self, server_selector, session, address=None):
|
||||
"""Select a server to run an operation on this client.
|
||||
@ -1281,7 +1281,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
def _socket_from_server(self, read_preference, server, session):
|
||||
assert read_preference is not None, "read_preference must not be None"
|
||||
# Get a socket for a server matching the read preference, and yield
|
||||
# sock_info with the effective read preference. The Server Selection
|
||||
# connection with the effective read preference. The Server Selection
|
||||
# Spec says not to send any $readPreference to standalones and to
|
||||
# always send primaryPreferred when directly connected to a repl set
|
||||
# member.
|
||||
@ -1289,16 +1289,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
topology = self._get_topology()
|
||||
single = topology.description.topology_type == TOPOLOGY_TYPE.Single
|
||||
|
||||
with self._get_socket(server, session) as sock_info:
|
||||
with self._get_socket(server, session) as connection:
|
||||
if single:
|
||||
if sock_info.is_repl and not (session and session.in_transaction):
|
||||
if connection.is_repl and not (session and session.in_transaction):
|
||||
# Use primary preferred to ensure any repl set member
|
||||
# can handle the request.
|
||||
read_preference = ReadPreference.PRIMARY_PREFERRED
|
||||
elif sock_info.is_standalone:
|
||||
elif connection.is_standalone:
|
||||
# Don't send read preference to standalones.
|
||||
read_preference = ReadPreference.PRIMARY
|
||||
yield sock_info, read_preference
|
||||
yield connection, read_preference
|
||||
|
||||
def _socket_for_reads(self, read_preference, session):
|
||||
assert read_preference is not None, "read_preference must not be None"
|
||||
@ -1331,10 +1331,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
operation.sock_mgr.sock, operation, True, self._event_listeners, unpack_res
|
||||
)
|
||||
|
||||
def _cmd(session, server, sock_info, read_preference):
|
||||
def _cmd(session, server, connection, read_preference):
|
||||
operation.reset() # Reset op in case of retry.
|
||||
return server.run_operation(
|
||||
sock_info, operation, read_preference, self._event_listeners, unpack_res
|
||||
connection, operation, read_preference, self._event_listeners, unpack_res
|
||||
)
|
||||
|
||||
return self._retryable_read(
|
||||
@ -1388,8 +1388,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
supports_session = (
|
||||
session is not None and server.description.retryable_writes_supported
|
||||
)
|
||||
with self._get_socket(server, session) as sock_info:
|
||||
max_wire_version = sock_info.max_wire_version
|
||||
with self._get_socket(server, session) as connection:
|
||||
max_wire_version = connection.max_wire_version
|
||||
if retryable and not supports_session:
|
||||
if is_retrying():
|
||||
# A retry is not possible because this server does
|
||||
@ -1397,7 +1397,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
retryable = False
|
||||
return func(session, sock_info, retryable)
|
||||
return func(session, connection, retryable)
|
||||
except ServerSelectionTimeoutError:
|
||||
if is_retrying():
|
||||
# The application may think the write was never attempted
|
||||
@ -1455,13 +1455,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
raise last_error
|
||||
try:
|
||||
server = self._select_server(read_pref, session, address=address)
|
||||
with self._socket_from_server(read_pref, server, session) as (sock_info, read_pref):
|
||||
with self._socket_from_server(read_pref, server, session) as (
|
||||
connection,
|
||||
read_pref,
|
||||
):
|
||||
if retrying and not retryable:
|
||||
# A retry is not possible because this server does
|
||||
# not support retryable reads, raise the last error.
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
return func(session, server, sock_info, read_pref)
|
||||
return func(session, server, connection, read_pref)
|
||||
except ServerSelectionTimeoutError:
|
||||
if retrying:
|
||||
# The application may think the write was never attempted
|
||||
@ -1633,14 +1636,14 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# Application called close_cursor() with no address.
|
||||
server = topology.select_server(writable_server_selector)
|
||||
|
||||
with self._get_socket(server, session) as sock_info:
|
||||
self._kill_cursor_impl(cursor_ids, address, session, sock_info)
|
||||
with self._get_socket(server, session) as connection:
|
||||
self._kill_cursor_impl(cursor_ids, address, session, connection)
|
||||
|
||||
def _kill_cursor_impl(self, cursor_ids, address, session, sock_info):
|
||||
def _kill_cursor_impl(self, cursor_ids, address, session, connection):
|
||||
namespace = address.namespace
|
||||
db, coll = namespace.split(".", 1)
|
||||
spec = SON([("killCursors", coll), ("cursors", cursor_ids)])
|
||||
sock_info.command(db, spec, session=session, client=self)
|
||||
connection.command(db, spec, session=session, client=self)
|
||||
|
||||
def _process_kill_cursors(self):
|
||||
"""Process any pending kill cursors requests."""
|
||||
@ -1925,9 +1928,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name_or_database must be an instance of str or a Database")
|
||||
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
with self._socket_for_writes(session) as connection:
|
||||
self[name]._command(
|
||||
sock_info,
|
||||
connection,
|
||||
{"dropDatabase": 1, "comment": comment},
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
write_concern=self._write_concern_for(session),
|
||||
@ -2149,11 +2152,11 @@ class _MongoClientErrorHandler:
|
||||
self.service_id = None
|
||||
self.handled = False
|
||||
|
||||
def contribute_socket(self, sock_info, completed_handshake=True):
|
||||
def contribute_socket(self, connection, completed_handshake=True):
|
||||
"""Provide socket information to the error handler."""
|
||||
self.max_wire_version = sock_info.max_wire_version
|
||||
self.sock_generation = sock_info.generation
|
||||
self.service_id = sock_info.service_id
|
||||
self.max_wire_version = connection.max_wire_version
|
||||
self.sock_generation = connection.generation
|
||||
self.service_id = connection.service_id
|
||||
self.completed_handshake = completed_handshake
|
||||
|
||||
def handle(self, exc_type, exc_val):
|
||||
|
||||
@ -245,9 +245,9 @@ class Monitor(MonitorBase):
|
||||
|
||||
if self._cancel_context and self._cancel_context.cancelled:
|
||||
self._reset_connection()
|
||||
with self._pool.get_socket() as sock_info:
|
||||
self._cancel_context = sock_info.cancel_context
|
||||
response, round_trip_time = self._check_with_socket(sock_info)
|
||||
with self._pool.get_socket() as connection:
|
||||
self._cancel_context = connection.cancel_context
|
||||
response, round_trip_time = self._check_with_socket(connection)
|
||||
if not response.awaitable:
|
||||
self._rtt_monitor.add_sample(round_trip_time)
|
||||
|
||||
@ -393,11 +393,11 @@ class _RttMonitor(MonitorBase):
|
||||
|
||||
def _ping(self):
|
||||
"""Run a "hello" command and return the RTT."""
|
||||
with self._pool.get_socket() as sock_info:
|
||||
with self._pool.get_socket() as connection:
|
||||
if self._executor._stopped:
|
||||
raise Exception("_RttMonitor closed")
|
||||
start = time.monotonic()
|
||||
sock_info.hello()
|
||||
connection.hello()
|
||||
return time.monotonic() - start
|
||||
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.typings import _Address
|
||||
@ -62,7 +62,7 @@ _UNPACK_HEADER = struct.Struct("<iiii").unpack
|
||||
|
||||
|
||||
def command(
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
dbname: str,
|
||||
spec: MutableMapping[str, Any],
|
||||
is_mongos: bool,
|
||||
@ -125,7 +125,7 @@ def command(
|
||||
if read_concern.level:
|
||||
spec["readConcern"] = read_concern.document
|
||||
if session:
|
||||
session._update_read_concern(spec, sock_info)
|
||||
session._update_read_concern(spec, connection)
|
||||
if collation is not None:
|
||||
spec["collation"] = collation
|
||||
|
||||
@ -142,7 +142,7 @@ def command(
|
||||
|
||||
# Support CSOT
|
||||
if client:
|
||||
sock_info.apply_timeout(client, spec)
|
||||
connection.apply_timeout(client, spec)
|
||||
_csot.apply_write_concern(spec, write_concern)
|
||||
|
||||
if use_op_msg:
|
||||
@ -167,19 +167,19 @@ def command(
|
||||
encoding_duration = datetime.datetime.now() - start
|
||||
assert listeners is not None
|
||||
listeners.publish_command_start(
|
||||
orig, dbname, request_id, address, service_id=sock_info.service_id
|
||||
orig, dbname, request_id, address, service_id=connection.service_id
|
||||
)
|
||||
start = datetime.datetime.now()
|
||||
|
||||
try:
|
||||
sock_info.sock.sendall(msg)
|
||||
connection.send_message(msg, max_bson_size)
|
||||
if use_op_msg and unacknowledged:
|
||||
# Unacknowledged, fake a successful command response.
|
||||
reply = None
|
||||
response_doc = {"ok": 1}
|
||||
else:
|
||||
reply = receive_message(sock_info, request_id)
|
||||
sock_info.more_to_come = reply.more_to_come
|
||||
reply = connection.receive_message(request_id)
|
||||
connection.more_to_come = reply.more_to_come
|
||||
unpacked_docs = reply.unpack_response(
|
||||
codec_options=codec_options, user_fields=user_fields
|
||||
)
|
||||
@ -190,7 +190,7 @@ def command(
|
||||
if check:
|
||||
helpers._check_command_response(
|
||||
response_doc,
|
||||
sock_info.max_wire_version,
|
||||
connection.max_wire_version,
|
||||
allowable_errors,
|
||||
parse_write_concern_error=parse_write_concern_error,
|
||||
)
|
||||
@ -203,7 +203,7 @@ def command(
|
||||
failure = message._convert_exception(exc)
|
||||
assert listeners is not None
|
||||
listeners.publish_command_failure(
|
||||
duration, failure, name, request_id, address, service_id=sock_info.service_id
|
||||
duration, failure, name, request_id, address, service_id=connection.service_id
|
||||
)
|
||||
raise
|
||||
if publish:
|
||||
@ -215,7 +215,7 @@ def command(
|
||||
name,
|
||||
request_id,
|
||||
address,
|
||||
service_id=sock_info.service_id,
|
||||
service_id=connection.service_id,
|
||||
speculative_hello=speculative_hello,
|
||||
)
|
||||
|
||||
@ -229,21 +229,21 @@ def command(
|
||||
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
|
||||
|
||||
|
||||
def receive_message(
|
||||
sock_info: SocketInfo, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE
|
||||
def receive_message_tcp(
|
||||
connection: Connection, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
if _csot.get_timeout():
|
||||
deadline = _csot.get_deadline()
|
||||
else:
|
||||
timeout = sock_info.sock.gettimeout()
|
||||
timeout = connection.connector.gettimeout()
|
||||
if timeout:
|
||||
deadline = time.monotonic() + timeout
|
||||
else:
|
||||
deadline = None
|
||||
# Ignore the response's request id.
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(
|
||||
_receive_data_on_socket(sock_info, 16, deadline)
|
||||
_receive_data_on_socket(connection, 16, deadline)
|
||||
)
|
||||
# No request_id for exhaust cursor "getMore".
|
||||
if request_id is not None:
|
||||
@ -260,11 +260,11 @@ def receive_message(
|
||||
)
|
||||
if op_code == 2012:
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
|
||||
_receive_data_on_socket(sock_info, 9, deadline)
|
||||
_receive_data_on_socket(connection, 9, deadline)
|
||||
)
|
||||
data = decompress(_receive_data_on_socket(sock_info, length - 25, deadline), compressor_id)
|
||||
data = decompress(_receive_data_on_socket(connection, length - 25, deadline), compressor_id)
|
||||
else:
|
||||
data = _receive_data_on_socket(sock_info, length - 16, deadline)
|
||||
data = _receive_data_on_socket(connection, length - 16, deadline)
|
||||
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
@ -273,15 +273,51 @@ def receive_message(
|
||||
return unpack_reply(data)
|
||||
|
||||
|
||||
def receive_message_grpc(
|
||||
connection: Connection, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
# Ignore the response's request id.
|
||||
if connection.response is not None:
|
||||
data = connection.response.__next__()
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(data[:16])
|
||||
# No request_id for exhaust cursor "getMore".
|
||||
if request_id is not None:
|
||||
if request_id != response_to:
|
||||
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
|
||||
if length <= 16:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) not longer than standard message header size (16)"
|
||||
)
|
||||
if length > max_message_size:
|
||||
raise ProtocolError(
|
||||
"Message length ({!r}) is larger than server max "
|
||||
"message size ({!r})".format(length, max_message_size)
|
||||
)
|
||||
if op_code == 2012:
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(data[:9])
|
||||
data = decompress(data[25:], compressor_id)
|
||||
else:
|
||||
data = data[16:]
|
||||
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
except KeyError:
|
||||
raise ProtocolError(f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}")
|
||||
return unpack_reply(data)
|
||||
else:
|
||||
raise ProtocolError(f"Received empty response for request id {request_id}")
|
||||
|
||||
|
||||
_POLL_TIMEOUT = 0.5
|
||||
|
||||
|
||||
def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None:
|
||||
def wait_for_read(connection: Connection, deadline: Optional[float]) -> None:
|
||||
"""Block until at least one byte is read, or a timeout, or a cancel."""
|
||||
context = sock_info.cancel_context
|
||||
context = connection.cancel_context
|
||||
# Only Monitor connections can be cancelled.
|
||||
if context:
|
||||
sock = sock_info.sock
|
||||
sock = connection.connector
|
||||
timed_out = False
|
||||
while True:
|
||||
# SSLSocket can have buffered data which won't be caught by select.
|
||||
@ -300,7 +336,7 @@ def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None:
|
||||
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
|
||||
else:
|
||||
timeout = _POLL_TIMEOUT
|
||||
readable = sock_info.socket_checker.select(sock, read=True, timeout=timeout)
|
||||
readable = connection.socket_checker.select(sock, read=True, timeout=timeout)
|
||||
if context.cancelled:
|
||||
raise _OperationCancelled("hello cancelled")
|
||||
if readable:
|
||||
@ -314,20 +350,20 @@ BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS)
|
||||
|
||||
|
||||
def _receive_data_on_socket(
|
||||
sock_info: SocketInfo, length: int, deadline: Optional[float]
|
||||
connection: Connection, length: int, deadline: Optional[float]
|
||||
) -> memoryview:
|
||||
buf = bytearray(length)
|
||||
mv = memoryview(buf)
|
||||
bytes_read = 0
|
||||
while bytes_read < length:
|
||||
try:
|
||||
wait_for_read(sock_info, deadline)
|
||||
wait_for_read(connection, deadline)
|
||||
# CSOT: Update timeout. When the timeout has expired perform one
|
||||
# final non-blocking recv. This helps avoid spurious timeouts when
|
||||
# the response is actually already buffered on the client.
|
||||
if _csot.get_timeout() and deadline is not None:
|
||||
sock_info.set_socket_timeout(max(deadline - time.monotonic(), 0))
|
||||
chunk_length = sock_info.sock.recv_into(mv[bytes_read:])
|
||||
connection.set_connector_timeout(max(deadline - time.monotonic(), 0))
|
||||
chunk_length = connection.connector.recv_into(mv[bytes_read:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
raise socket.timeout("timed out")
|
||||
except OSError as exc: # noqa: B014
|
||||
|
||||
267
pymongo/pool.py
267
pymongo/pool.py
@ -23,7 +23,14 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Dict, NoReturn, Optional
|
||||
from typing import Any, Dict, Iterator, NoReturn, Optional
|
||||
|
||||
try:
|
||||
import grpc # type: ignore
|
||||
|
||||
_HAVE_GRPC = True
|
||||
except ImportError:
|
||||
_HAVE_GRPC = False
|
||||
|
||||
import bson
|
||||
from bson import DEFAULT_CODEC_OPTIONS
|
||||
@ -60,7 +67,7 @@ from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.helpers import _handle_reauth
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
|
||||
from pymongo.network import command, receive_message
|
||||
from pymongo.network import command, receive_message_grpc, receive_message_tcp
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.server_api import _add_to_command
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
@ -370,6 +377,12 @@ def _cond_wait(condition, deadline):
|
||||
return condition.wait(timeout)
|
||||
|
||||
|
||||
class ConnectionProtocol:
|
||||
TCP_SOCKET = 0
|
||||
|
||||
GRPC = 1
|
||||
|
||||
|
||||
class PoolOptions:
|
||||
"""Read only connection pool options for a MongoClient.
|
||||
|
||||
@ -402,6 +415,7 @@ class PoolOptions:
|
||||
"__server_api",
|
||||
"__load_balanced",
|
||||
"__credentials",
|
||||
"__protocol",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -423,6 +437,7 @@ class PoolOptions:
|
||||
server_api=None,
|
||||
load_balanced=None,
|
||||
credentials=None,
|
||||
protocol=ConnectionProtocol.TCP_SOCKET,
|
||||
):
|
||||
self.__max_pool_size = max_pool_size
|
||||
self.__min_pool_size = min_pool_size
|
||||
@ -441,6 +456,7 @@ class PoolOptions:
|
||||
self.__server_api = server_api
|
||||
self.__load_balanced = load_balanced
|
||||
self.__credentials = credentials
|
||||
self.__protocol = protocol
|
||||
self.__metadata = copy.deepcopy(_METADATA)
|
||||
if appname:
|
||||
self.__metadata["application"] = {"name": appname}
|
||||
@ -473,6 +489,11 @@ class PoolOptions:
|
||||
|
||||
_truncate_metadata(self.__metadata)
|
||||
|
||||
@property
|
||||
def protocol(self):
|
||||
"""A :class:`~pymongo.pool.ConnectionProtocol` value."""
|
||||
return self.__protocol
|
||||
|
||||
@property
|
||||
def _credentials(self):
|
||||
"""A :class:`~pymongo.auth.MongoCredentials` instance or None."""
|
||||
@ -614,21 +635,22 @@ class _CancellationContext:
|
||||
return self._cancelled
|
||||
|
||||
|
||||
class SocketInfo:
|
||||
"""Store a socket with some metadata.
|
||||
class Connection:
|
||||
"""Store a connection with some metadata.
|
||||
|
||||
:Parameters:
|
||||
- `sock`: a raw socket object
|
||||
- `connector`: a raw connector implementation, currently either a TCP Socket or a gRPC Stream
|
||||
- `pool`: a Pool instance
|
||||
- `address`: the server's (host, port)
|
||||
- `id`: the id of this socket in it's pool
|
||||
- `id`: the id of this socket in its pool
|
||||
"""
|
||||
|
||||
def __init__(self, sock, pool, address, id):
|
||||
def __init__(self, connector, pool, address, id, protocol):
|
||||
self.pool_ref = weakref.ref(pool)
|
||||
self.sock = sock
|
||||
self.connector = connector
|
||||
self.address = address
|
||||
self.id = id
|
||||
self.protocol = protocol
|
||||
self.authed = set()
|
||||
self.closed = False
|
||||
self.last_checkin_time = time.monotonic()
|
||||
@ -671,14 +693,18 @@ class SocketInfo:
|
||||
self.pinned_cursor = False
|
||||
self.active = False
|
||||
self.last_timeout = self.opts.socket_timeout
|
||||
self.current_timeout = self.last_timeout
|
||||
self.connect_rtt = 0.0
|
||||
self.response: Optional[Iterator] = None
|
||||
|
||||
def set_socket_timeout(self, timeout):
|
||||
"""Cache last timeout to avoid duplicate calls to sock.settimeout."""
|
||||
def set_connector_timeout(self, timeout):
|
||||
"""Cache last timeout to avoid duplicate calls to connector timeout implementations."""
|
||||
if timeout == self.last_timeout:
|
||||
return
|
||||
self.last_timeout = timeout
|
||||
self.sock.settimeout(timeout)
|
||||
self.current_timeout = timeout
|
||||
if self.protocol == ConnectionProtocol.TCP_SOCKET: # TODO: Implement timeouts for gRPC mode
|
||||
self.connector.settimeout(timeout)
|
||||
|
||||
def apply_timeout(self, client, cmd):
|
||||
# CSOT: use remaining timeout when set.
|
||||
@ -686,7 +712,7 @@ class SocketInfo:
|
||||
if timeout is None:
|
||||
# Reset the socket timeout unless we're performing a streaming monitor check.
|
||||
if not self.more_to_come:
|
||||
self.set_socket_timeout(self.opts.socket_timeout)
|
||||
self.set_connector_timeout(self.opts.socket_timeout)
|
||||
return None
|
||||
# RTT validation.
|
||||
rtt = _csot.get_rtt()
|
||||
@ -701,7 +727,7 @@ class SocketInfo:
|
||||
)
|
||||
if cmd is not None:
|
||||
cmd["maxTimeMS"] = int(max_time_ms * 1000)
|
||||
self.set_socket_timeout(timeout)
|
||||
self.set_connector_timeout(timeout)
|
||||
return timeout
|
||||
|
||||
def pin_txn(self):
|
||||
@ -748,7 +774,7 @@ class SocketInfo:
|
||||
awaitable = True
|
||||
# If connect_timeout is None there is no timeout.
|
||||
if self.opts.connect_timeout:
|
||||
self.set_socket_timeout(self.opts.connect_timeout + heartbeat_frequency)
|
||||
self.set_connector_timeout(self.opts.connect_timeout + heartbeat_frequency)
|
||||
|
||||
if not performing_handshake and cluster_time is not None:
|
||||
cmd["$clusterTime"] = cluster_time
|
||||
@ -910,7 +936,7 @@ class SocketInfo:
|
||||
def send_message(self, message, max_doc_size):
|
||||
"""Send a raw BSON message or raise ConnectionFailure.
|
||||
|
||||
If a network exception is raised, the socket is closed.
|
||||
If a network exception is raised, the connector is closed.
|
||||
"""
|
||||
if self.max_bson_size is not None and max_doc_size > self.max_bson_size:
|
||||
raise DocumentTooLarge(
|
||||
@ -919,17 +945,32 @@ class SocketInfo:
|
||||
)
|
||||
|
||||
try:
|
||||
self.sock.sendall(message)
|
||||
if self.protocol == ConnectionProtocol.GRPC:
|
||||
self.response = self.connector.__call__(
|
||||
iter([message]),
|
||||
metadata=[
|
||||
("security-uuid", "uuid"),
|
||||
("username", "user"),
|
||||
("servername", "host.local.10gen.cc"),
|
||||
("mongodb-wireversion", "18"),
|
||||
("x-forwarded-for", "127.0.0.1:9901"),
|
||||
],
|
||||
)
|
||||
else:
|
||||
self.connector.sendall(message)
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
|
||||
def receive_message(self, request_id):
|
||||
"""Receive a raw BSON message or raise ConnectionFailure.
|
||||
|
||||
If any exception is raised, the socket is closed.
|
||||
If any exception is raised, the connector is closed.
|
||||
"""
|
||||
try:
|
||||
return receive_message(self, request_id, self.max_message_size)
|
||||
if self.protocol == ConnectionProtocol.GRPC:
|
||||
return receive_message_grpc(self, request_id, self.max_message_size)
|
||||
else:
|
||||
return receive_message_tcp(self, request_id, self.max_message_size)
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
|
||||
@ -1017,13 +1058,13 @@ class SocketInfo:
|
||||
# Note: We catch exceptions to avoid spurious errors on interpreter
|
||||
# shutdown.
|
||||
try:
|
||||
self.sock.close()
|
||||
self.connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def socket_closed(self):
|
||||
"""Return True if we know socket has been closed, False otherwise."""
|
||||
return self.socket_checker.socket_closed(self.sock)
|
||||
return self.socket_checker.socket_closed(self.connector)
|
||||
|
||||
def send_cluster_time(self, command, session, client):
|
||||
"""Add $clusterTime."""
|
||||
@ -1073,17 +1114,17 @@ class SocketInfo:
|
||||
raise
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.sock == other.sock
|
||||
return self.connector == other.connector
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.sock)
|
||||
return hash(self.connector)
|
||||
|
||||
def __repr__(self):
|
||||
return "SocketInfo({}){} at {}".format(
|
||||
repr(self.sock),
|
||||
return "Connection({}){} at {}".format(
|
||||
repr(self.connector),
|
||||
self.closed and " CLOSED" or "",
|
||||
id(self),
|
||||
)
|
||||
@ -1256,7 +1297,7 @@ class Pool:
|
||||
:Parameters:
|
||||
- `address`: a (hostname, port) tuple
|
||||
- `options`: a PoolOptions instance
|
||||
- `handshake`: whether to call hello for each new SocketInfo
|
||||
- `handshake`: whether to call hello for each new Connection
|
||||
"""
|
||||
if options.pause_enabled:
|
||||
self.state = PoolState.PAUSED
|
||||
@ -1318,6 +1359,11 @@ class Pool:
|
||||
self.ncursors = 0
|
||||
self.ntxns = 0
|
||||
|
||||
if self.opts.protocol == ConnectionProtocol.GRPC:
|
||||
self.channel = self._create_grpc_channel()
|
||||
else:
|
||||
self.channel = None
|
||||
|
||||
def ready(self):
|
||||
# Take the lock to avoid the race condition described in PYTHON-2699.
|
||||
with self.lock:
|
||||
@ -1348,11 +1394,11 @@ class Pool:
|
||||
else:
|
||||
discard: collections.deque = collections.deque()
|
||||
keep: collections.deque = collections.deque()
|
||||
for sock_info in self.sockets:
|
||||
if sock_info.service_id == service_id:
|
||||
discard.append(sock_info)
|
||||
for connection in self.sockets:
|
||||
if connection.service_id == service_id:
|
||||
discard.append(connection)
|
||||
else:
|
||||
keep.append(sock_info)
|
||||
keep.append(connection)
|
||||
sockets = discard
|
||||
self.sockets = keep
|
||||
|
||||
@ -1367,15 +1413,15 @@ class Pool:
|
||||
# PoolClosedEvent but that reset() SHOULD close sockets *after*
|
||||
# publishing the PoolClearedEvent.
|
||||
if close:
|
||||
for sock_info in sockets:
|
||||
sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED)
|
||||
for connection in sockets:
|
||||
connection.close_socket(ConnectionClosedReason.POOL_CLOSED)
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_pool_closed(self.address)
|
||||
else:
|
||||
if old_state != PoolState.PAUSED and self.enabled_for_cmap:
|
||||
listeners.publish_pool_cleared(self.address, service_id=service_id)
|
||||
for sock_info in sockets:
|
||||
sock_info.close_socket(ConnectionClosedReason.STALE)
|
||||
for connection in sockets:
|
||||
connection.close_socket(ConnectionClosedReason.STALE)
|
||||
|
||||
def update_is_writable(self, is_writable):
|
||||
"""Updates the is_writable attribute on all sockets currently in the
|
||||
@ -1395,6 +1441,10 @@ class Pool:
|
||||
def close(self):
|
||||
self._reset(close=True)
|
||||
|
||||
def _create_grpc_channel(self):
|
||||
connection_string = self.address[0] + ":" + str(self.address[1])
|
||||
return grpc.insecure_channel(connection_string, options=self.grpc_channel_options())
|
||||
|
||||
def stale_generation(self, gen, service_id):
|
||||
return self.gen.stale(gen, service_id)
|
||||
|
||||
@ -1415,8 +1465,8 @@ class Pool:
|
||||
self.sockets
|
||||
and self.sockets[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
|
||||
):
|
||||
sock_info = self.sockets.pop()
|
||||
sock_info.close_socket(ConnectionClosedReason.IDLE)
|
||||
connection = self.sockets.pop()
|
||||
connection.close_socket(ConnectionClosedReason.IDLE)
|
||||
|
||||
while True:
|
||||
with self.size_cond:
|
||||
@ -1435,14 +1485,14 @@ class Pool:
|
||||
return
|
||||
self._pending += 1
|
||||
incremented = True
|
||||
sock_info = self.connect()
|
||||
connection = self.connect()
|
||||
with self.lock:
|
||||
# Close connection and return if the pool was reset during
|
||||
# socket creation or while acquiring the pool lock.
|
||||
if self.gen.get_overall() != reference_generation:
|
||||
sock_info.close_socket(ConnectionClosedReason.STALE)
|
||||
connection.close_socket(ConnectionClosedReason.STALE)
|
||||
return
|
||||
self.sockets.appendleft(sock_info)
|
||||
self.sockets.appendleft(connection)
|
||||
finally:
|
||||
if incremented:
|
||||
# Notify after adding the socket to the pool.
|
||||
@ -1455,7 +1505,7 @@ class Pool:
|
||||
self.size_cond.notify()
|
||||
|
||||
def connect(self, handler=None):
|
||||
"""Connect to Mongo and return a new SocketInfo.
|
||||
"""Connect to Mongo and return a new Connection.
|
||||
|
||||
Can raise ConnectionFailure.
|
||||
|
||||
@ -1471,7 +1521,12 @@ class Pool:
|
||||
listeners.publish_connection_created(self.address, conn_id)
|
||||
|
||||
try:
|
||||
sock = _configured_socket(self.address, self.opts)
|
||||
if self.opts.protocol == ConnectionProtocol.GRPC:
|
||||
connector = self.channel.stream_stream(
|
||||
"/mongodb.CommandService/UnauthenticatedCommandStream"
|
||||
)
|
||||
else:
|
||||
connector = _configured_socket(self.address, self.opts)
|
||||
except BaseException as error:
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_closed(
|
||||
@ -1483,26 +1538,42 @@ class Pool:
|
||||
|
||||
raise
|
||||
|
||||
sock_info = SocketInfo(sock, self, self.address, conn_id)
|
||||
connection = Connection(connector, self, self.address, conn_id, self.opts.protocol)
|
||||
try:
|
||||
if self.handshake:
|
||||
sock_info.hello()
|
||||
self.is_writable = sock_info.is_writable
|
||||
connection.hello()
|
||||
self.is_writable = connection.is_writable
|
||||
if handler:
|
||||
handler.contribute_socket(sock_info, completed_handshake=False)
|
||||
handler.contribute_socket(connection, completed_handshake=False)
|
||||
|
||||
sock_info.authenticate()
|
||||
connection.authenticate()
|
||||
except BaseException:
|
||||
sock_info.close_socket(ConnectionClosedReason.ERROR)
|
||||
connection.close_socket(ConnectionClosedReason.ERROR)
|
||||
raise
|
||||
|
||||
return sock_info
|
||||
return connection
|
||||
|
||||
def grpc_metadata(self):
|
||||
return [
|
||||
("security-uuid", "client-uuid"),
|
||||
("username", "user"),
|
||||
("servername", "host.local.10gen.cc"),
|
||||
("mongodb-wireversion", 18),
|
||||
("x-forwarded-for", "127.0.0.1:9901"),
|
||||
]
|
||||
|
||||
def grpc_channel_options(self):
|
||||
return [
|
||||
("grpc.default_authority", "host.local.10gen.cc"),
|
||||
("grpc.max_receive_message_length", 48000000),
|
||||
("grpc.max_send_message_length", 48000000),
|
||||
]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_socket(self, handler=None):
|
||||
"""Get a socket from the pool. Use with a "with" statement.
|
||||
|
||||
Returns a :class:`SocketInfo` object wrapping a connected
|
||||
Returns a :class:`Connection` object wrapping a connected
|
||||
:class:`socket.socket`.
|
||||
|
||||
This method should always be used in a with-statement::
|
||||
@ -1520,36 +1591,36 @@ class Pool:
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_check_out_started(self.address)
|
||||
|
||||
sock_info = self._get_socket(handler=handler)
|
||||
connection = self._get_socket(handler=handler)
|
||||
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_checked_out(self.address, sock_info.id)
|
||||
listeners.publish_connection_checked_out(self.address, connection.id)
|
||||
try:
|
||||
yield sock_info
|
||||
yield connection
|
||||
except BaseException:
|
||||
# Exception in caller. Ensure the connection gets returned.
|
||||
# Note that when pinned is True, the session owns the
|
||||
# connection and it is responsible for checking the connection
|
||||
# back into the pool.
|
||||
pinned = sock_info.pinned_txn or sock_info.pinned_cursor
|
||||
pinned = connection.pinned_txn or connection.pinned_cursor
|
||||
if handler:
|
||||
# Perform SDAM error handling rules while the connection is
|
||||
# still checked out.
|
||||
exc_type, exc_val, _ = sys.exc_info()
|
||||
handler.handle(exc_type, exc_val)
|
||||
if not pinned and sock_info.active:
|
||||
self.return_socket(sock_info)
|
||||
if not pinned and connection.active:
|
||||
self.return_socket(connection)
|
||||
raise
|
||||
if sock_info.pinned_txn:
|
||||
if connection.pinned_txn:
|
||||
with self.lock:
|
||||
self.__pinned_sockets.add(sock_info)
|
||||
self.__pinned_sockets.add(connection)
|
||||
self.ntxns += 1
|
||||
elif sock_info.pinned_cursor:
|
||||
elif connection.pinned_cursor:
|
||||
with self.lock:
|
||||
self.__pinned_sockets.add(sock_info)
|
||||
self.__pinned_sockets.add(connection)
|
||||
self.ncursors += 1
|
||||
elif sock_info.active:
|
||||
self.return_socket(sock_info)
|
||||
elif connection.active:
|
||||
self.return_socket(connection)
|
||||
|
||||
def _raise_if_not_ready(self, emit_event):
|
||||
if self.state != PoolState.READY:
|
||||
@ -1560,7 +1631,7 @@ class Pool:
|
||||
_raise_connection_failure(self.address, AutoReconnect("connection pool paused"))
|
||||
|
||||
def _get_socket(self, handler=None):
|
||||
"""Get or create a SocketInfo. Can raise ConnectionFailure."""
|
||||
"""Get or create a Connection. Can raise ConnectionFailure."""
|
||||
# We use the pid here to avoid issues with fork / multiprocessing.
|
||||
# See test.test_client:TestClient.test_fork for an example of
|
||||
# what could go wrong otherwise
|
||||
@ -1600,7 +1671,7 @@ class Pool:
|
||||
self.requests += 1
|
||||
|
||||
# We've now acquired the semaphore and must release it on error.
|
||||
sock_info = None
|
||||
connection = None
|
||||
incremented = False
|
||||
emitted_event = False
|
||||
try:
|
||||
@ -1608,7 +1679,7 @@ class Pool:
|
||||
self.active_sockets += 1
|
||||
incremented = True
|
||||
|
||||
while sock_info is None:
|
||||
while connection is None:
|
||||
# CMAP: we MUST wait for either maxConnecting OR for a socket
|
||||
# to be checked back into the pool.
|
||||
with self._max_connecting_cond:
|
||||
@ -1624,24 +1695,24 @@ class Pool:
|
||||
self._raise_if_not_ready(emit_event=False)
|
||||
|
||||
try:
|
||||
sock_info = self.sockets.popleft()
|
||||
connection = self.sockets.popleft()
|
||||
except IndexError:
|
||||
self._pending += 1
|
||||
if sock_info: # We got a socket from the pool
|
||||
if self._perished(sock_info):
|
||||
sock_info = None
|
||||
if connection: # We got a socket from the pool
|
||||
if self._perished(connection):
|
||||
connection = None
|
||||
continue
|
||||
else: # We need to create a new connection
|
||||
try:
|
||||
sock_info = self.connect(handler=handler)
|
||||
connection = self.connect(handler=handler)
|
||||
finally:
|
||||
with self._max_connecting_cond:
|
||||
self._pending -= 1
|
||||
self._max_connecting_cond.notify()
|
||||
except BaseException:
|
||||
if sock_info:
|
||||
if connection:
|
||||
# We checked out a socket but authentication failed.
|
||||
sock_info.close_socket(ConnectionClosedReason.ERROR)
|
||||
connection.close_socket(ConnectionClosedReason.ERROR)
|
||||
with self.size_cond:
|
||||
self.requests -= 1
|
||||
if incremented:
|
||||
@ -1654,45 +1725,45 @@ class Pool:
|
||||
)
|
||||
raise
|
||||
|
||||
sock_info.active = True
|
||||
return sock_info
|
||||
connection.active = True
|
||||
return connection
|
||||
|
||||
def return_socket(self, sock_info):
|
||||
def return_socket(self, connection):
|
||||
"""Return the socket to the pool, or if it's closed discard it.
|
||||
|
||||
:Parameters:
|
||||
- `sock_info`: The socket to check into the pool.
|
||||
- `connection`: The socket to check into the pool.
|
||||
"""
|
||||
txn = sock_info.pinned_txn
|
||||
cursor = sock_info.pinned_cursor
|
||||
sock_info.active = False
|
||||
sock_info.pinned_txn = False
|
||||
sock_info.pinned_cursor = False
|
||||
self.__pinned_sockets.discard(sock_info)
|
||||
txn = connection.pinned_txn
|
||||
cursor = connection.pinned_cursor
|
||||
connection.active = False
|
||||
connection.pinned_txn = False
|
||||
connection.pinned_cursor = False
|
||||
self.__pinned_sockets.discard(connection)
|
||||
listeners = self.opts._event_listeners
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_checked_in(self.address, sock_info.id)
|
||||
listeners.publish_connection_checked_in(self.address, connection.id)
|
||||
if self.pid != os.getpid():
|
||||
self.reset_without_pause()
|
||||
else:
|
||||
if self.closed:
|
||||
sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED)
|
||||
elif sock_info.closed:
|
||||
connection.close_socket(ConnectionClosedReason.POOL_CLOSED)
|
||||
elif connection.closed:
|
||||
# CMAP requires the closed event be emitted after the check in.
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_closed(
|
||||
self.address, sock_info.id, ConnectionClosedReason.ERROR
|
||||
self.address, connection.id, ConnectionClosedReason.ERROR
|
||||
)
|
||||
else:
|
||||
with self.lock:
|
||||
# Hold the lock to ensure this section does not race with
|
||||
# Pool.reset().
|
||||
if self.stale_generation(sock_info.generation, sock_info.service_id):
|
||||
sock_info.close_socket(ConnectionClosedReason.STALE)
|
||||
if self.stale_generation(connection.generation, connection.service_id):
|
||||
connection.close_socket(ConnectionClosedReason.STALE)
|
||||
else:
|
||||
sock_info.update_last_checkin_time()
|
||||
sock_info.update_is_writable(self.is_writable)
|
||||
self.sockets.appendleft(sock_info)
|
||||
connection.update_last_checkin_time()
|
||||
connection.update_is_writable(self.is_writable)
|
||||
self.sockets.appendleft(connection)
|
||||
# Notify any threads waiting to create a connection.
|
||||
self._max_connecting_cond.notify()
|
||||
|
||||
@ -1706,7 +1777,7 @@ class Pool:
|
||||
self.operation_count -= 1
|
||||
self.size_cond.notify()
|
||||
|
||||
def _perished(self, sock_info):
|
||||
def _perished(self, connection):
|
||||
"""Return True and close the connection if it is "perished".
|
||||
|
||||
This side-effecty function checks if this socket has been idle for
|
||||
@ -1720,24 +1791,24 @@ class Pool:
|
||||
pool, to keep performance reasonable - we can't avoid AutoReconnects
|
||||
completely anyway.
|
||||
"""
|
||||
idle_time_seconds = sock_info.idle_time_seconds()
|
||||
idle_time_seconds = connection.idle_time_seconds()
|
||||
# If socket is idle, open a new one.
|
||||
if (
|
||||
self.opts.max_idle_time_seconds is not None
|
||||
and idle_time_seconds > self.opts.max_idle_time_seconds
|
||||
):
|
||||
sock_info.close_socket(ConnectionClosedReason.IDLE)
|
||||
connection.close_socket(ConnectionClosedReason.IDLE)
|
||||
return True
|
||||
|
||||
if self._check_interval_seconds is not None and (
|
||||
0 == self._check_interval_seconds or idle_time_seconds > self._check_interval_seconds
|
||||
):
|
||||
if sock_info.socket_closed():
|
||||
sock_info.close_socket(ConnectionClosedReason.ERROR)
|
||||
if connection.socket_closed():
|
||||
connection.close_socket(ConnectionClosedReason.ERROR)
|
||||
return True
|
||||
|
||||
if self.stale_generation(sock_info.generation, sock_info.service_id):
|
||||
sock_info.close_socket(ConnectionClosedReason.STALE)
|
||||
if self.stale_generation(connection.generation, connection.service_id):
|
||||
connection.close_socket(ConnectionClosedReason.STALE)
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -1772,5 +1843,5 @@ class Pool:
|
||||
# Avoid ResourceWarnings in Python 3
|
||||
# Close all sockets without calling reset() or close() because it is
|
||||
# not safe to acquire a lock in __del__.
|
||||
for sock_info in self.sockets:
|
||||
sock_info.close_socket(None)
|
||||
for connection in self.sockets:
|
||||
connection.close_socket(None)
|
||||
|
||||
@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from pymongo.message import _OpMsg, _OpReply
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.typings import _Address
|
||||
|
||||
|
||||
@ -91,7 +91,7 @@ class PinnedResponse(Response):
|
||||
self,
|
||||
data: Union[_OpMsg, _OpReply],
|
||||
address: _Address,
|
||||
socket_info: SocketInfo,
|
||||
socket_info: Connection,
|
||||
request_id: int,
|
||||
duration: Optional[timedelta],
|
||||
from_command: bool,
|
||||
@ -103,7 +103,7 @@ class PinnedResponse(Response):
|
||||
:Parameters:
|
||||
- `data`: A network response message.
|
||||
- `address`: (host, port) of the source server.
|
||||
- `socket_info`: The SocketInfo used for the initial query.
|
||||
- `socket_info`: The Connection used for the initial query.
|
||||
- `request_id`: The request id of this operation.
|
||||
- `duration`: The duration of the operation.
|
||||
- `from_command`: If the response is the result of a db command.
|
||||
@ -116,8 +116,8 @@ class PinnedResponse(Response):
|
||||
self._more_to_come = more_to_come
|
||||
|
||||
@property
|
||||
def socket_info(self) -> SocketInfo:
|
||||
"""The SocketInfo used for the initial query.
|
||||
def socket_info(self) -> Connection:
|
||||
"""The Connection used for the initial query.
|
||||
|
||||
The server will send batches on this socket, without waiting for
|
||||
getMores from the client, until the result set is exhausted or there
|
||||
|
||||
@ -42,7 +42,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.mongo_client import _MongoClientErrorHandler
|
||||
from pymongo.monitor import Monitor
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.pool import Pool, SocketInfo
|
||||
from pymongo.pool import Connection, Pool
|
||||
from pymongo.server_description import ServerDescription
|
||||
|
||||
_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}
|
||||
@ -105,7 +105,7 @@ class Server:
|
||||
@_handle_reauth
|
||||
def run_operation(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
connection: Connection,
|
||||
operation: Union[_Query, _GetMore],
|
||||
read_preference: bool,
|
||||
listeners: _EventListeners,
|
||||
@ -118,7 +118,7 @@ class Server:
|
||||
Can raise ConnectionFailure, OperationFailure, etc.
|
||||
|
||||
:Parameters:
|
||||
- `sock_info`: A SocketInfo instance.
|
||||
- `connection`: A Connection instance.
|
||||
- `operation`: A _Query or _GetMore object.
|
||||
- `read_preference`: The read preference to use.
|
||||
- `listeners`: Instance of _EventListeners or None.
|
||||
@ -129,27 +129,27 @@ class Server:
|
||||
if publish:
|
||||
start = datetime.now()
|
||||
|
||||
use_cmd = operation.use_command(sock_info)
|
||||
use_cmd = operation.use_command(connection)
|
||||
more_to_come = operation.sock_mgr and operation.sock_mgr.more_to_come
|
||||
if more_to_come:
|
||||
request_id = 0
|
||||
else:
|
||||
message = operation.get_message(read_preference, sock_info, use_cmd)
|
||||
message = operation.get_message(read_preference, connection, use_cmd)
|
||||
request_id, data, max_doc_size = self._split_message(message)
|
||||
|
||||
if publish:
|
||||
cmd, dbn = operation.as_command(sock_info)
|
||||
cmd, dbn = operation.as_command(connection)
|
||||
listeners.publish_command_start(
|
||||
cmd, dbn, request_id, sock_info.address, service_id=sock_info.service_id
|
||||
cmd, dbn, request_id, connection.address, service_id=connection.service_id
|
||||
)
|
||||
start = datetime.now()
|
||||
|
||||
try:
|
||||
if more_to_come:
|
||||
reply = sock_info.receive_message(None)
|
||||
reply = connection.receive_message(None)
|
||||
else:
|
||||
sock_info.send_message(data, max_doc_size)
|
||||
reply = sock_info.receive_message(request_id)
|
||||
connection.send_message(data, max_doc_size)
|
||||
reply = connection.receive_message(request_id)
|
||||
|
||||
# Unpack and check for command errors.
|
||||
if use_cmd:
|
||||
@ -168,7 +168,7 @@ class Server:
|
||||
if use_cmd:
|
||||
first = docs[0]
|
||||
operation.client._process_response(first, operation.session)
|
||||
_check_command_response(first, sock_info.max_wire_version)
|
||||
_check_command_response(first, connection.max_wire_version)
|
||||
except Exception as exc:
|
||||
if publish:
|
||||
duration = datetime.now() - start
|
||||
@ -181,8 +181,8 @@ class Server:
|
||||
failure,
|
||||
operation.name,
|
||||
request_id,
|
||||
sock_info.address,
|
||||
service_id=sock_info.service_id,
|
||||
connection.address,
|
||||
service_id=connection.service_id,
|
||||
)
|
||||
raise
|
||||
|
||||
@ -205,8 +205,8 @@ class Server:
|
||||
res,
|
||||
operation.name,
|
||||
request_id,
|
||||
sock_info.address,
|
||||
service_id=sock_info.service_id,
|
||||
connection.address,
|
||||
service_id=connection.service_id,
|
||||
)
|
||||
|
||||
# Decrypt response.
|
||||
@ -219,7 +219,7 @@ class Server:
|
||||
response: Response
|
||||
|
||||
if client._should_pin_cursor(operation.session) or operation.exhaust:
|
||||
sock_info.pin_cursor()
|
||||
connection.pin_cursor()
|
||||
if isinstance(reply, _OpMsg):
|
||||
# In OP_MSG, the server keeps sending only if the
|
||||
# more_to_come flag is set.
|
||||
@ -232,7 +232,7 @@ class Server:
|
||||
response = PinnedResponse(
|
||||
data=reply,
|
||||
address=self._description.address,
|
||||
socket_info=sock_info,
|
||||
socket_info=connection,
|
||||
duration=duration,
|
||||
request_id=request_id,
|
||||
from_command=use_cmd,
|
||||
@ -253,7 +253,7 @@ class Server:
|
||||
|
||||
def get_socket(
|
||||
self, handler: Optional[_MongoClientErrorHandler] = None
|
||||
) -> ContextManager[SocketInfo]:
|
||||
) -> ContextManager[Connection]:
|
||||
return self.pool.get_socket(handler)
|
||||
|
||||
@property
|
||||
|
||||
@ -778,6 +778,7 @@ class Topology:
|
||||
driver=options.driver,
|
||||
pause_enabled=False,
|
||||
server_api=options.server_api,
|
||||
protocol=options.protocol,
|
||||
)
|
||||
|
||||
return self._settings.pool_class(address, monitor_pool_options, handshake=False)
|
||||
|
||||
@ -41,6 +41,7 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"dnspython>=1.16.0,<3.0.0",
|
||||
"grpcio>=1.56.0"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
7
test/grpc_example.py
Normal file
7
test/grpc_example.py
Normal file
@ -0,0 +1,7 @@
|
||||
from pymongo import MongoClient
|
||||
|
||||
client = MongoClient("mongodb://host9.local.10gen.cc:9901", grpc=True, loadBalanced=True)
|
||||
|
||||
dbs = client.list_databases()
|
||||
for db in dbs:
|
||||
print(db["name"], ": ", client[db["name"]].list_collection_names())
|
||||
@ -48,10 +48,10 @@ class MockPool(Pool):
|
||||
client.mock_standalones + client.mock_members + client.mock_mongoses
|
||||
), ("bad host: %s" % host_and_port)
|
||||
|
||||
with Pool.get_socket(self, handler) as sock_info:
|
||||
sock_info.mock_host = self.mock_host
|
||||
sock_info.mock_port = self.mock_port
|
||||
yield sock_info
|
||||
with Pool.get_socket(self, handler) as connection:
|
||||
connection.mock_host = self.mock_host
|
||||
connection.mock_port = self.mock_port
|
||||
yield connection
|
||||
|
||||
|
||||
class DummyMonitor:
|
||||
|
||||
@ -96,7 +96,7 @@ from pymongo.errors import (
|
||||
)
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent
|
||||
from pymongo.pool import _METADATA, PoolOptions, SocketInfo
|
||||
from pymongo.pool import _METADATA, Connection, PoolOptions
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import readable_server_selector, writable_server_selector
|
||||
@ -541,10 +541,10 @@ class TestClient(IntegrationTest):
|
||||
# Assert reaper doesn't remove sockets when maxIdleTimeMS not set
|
||||
client = rs_or_single_client()
|
||||
server = client._get_topology().select_server(readable_server_selector)
|
||||
with server._pool.get_socket() as sock_info:
|
||||
with server._pool.get_socket() as connection:
|
||||
pass
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
self.assertTrue(sock_info in server._pool.sockets)
|
||||
self.assertTrue(connection in server._pool.sockets)
|
||||
client.close()
|
||||
|
||||
def test_max_idle_time_reaper_removes_stale_minPoolSize(self):
|
||||
@ -552,12 +552,12 @@ class TestClient(IntegrationTest):
|
||||
# Assert reaper removes idle socket and replaces it with a new one
|
||||
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1)
|
||||
server = client._get_topology().select_server(readable_server_selector)
|
||||
with server._pool.get_socket() as sock_info:
|
||||
with server._pool.get_socket() as connection:
|
||||
pass
|
||||
# When the reaper runs at the same time as the get_socket, two
|
||||
# sockets could be created and checked into the pool.
|
||||
self.assertGreaterEqual(len(server._pool.sockets), 1)
|
||||
wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket")
|
||||
wait_until(lambda: connection not in server._pool.sockets, "remove stale socket")
|
||||
wait_until(lambda: 1 <= len(server._pool.sockets), "replace stale socket")
|
||||
client.close()
|
||||
|
||||
@ -566,12 +566,12 @@ class TestClient(IntegrationTest):
|
||||
# Assert reaper respects maxPoolSize when adding new sockets.
|
||||
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1)
|
||||
server = client._get_topology().select_server(readable_server_selector)
|
||||
with server._pool.get_socket() as sock_info:
|
||||
with server._pool.get_socket() as connection:
|
||||
pass
|
||||
# When the reaper runs at the same time as the get_socket,
|
||||
# maxPoolSize=1 should prevent two sockets from being created.
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket")
|
||||
wait_until(lambda: connection not in server._pool.sockets, "remove stale socket")
|
||||
wait_until(lambda: 1 == len(server._pool.sockets), "replace stale socket")
|
||||
client.close()
|
||||
|
||||
@ -605,39 +605,39 @@ class TestClient(IntegrationTest):
|
||||
wait_until(lambda: 10 == len(server._pool.sockets), "pool initialized with 10 sockets")
|
||||
|
||||
# Assert that if a socket is closed, a new one takes its place
|
||||
with server._pool.get_socket() as sock_info:
|
||||
sock_info.close_socket(None)
|
||||
with server._pool.get_socket() as connection:
|
||||
connection.close_socket(None)
|
||||
wait_until(
|
||||
lambda: 10 == len(server._pool.sockets),
|
||||
"a closed socket gets replaced from the pool",
|
||||
)
|
||||
self.assertFalse(sock_info in server._pool.sockets)
|
||||
self.assertFalse(connection in server._pool.sockets)
|
||||
|
||||
def test_max_idle_time_checkout(self):
|
||||
# Use high frequency to test _get_socket_no_auth.
|
||||
with client_knobs(kill_cursor_frequency=99999999):
|
||||
client = rs_or_single_client(maxIdleTimeMS=500)
|
||||
server = client._get_topology().select_server(readable_server_selector)
|
||||
with server._pool.get_socket() as sock_info:
|
||||
with server._pool.get_socket() as connection:
|
||||
pass
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
time.sleep(1) # Sleep so that the socket becomes stale.
|
||||
|
||||
with server._pool.get_socket() as new_sock_info:
|
||||
self.assertNotEqual(sock_info, new_sock_info)
|
||||
self.assertNotEqual(connection, new_sock_info)
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
self.assertFalse(sock_info in server._pool.sockets)
|
||||
self.assertFalse(connection in server._pool.sockets)
|
||||
self.assertTrue(new_sock_info in server._pool.sockets)
|
||||
|
||||
# Test that sockets are reused if maxIdleTimeMS is not set.
|
||||
client = rs_or_single_client()
|
||||
server = client._get_topology().select_server(readable_server_selector)
|
||||
with server._pool.get_socket() as sock_info:
|
||||
with server._pool.get_socket() as connection:
|
||||
pass
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
time.sleep(1)
|
||||
with server._pool.get_socket() as new_sock_info:
|
||||
self.assertEqual(sock_info, new_sock_info)
|
||||
self.assertEqual(connection, new_sock_info)
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
|
||||
def test_constants(self):
|
||||
@ -1130,8 +1130,8 @@ class TestClient(IntegrationTest):
|
||||
|
||||
def test_socketKeepAlive(self):
|
||||
pool = get_pool(self.client)
|
||||
with pool.get_socket() as sock_info:
|
||||
keepalive = sock_info.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
|
||||
with pool.get_socket() as connection:
|
||||
keepalive = connection.connector.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
|
||||
self.assertTrue(keepalive)
|
||||
|
||||
@no_type_check
|
||||
@ -1326,13 +1326,13 @@ class TestClient(IntegrationTest):
|
||||
connected(client)
|
||||
|
||||
# Cause a network error.
|
||||
sock_info = one(pool.sockets)
|
||||
sock_info.sock.close()
|
||||
connection = one(pool.sockets)
|
||||
connection.connector.close()
|
||||
cursor = collection.find(cursor_type=CursorType.EXHAUST)
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
next(cursor)
|
||||
|
||||
self.assertTrue(sock_info.closed)
|
||||
self.assertTrue(connection.closed)
|
||||
|
||||
# The semaphore was decremented despite the error.
|
||||
self.assertEqual(0, pool.requests)
|
||||
@ -1350,7 +1350,7 @@ class TestClient(IntegrationTest):
|
||||
socket_info = one(pool.sockets)
|
||||
socket_info.sock.close()
|
||||
|
||||
# SocketInfo.authenticate logs, but gets a socket.error. Should be
|
||||
# Connection.authenticate logs, but gets a socket.error. Should be
|
||||
# reraised as AutoReconnect.
|
||||
self.assertRaises(AutoReconnect, c.test.collection.find_one)
|
||||
|
||||
@ -1847,7 +1847,7 @@ class TestExhaustCursor(IntegrationTest):
|
||||
|
||||
collection = client.pymongo_test.test
|
||||
pool = get_pool(client)
|
||||
sock_info = one(pool.sockets)
|
||||
connection = one(pool.sockets)
|
||||
|
||||
# This will cause OperationFailure in all mongo versions since
|
||||
# the value for $orderby must be a document.
|
||||
@ -1856,10 +1856,10 @@ class TestExhaustCursor(IntegrationTest):
|
||||
)
|
||||
|
||||
self.assertRaises(OperationFailure, cursor.next)
|
||||
self.assertFalse(sock_info.closed)
|
||||
self.assertFalse(connection.closed)
|
||||
|
||||
# The socket was checked in and the semaphore was decremented.
|
||||
self.assertIn(sock_info, pool.sockets)
|
||||
self.assertIn(connection, pool.sockets)
|
||||
self.assertEqual(0, pool.requests)
|
||||
|
||||
def test_exhaust_getmore_server_error(self):
|
||||
@ -1874,7 +1874,7 @@ class TestExhaustCursor(IntegrationTest):
|
||||
|
||||
pool = get_pool(client)
|
||||
pool._check_interval_seconds = None # Never check.
|
||||
sock_info = one(pool.sockets)
|
||||
connection = one(pool.sockets)
|
||||
|
||||
cursor = collection.find(cursor_type=CursorType.EXHAUST)
|
||||
|
||||
@ -1884,21 +1884,21 @@ class TestExhaustCursor(IntegrationTest):
|
||||
# Cause a server error on getmore.
|
||||
def receive_message(request_id):
|
||||
# Discard the actual server response.
|
||||
SocketInfo.receive_message(sock_info, request_id)
|
||||
Connection.receive_message(connection, request_id)
|
||||
|
||||
# responseFlags bit 1 is QueryFailure.
|
||||
msg = struct.pack("<iiiii", 1 << 1, 0, 0, 0, 0)
|
||||
msg += encode({"$err": "mock err", "code": 0})
|
||||
return message._OpReply.unpack(msg)
|
||||
|
||||
sock_info.receive_message = receive_message
|
||||
connection.receive_message = receive_message
|
||||
self.assertRaises(OperationFailure, list, cursor)
|
||||
# Unpatch the instance.
|
||||
del sock_info.receive_message
|
||||
del connection.receive_message
|
||||
|
||||
# The socket is returned to the pool and it still works.
|
||||
self.assertEqual(200, collection.count_documents({}))
|
||||
self.assertIn(sock_info, pool.sockets)
|
||||
self.assertIn(connection, pool.sockets)
|
||||
|
||||
def test_exhaust_query_network_error(self):
|
||||
# When doing an exhaust query, the socket stays checked out on success
|
||||
@ -1909,15 +1909,15 @@ class TestExhaustCursor(IntegrationTest):
|
||||
pool._check_interval_seconds = None # Never check.
|
||||
|
||||
# Cause a network error.
|
||||
sock_info = one(pool.sockets)
|
||||
sock_info.sock.close()
|
||||
connection = one(pool.sockets)
|
||||
connection.connector.close()
|
||||
|
||||
cursor = collection.find(cursor_type=CursorType.EXHAUST)
|
||||
self.assertRaises(ConnectionFailure, cursor.next)
|
||||
self.assertTrue(sock_info.closed)
|
||||
self.assertTrue(connection.closed)
|
||||
|
||||
# The socket was closed and the semaphore was decremented.
|
||||
self.assertNotIn(sock_info, pool.sockets)
|
||||
self.assertNotIn(connection, pool.sockets)
|
||||
self.assertEqual(0, pool.requests)
|
||||
|
||||
def test_exhaust_getmore_network_error(self):
|
||||
@ -1936,15 +1936,15 @@ class TestExhaustCursor(IntegrationTest):
|
||||
cursor.next()
|
||||
|
||||
# Cause a network error.
|
||||
sock_info = cursor._Cursor__sock_mgr.sock
|
||||
sock_info.sock.close()
|
||||
connection = cursor._Cursor__sock_mgr.sock
|
||||
connection.connector.close()
|
||||
|
||||
# A getmore fails.
|
||||
self.assertRaises(ConnectionFailure, list, cursor)
|
||||
self.assertTrue(sock_info.closed)
|
||||
self.assertTrue(connection.closed)
|
||||
|
||||
# The socket was closed and the semaphore was decremented.
|
||||
self.assertNotIn(sock_info, pool.sockets)
|
||||
self.assertNotIn(connection, pool.sockets)
|
||||
self.assertEqual(0, pool.requests)
|
||||
|
||||
def test_gevent_task(self):
|
||||
|
||||
@ -123,19 +123,19 @@ class TestCMAP(IntegrationTest):
|
||||
def check_out(self, op):
|
||||
"""Run the 'checkOut' operation."""
|
||||
label = op["label"]
|
||||
with self.pool.get_socket() as sock_info:
|
||||
with self.pool.get_socket() as connection:
|
||||
# Call 'pin_cursor' so we can hold the socket.
|
||||
sock_info.pin_cursor()
|
||||
connection.pin_cursor()
|
||||
if label:
|
||||
self.labels[label] = sock_info
|
||||
self.labels[label] = connection
|
||||
else:
|
||||
self.addCleanup(sock_info.close_socket, None)
|
||||
self.addCleanup(connection.close_socket, None)
|
||||
|
||||
def check_in(self, op):
|
||||
"""Run the 'checkIn' operation."""
|
||||
label = op["connection"]
|
||||
sock_info = self.labels[label]
|
||||
self.pool.return_socket(sock_info)
|
||||
connection = self.labels[label]
|
||||
self.pool.return_socket(connection)
|
||||
|
||||
def ready(self, op):
|
||||
"""Run the 'ready' operation."""
|
||||
|
||||
@ -94,7 +94,7 @@ def got_app_error(topology, app_error):
|
||||
when = app_error["when"]
|
||||
max_wire_version = app_error["maxWireVersion"]
|
||||
# XXX: We could get better test coverage by mocking the errors on the
|
||||
# Pool/SocketInfo.
|
||||
# Pool/Connection.
|
||||
try:
|
||||
if error_type == "command":
|
||||
_check_command_response(app_error["response"], max_wire_version)
|
||||
@ -279,7 +279,7 @@ class TestIgnoreStaleErrors(IntegrationTest):
|
||||
def mock_command(*args, **kwargs):
|
||||
# Synchronize all threads to ensure they use the same generation.
|
||||
barrier.wait()
|
||||
raise AutoReconnect("mock SocketInfo.command error")
|
||||
raise AutoReconnect("mock Connection.command error")
|
||||
|
||||
for sock in pool.sockets:
|
||||
sock.command = mock_command
|
||||
|
||||
@ -188,11 +188,11 @@ class TestPooling(_TestPoolingBase):
|
||||
# Test Pool's _check_closed() method doesn't close a healthy socket.
|
||||
cx_pool = self.create_pool(max_pool_size=10)
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
with cx_pool.get_socket() as connection:
|
||||
pass
|
||||
|
||||
with cx_pool.get_socket() as new_sock_info:
|
||||
self.assertEqual(sock_info, new_sock_info)
|
||||
self.assertEqual(connection, new_sock_info)
|
||||
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
|
||||
@ -200,12 +200,12 @@ class TestPooling(_TestPoolingBase):
|
||||
# get_socket() returns socket after a non-network error.
|
||||
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
with cx_pool.get_socket() as connection:
|
||||
1 / 0
|
||||
|
||||
# Socket was returned, not closed.
|
||||
with cx_pool.get_socket() as new_sock_info:
|
||||
self.assertEqual(sock_info, new_sock_info)
|
||||
self.assertEqual(connection, new_sock_info)
|
||||
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
|
||||
@ -213,9 +213,9 @@ class TestPooling(_TestPoolingBase):
|
||||
# Test that Pool removes explicitly closed socket.
|
||||
cx_pool = self.create_pool()
|
||||
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
# Use SocketInfo's API to close the socket.
|
||||
sock_info.close_socket(None)
|
||||
with cx_pool.get_socket() as connection:
|
||||
# Use Connection's API to close the socket.
|
||||
connection.close_socket(None)
|
||||
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
|
||||
@ -225,15 +225,15 @@ class TestPooling(_TestPoolingBase):
|
||||
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
# Simulate a closed socket without telling the SocketInfo it's
|
||||
with cx_pool.get_socket() as connection:
|
||||
# Simulate a closed socket without telling the Connection it's
|
||||
# closed.
|
||||
sock_info.sock.close()
|
||||
self.assertTrue(sock_info.socket_closed())
|
||||
connection.connector.close()
|
||||
self.assertTrue(connection.socket_closed())
|
||||
|
||||
with cx_pool.get_socket() as new_sock_info:
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
self.assertNotEqual(sock_info, new_sock_info)
|
||||
self.assertNotEqual(connection, new_sock_info)
|
||||
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
|
||||
@ -299,10 +299,10 @@ class TestPooling(_TestPoolingBase):
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
self.addCleanup(cx_pool.close)
|
||||
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
# Simulate a closed socket without telling the SocketInfo it's
|
||||
with cx_pool.get_socket() as connection:
|
||||
# Simulate a closed socket without telling the Connection it's
|
||||
# closed.
|
||||
sock_info.sock.close()
|
||||
connection.connector.close()
|
||||
|
||||
# Swap pool's address with a bad one.
|
||||
address, cx_pool.address = cx_pool.address, ("foo.com", 1234)
|
||||
|
||||
@ -286,16 +286,16 @@ class ReadPrefTester(MongoClient):
|
||||
@contextlib.contextmanager
|
||||
def _socket_for_reads(self, read_preference, session):
|
||||
context = super()._socket_for_reads(read_preference, session)
|
||||
with context as (sock_info, read_preference):
|
||||
self.record_a_read(sock_info.address)
|
||||
yield sock_info, read_preference
|
||||
with context as (connection, read_preference):
|
||||
self.record_a_read(connection.address)
|
||||
yield connection, read_preference
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _socket_from_server(self, read_preference, server, session):
|
||||
context = super()._socket_from_server(read_preference, server, session)
|
||||
with context as (sock_info, read_preference):
|
||||
self.record_a_read(sock_info.address)
|
||||
yield sock_info, read_preference
|
||||
with context as (connection, read_preference):
|
||||
self.record_a_read(connection.address)
|
||||
yield connection, read_preference
|
||||
|
||||
def record_a_read(self, address):
|
||||
server = self._get_topology().select_server_by_address(address, 0)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user