Compare commits

...

1 Commits
master ... grpc

Author SHA1 Message Date
Noah Stapp
34ca694c9f
PYTHON-3801 gRPC POC phase 1 (#1317) 2023-07-26 14:01:22 -07:00
30 changed files with 638 additions and 508 deletions

View File

@ -29,7 +29,7 @@ if TYPE_CHECKING:
from pymongo.collection import Collection from pymongo.collection import Collection
from pymongo.command_cursor import CommandCursor from pymongo.command_cursor import CommandCursor
from pymongo.database import Database from pymongo.database import Database
from pymongo.pool import SocketInfo from pymongo.pool import Connection
from pymongo.read_preferences import _ServerMode from pymongo.read_preferences import _ServerMode
from pymongo.server import Server from pymongo.server import Server
from pymongo.typings import _Pipeline from pymongo.typings import _Pipeline
@ -52,7 +52,7 @@ class _AggregationCommand:
explicit_session: bool, explicit_session: bool,
let: Optional[Mapping[str, Any]] = None, let: Optional[Mapping[str, Any]] = None,
user_fields: Optional[MutableMapping[str, Any]] = None, user_fields: Optional[MutableMapping[str, Any]] = None,
result_processor: Optional[Callable[[Mapping[str, Any], SocketInfo], None]] = None, result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None,
comment: Any = None, comment: Any = None,
) -> None: ) -> None:
if "explain" in options: if "explain" in options:
@ -134,7 +134,7 @@ class _AggregationCommand:
self, self,
session: ClientSession, session: ClientSession,
server: Server, server: Server,
sock_info: SocketInfo, connection: Connection,
read_preference: _ServerMode, read_preference: _ServerMode,
) -> CommandCursor: ) -> CommandCursor:
# Serialize command. # Serialize command.
@ -146,7 +146,7 @@ class _AggregationCommand:
# - server version is >= 4.2 or # - server version is >= 4.2 or
# - server version is >= 3.2 and pipeline doesn't use $out # - server version is >= 3.2 and pipeline doesn't use $out
if ("readConcern" not in cmd) and ( if ("readConcern" not in cmd) and (
not self._performs_write or (sock_info.max_wire_version >= 8) not self._performs_write or (connection.max_wire_version >= 8)
): ):
read_concern = self._target.read_concern read_concern = self._target.read_concern
else: else:
@ -161,7 +161,7 @@ class _AggregationCommand:
write_concern = None write_concern = None
# Run command. # Run command.
result = sock_info.command( result = connection.command(
self._database.name, self._database.name,
cmd, cmd,
read_preference, read_preference,
@ -176,7 +176,7 @@ class _AggregationCommand:
) )
if self._result_processor: if self._result_processor:
self._result_processor(result, sock_info) self._result_processor(result, connection)
# Extract cursor from result or mock/fake one if necessary. # Extract cursor from result or mock/fake one if necessary.
if "cursor" in result: if "cursor" in result:
@ -193,14 +193,14 @@ class _AggregationCommand:
cmd_cursor = self._cursor_class( cmd_cursor = self._cursor_class(
self._cursor_collection(cursor), self._cursor_collection(cursor),
cursor, cursor,
sock_info.address, connection.address,
batch_size=self._batch_size or 0, batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms, max_await_time_ms=self._max_await_time_ms,
session=session, session=session,
explicit_session=self._explicit_session, explicit_session=self._explicit_session,
comment=self._options.get("comment"), comment=self._options.get("comment"),
) )
cmd_cursor._maybe_pin_connection(sock_info) cmd_cursor._maybe_pin_connection(connection)
return cmd_cursor return cmd_cursor

View File

@ -35,7 +35,7 @@ from pymongo.saslprep import saslprep
if TYPE_CHECKING: if TYPE_CHECKING:
from pymongo.hello import Hello from pymongo.hello import Hello
from pymongo.pool import SocketInfo from pymongo.pool import Connection
HAVE_KERBEROS = True HAVE_KERBEROS = True
_USE_PRINCIPAL = False _USE_PRINCIPAL = False
@ -221,7 +221,7 @@ def _authenticate_scram_start(
def _authenticate_scram( def _authenticate_scram(
credentials: MongoCredential, sock_info: SocketInfo, mechanism: str credentials: MongoCredential, connection: Connection, mechanism: str
) -> None: ) -> None:
"""Authenticate using SCRAM.""" """Authenticate using SCRAM."""
username = credentials.username username = credentials.username
@ -239,13 +239,13 @@ def _authenticate_scram(
# Make local # Make local
_hmac = hmac.HMAC _hmac = hmac.HMAC
ctx = sock_info.auth_ctx ctx = connection.auth_ctx
if ctx and ctx.speculate_succeeded(): if ctx and ctx.speculate_succeeded():
nonce, first_bare = ctx.scram_data nonce, first_bare = ctx.scram_data
res = ctx.speculative_authenticate res = ctx.speculative_authenticate
else: else:
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism) nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
res = sock_info.command(source, cmd) res = connection.command(source, cmd)
server_first = res["payload"] server_first = res["payload"]
parsed = _parse_scram_response(server_first) parsed = _parse_scram_response(server_first)
@ -285,7 +285,7 @@ def _authenticate_scram(
("payload", Binary(client_final)), ("payload", Binary(client_final)),
] ]
) )
res = sock_info.command(source, cmd) res = connection.command(source, cmd)
parsed = _parse_scram_response(res["payload"]) parsed = _parse_scram_response(res["payload"])
if not hmac.compare_digest(parsed[b"v"], server_sig): if not hmac.compare_digest(parsed[b"v"], server_sig):
@ -301,7 +301,7 @@ def _authenticate_scram(
("payload", Binary(b"")), ("payload", Binary(b"")),
] ]
) )
res = sock_info.command(source, cmd) res = connection.command(source, cmd)
if not res["done"]: if not res["done"]:
raise OperationFailure("SASL conversation failed to complete.") raise OperationFailure("SASL conversation failed to complete.")
@ -345,7 +345,7 @@ def _canonicalize_hostname(hostname: str) -> str:
return name[0].lower() return name[0].lower()
def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> None: def _authenticate_gssapi(credentials: MongoCredential, connection: Connection) -> None:
"""Authenticate using GSSAPI.""" """Authenticate using GSSAPI."""
if not HAVE_KERBEROS: if not HAVE_KERBEROS:
raise ConfigurationError( raise ConfigurationError(
@ -358,7 +358,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
props = credentials.mechanism_properties props = credentials.mechanism_properties
# Starting here and continuing through the while loop below - establish # Starting here and continuing through the while loop below - establish
# the security context. See RFC 4752, Section 3.1, first paragraph. # the security context. See RFC 4752, Section 3.1, first paragraph.
host = sock_info.address[0] host = connection.address[0]
if props.canonicalize_host_name: if props.canonicalize_host_name:
host = _canonicalize_hostname(host) host = _canonicalize_hostname(host)
service = props.service_name + "@" + host service = props.service_name + "@" + host
@ -413,7 +413,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
("autoAuthorize", 1), ("autoAuthorize", 1),
] ]
) )
response = sock_info.command("$external", cmd) response = connection.command("$external", cmd)
# Limit how many times we loop to catch protocol / library issues # Limit how many times we loop to catch protocol / library issues
for _ in range(10): for _ in range(10):
@ -430,7 +430,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
("payload", payload), ("payload", payload),
] ]
) )
response = sock_info.command("$external", cmd) response = connection.command("$external", cmd)
if result == kerberos.AUTH_GSS_COMPLETE: if result == kerberos.AUTH_GSS_COMPLETE:
break break
@ -453,7 +453,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
("payload", payload), ("payload", payload),
] ]
) )
sock_info.command("$external", cmd) connection.command("$external", cmd)
finally: finally:
kerberos.authGSSClientClean(ctx) kerberos.authGSSClientClean(ctx)
@ -462,7 +462,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
raise OperationFailure(str(exc)) raise OperationFailure(str(exc))
def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) -> None: def _authenticate_plain(credentials: MongoCredential, connection: Connection) -> None:
"""Authenticate using SASL PLAIN (RFC 4616)""" """Authenticate using SASL PLAIN (RFC 4616)"""
source = credentials.source source = credentials.source
username = credentials.username username = credentials.username
@ -476,52 +476,52 @@ def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) ->
("autoAuthorize", 1), ("autoAuthorize", 1),
] ]
) )
sock_info.command(source, cmd) connection.command(source, cmd)
def _authenticate_x509(credentials: MongoCredential, sock_info: SocketInfo) -> None: def _authenticate_x509(credentials: MongoCredential, connection: Connection) -> None:
"""Authenticate using MONGODB-X509.""" """Authenticate using MONGODB-X509."""
ctx = sock_info.auth_ctx ctx = connection.auth_ctx
if ctx and ctx.speculate_succeeded(): if ctx and ctx.speculate_succeeded():
# MONGODB-X509 is done after the speculative auth step. # MONGODB-X509 is done after the speculative auth step.
return return
cmd = _X509Context(credentials, sock_info.address).speculate_command() cmd = _X509Context(credentials, connection.address).speculate_command()
sock_info.command("$external", cmd) connection.command("$external", cmd)
def _authenticate_mongo_cr(credentials: MongoCredential, sock_info: SocketInfo) -> None: def _authenticate_mongo_cr(credentials: MongoCredential, connection: Connection) -> None:
"""Authenticate using MONGODB-CR.""" """Authenticate using MONGODB-CR."""
source = credentials.source source = credentials.source
username = credentials.username username = credentials.username
password = credentials.password password = credentials.password
# Get a nonce # Get a nonce
response = sock_info.command(source, {"getnonce": 1}) response = connection.command(source, {"getnonce": 1})
nonce = response["nonce"] nonce = response["nonce"]
key = _auth_key(nonce, username, password) key = _auth_key(nonce, username, password)
# Actually authenticate # Actually authenticate
query = SON([("authenticate", 1), ("user", username), ("nonce", nonce), ("key", key)]) query = SON([("authenticate", 1), ("user", username), ("nonce", nonce), ("key", key)])
sock_info.command(source, query) connection.command(source, query)
def _authenticate_default(credentials: MongoCredential, sock_info: SocketInfo) -> None: def _authenticate_default(credentials: MongoCredential, connection: Connection) -> None:
if sock_info.max_wire_version >= 7: if connection.max_wire_version >= 7:
if sock_info.negotiated_mechs: if connection.negotiated_mechs:
mechs = sock_info.negotiated_mechs mechs = connection.negotiated_mechs
else: else:
source = credentials.source source = credentials.source
cmd = sock_info.hello_cmd() cmd = connection.hello_cmd()
cmd["saslSupportedMechs"] = source + "." + credentials.username cmd["saslSupportedMechs"] = source + "." + credentials.username
mechs = sock_info.command(source, cmd, publish_events=False).get( mechs = connection.command(source, cmd, publish_events=False).get(
"saslSupportedMechs", [] "saslSupportedMechs", []
) )
if "SCRAM-SHA-256" in mechs: if "SCRAM-SHA-256" in mechs:
return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-256") return _authenticate_scram(credentials, connection, "SCRAM-SHA-256")
else: else:
return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1") return _authenticate_scram(credentials, connection, "SCRAM-SHA-1")
else: else:
return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1") return _authenticate_scram(credentials, connection, "SCRAM-SHA-1")
_AUTH_MAP: Mapping[str, Callable] = { _AUTH_MAP: Mapping[str, Callable] = {
@ -606,12 +606,12 @@ _SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = {
def authenticate( def authenticate(
credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool = False credentials: MongoCredential, connection: Connection, reauthenticate: bool = False
) -> None: ) -> None:
"""Authenticate sock_info.""" """Authenticate connection."""
mechanism = credentials.mechanism mechanism = credentials.mechanism
auth_func = _AUTH_MAP[mechanism] auth_func = _AUTH_MAP[mechanism]
if mechanism == "MONGODB-OIDC": if mechanism == "MONGODB-OIDC":
_authenticate_oidc(credentials, sock_info, reauthenticate) _authenticate_oidc(credentials, connection, reauthenticate)
else: else:
auth_func(credentials, sock_info) auth_func(credentials, connection)

View File

@ -49,7 +49,7 @@ from pymongo.errors import ConfigurationError, OperationFailure
if TYPE_CHECKING: if TYPE_CHECKING:
from bson.typings import _ReadableBuffer from bson.typings import _ReadableBuffer
from pymongo.auth import MongoCredential from pymongo.auth import MongoCredential
from pymongo.pool import SocketInfo from pymongo.pool import Connection
class _AwsSaslContext(AwsSaslContext): # type: ignore class _AwsSaslContext(AwsSaslContext): # type: ignore
@ -67,7 +67,7 @@ class _AwsSaslContext(AwsSaslContext): # type: ignore
return bson.decode(data) return bson.decode(data)
def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> None: def _authenticate_aws(credentials: MongoCredential, connection: Connection) -> None:
"""Authenticate using MONGODB-AWS.""" """Authenticate using MONGODB-AWS."""
if not _HAVE_MONGODB_AWS: if not _HAVE_MONGODB_AWS:
raise ConfigurationError( raise ConfigurationError(
@ -75,7 +75,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No
"install with: python -m pip install 'pymongo[aws]'" "install with: python -m pip install 'pymongo[aws]'"
) )
if sock_info.max_wire_version < 9: if connection.max_wire_version < 9:
raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later") raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later")
try: try:
@ -90,7 +90,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No
client_first = SON( client_first = SON(
[("saslStart", 1), ("mechanism", "MONGODB-AWS"), ("payload", client_payload)] [("saslStart", 1), ("mechanism", "MONGODB-AWS"), ("payload", client_payload)]
) )
server_first = sock_info.command("$external", client_first) server_first = connection.command("$external", client_first)
res = server_first res = server_first
# Limit how many times we loop to catch protocol / library issues # Limit how many times we loop to catch protocol / library issues
for _ in range(10): for _ in range(10):
@ -102,7 +102,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No
("payload", client_payload), ("payload", client_payload),
] ]
) )
res = sock_info.command("$external", cmd) res = connection.command("$external", cmd)
if res["done"]: if res["done"]:
# SASL complete. # SASL complete.
break break

View File

@ -29,7 +29,7 @@ from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE
if TYPE_CHECKING: if TYPE_CHECKING:
from pymongo.auth import MongoCredential from pymongo.auth import MongoCredential
from pymongo.pool import SocketInfo from pymongo.pool import Connection
@dataclass @dataclass
@ -243,24 +243,24 @@ class _OIDCAuthenticator:
self.token_exp_utc = None self.token_exp_utc = None
def run_command( def run_command(
self, sock_info: SocketInfo, cmd: Mapping[str, Any] self, connection: Connection, cmd: Mapping[str, Any]
) -> Optional[Mapping[str, Any]]: ) -> Optional[Mapping[str, Any]]:
try: try:
return sock_info.command("$external", cmd, no_reauth=True) # type: ignore[call-arg] return connection.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
except OperationFailure as exc: except OperationFailure as exc:
self.clear() self.clear()
if exc.code == _REAUTHENTICATION_REQUIRED_CODE: if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
if "jwt" in bson.decode(cmd["payload"]): if "jwt" in bson.decode(cmd["payload"]):
if self.idp_info_gen_id > self.reauth_gen_id: if self.idp_info_gen_id > self.reauth_gen_id:
raise raise
return self.authenticate(sock_info, reauthenticate=True) return self.authenticate(connection, reauthenticate=True)
raise raise
def authenticate( def authenticate(
self, sock_info: SocketInfo, reauthenticate: bool = False self, connection: Connection, reauthenticate: bool = False
) -> Optional[Mapping[str, Any]]: ) -> Optional[Mapping[str, Any]]:
if reauthenticate: if reauthenticate:
prev_id = getattr(sock_info, "oidc_token_gen_id", None) prev_id = getattr(connection, "oidc_token_gen_id", None)
# Check if we've already changed tokens. # Check if we've already changed tokens.
if prev_id == self.token_gen_id: if prev_id == self.token_gen_id:
self.reauth_gen_id = self.idp_info_gen_id self.reauth_gen_id = self.idp_info_gen_id
@ -268,7 +268,7 @@ class _OIDCAuthenticator:
if not self.properties.refresh_token_callback: if not self.properties.refresh_token_callback:
self.clear() self.clear()
ctx = sock_info.auth_ctx ctx = connection.auth_ctx
cmd = None cmd = None
if ctx and ctx.speculate_succeeded(): if ctx and ctx.speculate_succeeded():
@ -276,10 +276,10 @@ class _OIDCAuthenticator:
else: else:
cmd = self.auth_start_cmd() cmd = self.auth_start_cmd()
assert cmd is not None assert cmd is not None
resp = self.run_command(sock_info, cmd) resp = self.run_command(connection, cmd)
if resp["done"]: if resp["done"]:
sock_info.oidc_token_gen_id = self.token_gen_id connection.oidc_token_gen_id = self.token_gen_id
return None return None
server_resp: Dict = bson.decode(resp["payload"]) server_resp: Dict = bson.decode(resp["payload"])
@ -289,7 +289,7 @@ class _OIDCAuthenticator:
conversation_id = resp["conversationId"] conversation_id = resp["conversationId"]
token = self.get_current_token() token = self.get_current_token()
sock_info.oidc_token_gen_id = self.token_gen_id connection.oidc_token_gen_id = self.token_gen_id
bin_payload = Binary(bson.encode({"jwt": token})) bin_payload = Binary(bson.encode({"jwt": token}))
cmd = SON( cmd = SON(
[ [
@ -298,7 +298,7 @@ class _OIDCAuthenticator:
("payload", bin_payload), ("payload", bin_payload),
] ]
) )
resp = self.run_command(sock_info, cmd) resp = self.run_command(connection, cmd)
if not resp["done"]: if not resp["done"]:
self.clear() self.clear()
raise OperationFailure("SASL conversation failed to complete.") raise OperationFailure("SASL conversation failed to complete.")
@ -306,8 +306,8 @@ class _OIDCAuthenticator:
def _authenticate_oidc( def _authenticate_oidc(
credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool credentials: MongoCredential, connection: Connection, reauthenticate: bool
) -> Optional[Mapping[str, Any]]: ) -> Optional[Mapping[str, Any]]:
"""Authenticate using MONGODB-OIDC.""" """Authenticate using MONGODB-OIDC."""
authenticator = _get_authenticator(credentials, sock_info.address) authenticator = _get_authenticator(credentials, connection.address)
return authenticator.authenticate(sock_info, reauthenticate=reauthenticate) return authenticator.authenticate(connection, reauthenticate=reauthenticate)

View File

@ -65,7 +65,7 @@ from pymongo.write_concern import WriteConcern
if TYPE_CHECKING: if TYPE_CHECKING:
from pymongo.collection import Collection from pymongo.collection import Collection
from pymongo.pool import SocketInfo from pymongo.pool import Connection
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
_DELETE_ALL: int = 0 _DELETE_ALL: int = 0
@ -311,7 +311,7 @@ class _Bulk:
generator: Iterator[Any], generator: Iterator[Any],
write_concern: WriteConcern, write_concern: WriteConcern,
session: Optional[ClientSession], session: Optional[ClientSession],
sock_info: SocketInfo, connection: Connection,
op_id: int, op_id: int,
retryable: bool, retryable: bool,
full_result: MutableMapping[str, Any], full_result: MutableMapping[str, Any],
@ -326,9 +326,9 @@ class _Bulk:
self.next_run = None self.next_run = None
run = self.current_run run = self.current_run
# sock_info.command validates the session, but we use # connection.command validates the session, but we use
# sock_info.write_command. # connection.write_command.
sock_info.validate_session(client, session) connection.validate_session(client, session)
last_run = False last_run = False
while run: while run:
@ -341,7 +341,7 @@ class _Bulk:
bwc = self.bulk_ctx_class( bwc = self.bulk_ctx_class(
db_name, db_name,
cmd_name, cmd_name,
sock_info, connection,
op_id, op_id,
listeners, listeners,
session, session,
@ -369,11 +369,11 @@ class _Bulk:
if retryable and not self.started_retryable_write: if retryable and not self.started_retryable_write:
session._start_retryable_write() session._start_retryable_write()
self.started_retryable_write = True self.started_retryable_write = True
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, sock_info) session._apply_to(cmd, retryable, ReadPreference.PRIMARY, connection)
sock_info.send_cluster_time(cmd, session, client) connection.send_cluster_time(cmd, session, client)
sock_info.add_server_api(cmd) connection.add_server_api(cmd)
# CSOT: apply timeout before encoding the command. # CSOT: apply timeout before encoding the command.
sock_info.apply_timeout(client, cmd) connection.apply_timeout(client, cmd)
ops = islice(run.ops, run.idx_offset, None) ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible in one command. # Run as many ops as possible in one command.
@ -430,13 +430,13 @@ class _Bulk:
op_id = _randint() op_id = _randint()
def retryable_bulk( def retryable_bulk(
session: Optional[ClientSession], sock_info: SocketInfo, retryable: bool session: Optional[ClientSession], connection: Connection, retryable: bool
) -> None: ) -> None:
self._execute_command( self._execute_command(
generator, generator,
write_concern, write_concern,
session, session,
sock_info, connection,
op_id, op_id,
retryable, retryable,
full_result, full_result,
@ -450,7 +450,7 @@ class _Bulk:
_raise_bulk_write_error(full_result) _raise_bulk_write_error(full_result)
return full_result return full_result
def execute_op_msg_no_results(self, sock_info: SocketInfo, generator: Iterator[Any]) -> None: def execute_op_msg_no_results(self, connection: Connection, generator: Iterator[Any]) -> None:
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" """Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
db_name = self.collection.database.name db_name = self.collection.database.name
client = self.collection.database.client client = self.collection.database.client
@ -466,7 +466,7 @@ class _Bulk:
bwc = self.bulk_ctx_class( bwc = self.bulk_ctx_class(
db_name, db_name,
cmd_name, cmd_name,
sock_info, connection,
op_id, op_id,
listeners, listeners,
None, None,
@ -482,7 +482,7 @@ class _Bulk:
("writeConcern", {"w": 0}), ("writeConcern", {"w": 0}),
] ]
) )
sock_info.add_server_api(cmd) connection.add_server_api(cmd)
ops = islice(run.ops, run.idx_offset, None) ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible. # Run as many ops as possible.
to_send = bwc.execute_unack(cmd, ops, client) to_send = bwc.execute_unack(cmd, ops, client)
@ -491,7 +491,7 @@ class _Bulk:
def execute_command_no_results( def execute_command_no_results(
self, self,
sock_info: SocketInfo, connection: Connection,
generator: Iterator[Any], generator: Iterator[Any],
write_concern: WriteConcern, write_concern: WriteConcern,
) -> None: ) -> None:
@ -516,7 +516,7 @@ class _Bulk:
generator, generator,
initial_write_concern, initial_write_concern,
None, None,
sock_info, connection,
op_id, op_id,
False, False,
full_result, full_result,
@ -527,7 +527,7 @@ class _Bulk:
def execute_no_results( def execute_no_results(
self, self,
sock_info: SocketInfo, connection: Connection,
generator: Iterator[Any], generator: Iterator[Any],
write_concern: WriteConcern, write_concern: WriteConcern,
) -> None: ) -> None:
@ -538,11 +538,11 @@ class _Bulk:
raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.") raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.")
# Guard against unsupported unacknowledged writes. # Guard against unsupported unacknowledged writes.
unack = write_concern and not write_concern.acknowledged unack = write_concern and not write_concern.acknowledged
if unack and self.uses_hint_delete and sock_info.max_wire_version < 9: if unack and self.uses_hint_delete and connection.max_wire_version < 9:
raise ConfigurationError( raise ConfigurationError(
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands." "Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
) )
if unack and self.uses_hint_update and sock_info.max_wire_version < 8: if unack and self.uses_hint_update and connection.max_wire_version < 8:
raise ConfigurationError( raise ConfigurationError(
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
) )
@ -553,8 +553,8 @@ class _Bulk:
) )
if self.ordered: if self.ordered:
return self.execute_command_no_results(sock_info, generator, write_concern) return self.execute_command_no_results(connection, generator, write_concern)
return self.execute_op_msg_no_results(sock_info, generator) return self.execute_op_msg_no_results(connection, generator)
def execute(self, write_concern: WriteConcern, session: Optional[ClientSession]) -> Any: def execute(self, write_concern: WriteConcern, session: Optional[ClientSession]) -> Any:
"""Execute operations.""" """Execute operations."""
@ -573,8 +573,8 @@ class _Bulk:
client = self.collection.database.client client = self.collection.database.client
if not write_concern.acknowledged: if not write_concern.acknowledged:
with client._socket_for_writes(session) as sock_info: with client._socket_for_writes(session) as connection:
self.execute_no_results(sock_info, generator, write_concern) self.execute_no_results(connection, generator, write_concern)
return None return None
else: else:
return self.execute_command(generator, write_concern, session) return self.execute_command(generator, write_concern, session)

View File

@ -78,7 +78,7 @@ if TYPE_CHECKING:
from pymongo.collection import Collection from pymongo.collection import Collection
from pymongo.database import Database from pymongo.database import Database
from pymongo.mongo_client import MongoClient from pymongo.mongo_client import MongoClient
from pymongo.pool import SocketInfo from pymongo.pool import Connection
def _resumable(exc: PyMongoError) -> bool: def _resumable(exc: PyMongoError) -> bool:
@ -213,7 +213,7 @@ class ChangeStream(Generic[_DocumentType]):
full_pipeline.extend(self._pipeline) full_pipeline.extend(self._pipeline)
return full_pipeline return full_pipeline
def _process_result(self, result: Mapping[str, Any], sock_info: SocketInfo) -> None: def _process_result(self, result: Mapping[str, Any], connection: Connection) -> None:
"""Callback that caches the postBatchResumeToken or """Callback that caches the postBatchResumeToken or
startAtOperationTime from a changeStream aggregate command response startAtOperationTime from a changeStream aggregate command response
containing an empty batch of change documents. containing an empty batch of change documents.
@ -228,7 +228,7 @@ class ChangeStream(Generic[_DocumentType]):
self._start_at_operation_time is None self._start_at_operation_time is None
and self._uses_resume_after is False and self._uses_resume_after is False
and self._uses_start_after is False and self._uses_start_after is False
and sock_info.max_wire_version >= 7 and connection.max_wire_version >= 7
): ):
self._start_at_operation_time = result.get("operationTime") self._start_at_operation_time = result.get("operationTime")
# PYTHON-2181: informative error on missing operationTime. # PYTHON-2181: informative error on missing operationTime.

View File

@ -24,7 +24,7 @@ from pymongo.common import validate_boolean
from pymongo.compression_support import CompressionSettings from pymongo.compression_support import CompressionSettings
from pymongo.errors import ConfigurationError from pymongo.errors import ConfigurationError
from pymongo.monitoring import _EventListeners from pymongo.monitoring import _EventListeners
from pymongo.pool import PoolOptions from pymongo.pool import ConnectionProtocol, PoolOptions
from pymongo.read_concern import ReadConcern from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ( from pymongo.read_preferences import (
_ServerMode, _ServerMode,
@ -162,6 +162,13 @@ def _parse_pool_options(
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options) ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
load_balanced = options.get("loadbalanced") load_balanced = options.get("loadbalanced")
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING) max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
grpc_enabled = options.get("grpc", False)
if grpc_enabled:
protocol = ConnectionProtocol.GRPC
else:
protocol = ConnectionProtocol.TCP_SOCKET
return PoolOptions( return PoolOptions(
max_pool_size, max_pool_size,
min_pool_size, min_pool_size,
@ -179,6 +186,7 @@ def _parse_pool_options(
server_api=server_api, server_api=server_api,
load_balanced=load_balanced, load_balanced=load_balanced,
credentials=credentials, credentials=credentials,
protocol=protocol,
) )

View File

@ -178,7 +178,7 @@ from pymongo.write_concern import WriteConcern
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
from pymongo.pool import SocketInfo from pymongo.pool import Connection
from pymongo.server import Server from pymongo.server import Server
@ -412,17 +412,17 @@ class _Transaction:
return self.state == _TxnState.STARTING return self.state == _TxnState.STARTING
@property @property
def pinned_conn(self) -> Optional[SocketInfo]: def pinned_conn(self) -> Optional[Connection]:
if self.active() and self.sock_mgr: if self.active() and self.sock_mgr:
return self.sock_mgr.sock return self.sock_mgr.sock
return None return None
def pin(self, server: Server, sock_info: SocketInfo) -> None: def pin(self, server: Server, connection: Connection) -> None:
self.sharded = True self.sharded = True
self.pinned_address = server.description.address self.pinned_address = server.description.address
if server.description.server_type == SERVER_TYPE.LoadBalancer: if server.description.server_type == SERVER_TYPE.LoadBalancer:
sock_info.pin_txn() connection.pin_txn()
self.sock_mgr = _SocketManager(sock_info, False) self.sock_mgr = _SocketManager(connection, False)
def unpin(self) -> None: def unpin(self) -> None:
self.pinned_address = None self.pinned_address = None
@ -839,12 +839,12 @@ class ClientSession:
- `command_name`: Either "commitTransaction" or "abortTransaction". - `command_name`: Either "commitTransaction" or "abortTransaction".
""" """
def func(session: ClientSession, sock_info: SocketInfo, retryable: bool) -> Dict[str, Any]: def func(session: ClientSession, connection: Connection, retryable: bool) -> Dict[str, Any]:
return self._finish_transaction(sock_info, command_name) return self._finish_transaction(connection, command_name)
return self._client._retry_internal(True, func, self, None) return self._client._retry_internal(True, func, self, None)
def _finish_transaction(self, sock_info: SocketInfo, command_name: str) -> Dict[str, Any]: def _finish_transaction(self, connection: Connection, command_name: str) -> Dict[str, Any]:
self._transaction.attempt += 1 self._transaction.attempt += 1
opts = self._transaction.opts opts = self._transaction.opts
assert opts assert opts
@ -868,7 +868,7 @@ class ClientSession:
cmd["recoveryToken"] = self._transaction.recovery_token cmd["recoveryToken"] = self._transaction.recovery_token
return self._client.admin._command( return self._client.admin._command(
sock_info, cmd, session=self, write_concern=wc, parse_write_concern_error=True connection, cmd, session=self, write_concern=wc, parse_write_concern_error=True
) )
def _advance_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None: def _advance_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None:
@ -954,13 +954,13 @@ class ClientSession:
return None return None
@property @property
def _pinned_connection(self) -> Optional[SocketInfo]: def _pinned_connection(self) -> Optional[Connection]:
"""The connection this transaction was started on.""" """The connection this transaction was started on."""
return self._transaction.pinned_conn return self._transaction.pinned_conn
def _pin(self, server: Server, sock_info: SocketInfo) -> None: def _pin(self, server: Server, connection: Connection) -> None:
"""Pin this session to the given Server or to the given connection.""" """Pin this session to the given Server or to the given connection."""
self._transaction.pin(server, sock_info) self._transaction.pin(server, connection)
def _unpin(self) -> None: def _unpin(self) -> None:
"""Unpin this session from any pinned Server.""" """Unpin this session from any pinned Server."""
@ -985,12 +985,12 @@ class ClientSession:
command: MutableMapping[str, Any], command: MutableMapping[str, Any],
is_retryable: bool, is_retryable: bool,
read_preference: ReadPreference, read_preference: ReadPreference,
sock_info: SocketInfo, connection: Connection,
) -> None: ) -> None:
self._check_ended() self._check_ended()
self._materialize() self._materialize()
if self.options.snapshot: if self.options.snapshot:
self._update_read_concern(command, sock_info) self._update_read_concern(command, connection)
self._server_session.last_use = time.monotonic() self._server_session.last_use = time.monotonic()
command["lsid"] = self._server_session.session_id command["lsid"] = self._server_session.session_id
@ -1016,7 +1016,7 @@ class ClientSession:
rc = self._transaction.opts.read_concern.document rc = self._transaction.opts.read_concern.document
if rc: if rc:
command["readConcern"] = rc command["readConcern"] = rc
self._update_read_concern(command, sock_info) self._update_read_concern(command, connection)
command["txnNumber"] = self._server_session.transaction_id command["txnNumber"] = self._server_session.transaction_id
command["autocommit"] = False command["autocommit"] = False
@ -1025,11 +1025,11 @@ class ClientSession:
self._check_ended() self._check_ended()
self._server_session.inc_transaction_id() self._server_session.inc_transaction_id()
def _update_read_concern(self, cmd: MutableMapping[str, Any], sock_info: SocketInfo) -> None: def _update_read_concern(self, cmd: MutableMapping[str, Any], connection: Connection) -> None:
if self.options.causal_consistency and self.operation_time is not None: if self.options.causal_consistency and self.operation_time is not None:
cmd.setdefault("readConcern", {})["afterClusterTime"] = self.operation_time cmd.setdefault("readConcern", {})["afterClusterTime"] = self.operation_time
if self.options.snapshot: if self.options.snapshot:
if sock_info.max_wire_version < 13: if connection.max_wire_version < 13:
raise ConfigurationError("Snapshot reads require MongoDB 5.0 or later") raise ConfigurationError("Snapshot reads require MongoDB 5.0 or later")
rc = cmd.setdefault("readConcern", {}) rc = cmd.setdefault("readConcern", {})
rc["level"] = "snapshot" rc["level"] = "snapshot"

View File

@ -126,7 +126,7 @@ if TYPE_CHECKING:
from pymongo.client_session import ClientSession from pymongo.client_session import ClientSession
from pymongo.collation import Collation from pymongo.collation import Collation
from pymongo.database import Database from pymongo.database import Database
from pymongo.pool import SocketInfo from pymongo.pool import Connection
from pymongo.read_concern import ReadConcern from pymongo.read_concern import ReadConcern
from pymongo.server import Server from pymongo.server import Server
@ -264,15 +264,15 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _socket_for_reads( def _socket_for_reads(
self, session: ClientSession self, session: ClientSession
) -> ContextManager[Tuple[SocketInfo, Union[PrimaryPreferred, Primary]]]: ) -> ContextManager[Tuple[Connection, Union[PrimaryPreferred, Primary]]]:
return self.__database.client._socket_for_reads(self._read_preference_for(session), session) return self.__database.client._socket_for_reads(self._read_preference_for(session), session)
def _socket_for_writes(self, session: Optional[ClientSession]) -> ContextManager[SocketInfo]: def _socket_for_writes(self, session: Optional[ClientSession]) -> ContextManager[Connection]:
return self.__database.client._socket_for_writes(session) return self.__database.client._socket_for_writes(session)
def _command( def _command(
self, self,
sock_info: SocketInfo, connection: Connection,
command: Mapping[str, Any], command: Mapping[str, Any],
read_preference: Optional[_ServerMode] = None, read_preference: Optional[_ServerMode] = None,
codec_options: Optional[CodecOptions] = None, codec_options: Optional[CodecOptions] = None,
@ -288,7 +288,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
"""Internal command helper. """Internal command helper.
:Parameters: :Parameters:
- `sock_info` - A SocketInfo instance. - `connection` - A Connection instance.
- `command` - The command itself, as a :class:`~bson.son.SON` instance. - `command` - The command itself, as a :class:`~bson.son.SON` instance.
- `read_preference` (optional) - The read preference to use. - `read_preference` (optional) - The read preference to use.
- `codec_options` (optional) - An instance of - `codec_options` (optional) - An instance of
@ -313,7 +313,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
The result document. The result document.
""" """
with self.__database.client._tmp_session(session) as s: with self.__database.client._tmp_session(session) as s:
return sock_info.command( return connection.command(
self.__database.name, self.__database.name,
command, command,
read_preference or self._read_preference_for(session), read_preference or self._read_preference_for(session),
@ -348,16 +348,16 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
if "size" in options: if "size" in options:
options["size"] = float(options["size"]) options["size"] = float(options["size"])
cmd.update(options) cmd.update(options)
with self._socket_for_writes(session) as sock_info: with self._socket_for_writes(session) as connection:
if qev2_required and sock_info.max_wire_version < 21: if qev2_required and connection.max_wire_version < 21:
raise ConfigurationError( raise ConfigurationError(
"Driver support of Queryable Encryption is incompatible with server. " "Driver support of Queryable Encryption is incompatible with server. "
"Upgrade server to use Queryable Encryption. " "Upgrade server to use Queryable Encryption. "
f"Got maxWireVersion {sock_info.max_wire_version} but need maxWireVersion >= 21 (MongoDB >=7.0)" f"Got maxWireVersion {connection.max_wire_version} but need maxWireVersion >= 21 (MongoDB >=7.0)"
) )
self._command( self._command(
sock_info, connection,
cmd, cmd,
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session), write_concern=self._write_concern_for(session),
@ -597,12 +597,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
command["comment"] = comment command["comment"] = comment
def _insert_command( def _insert_command(
session: ClientSession, sock_info: SocketInfo, retryable_write: bool session: ClientSession, connection: Connection, retryable_write: bool
) -> None: ) -> None:
if bypass_doc_val: if bypass_doc_val:
command["bypassDocumentValidation"] = True command["bypassDocumentValidation"] = True
result = sock_info.command( result = connection.command(
self.__database.name, self.__database.name,
command, command,
write_concern=write_concern, write_concern=write_concern,
@ -765,7 +765,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _update( def _update(
self, self,
sock_info: SocketInfo, connection: Connection,
criteria: Mapping[str, Any], criteria: Mapping[str, Any],
document: Union[Mapping[str, Any], _Pipeline], document: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False, upsert: bool = False,
@ -801,7 +801,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
else: else:
update_doc["arrayFilters"] = array_filters update_doc["arrayFilters"] = array_filters
if hint is not None: if hint is not None:
if not acknowledged and sock_info.max_wire_version < 8: if not acknowledged and connection.max_wire_version < 8:
raise ConfigurationError( raise ConfigurationError(
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
) )
@ -821,7 +821,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
# The command result has to be published for APM unmodified # The command result has to be published for APM unmodified
# so we make a shallow copy here before adding updatedExisting. # so we make a shallow copy here before adding updatedExisting.
result = sock_info.command( result = connection.command(
self.__database.name, self.__database.name,
command, command,
write_concern=write_concern, write_concern=write_concern,
@ -865,10 +865,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
"""Internal update / replace helper.""" """Internal update / replace helper."""
def _update( def _update(
session: Optional[ClientSession], sock_info: SocketInfo, retryable_write: bool session: Optional[ClientSession], connection: Connection, retryable_write: bool
) -> Optional[Mapping[str, Any]]: ) -> Optional[Mapping[str, Any]]:
return self._update( return self._update(
sock_info, connection,
criteria, criteria,
document, document,
upsert=upsert, upsert=upsert,
@ -1255,7 +1255,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _delete( def _delete(
self, self,
sock_info: SocketInfo, connection: Connection,
criteria: Mapping[str, Any], criteria: Mapping[str, Any],
multi: bool, multi: bool,
write_concern: Optional[WriteConcern] = None, write_concern: Optional[WriteConcern] = None,
@ -1280,7 +1280,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
else: else:
delete_doc["collation"] = collation delete_doc["collation"] = collation
if hint is not None: if hint is not None:
if not acknowledged and sock_info.max_wire_version < 9: if not acknowledged and connection.max_wire_version < 9:
raise ConfigurationError( raise ConfigurationError(
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands." "Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
) )
@ -1297,7 +1297,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
command["comment"] = comment command["comment"] = comment
# Delete command. # Delete command.
result = sock_info.command( result = connection.command(
self.__database.name, self.__database.name,
command, command,
write_concern=write_concern, write_concern=write_concern,
@ -1325,10 +1325,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
"""Internal delete helper.""" """Internal delete helper."""
def _delete( def _delete(
session: Optional[ClientSession], sock_info: SocketInfo, retryable_write: bool session: Optional[ClientSession], connection: Connection, retryable_write: bool
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
return self._delete( return self._delete(
sock_info, connection,
criteria, criteria,
multi, multi,
write_concern=write_concern, write_concern=write_concern,
@ -1738,7 +1738,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _count_cmd( def _count_cmd(
self, self,
session: ClientSession, session: ClientSession,
sock_info: SocketInfo, connection: Connection,
read_preference: Optional[_ServerMode], read_preference: Optional[_ServerMode],
cmd: Mapping[str, Any], cmd: Mapping[str, Any],
collation: Optional[Collation], collation: Optional[Collation],
@ -1747,7 +1747,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
# XXX: "ns missing" checks can be removed when we drop support for # XXX: "ns missing" checks can be removed when we drop support for
# MongoDB 3.0, see SERVER-17051. # MongoDB 3.0, see SERVER-17051.
res = self._command( res = self._command(
sock_info, connection,
cmd, cmd,
read_preference=read_preference, read_preference=read_preference,
allowable_errors=["ns missing"], allowable_errors=["ns missing"],
@ -1762,7 +1762,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _aggregate_one_result( def _aggregate_one_result(
self, self,
sock_info: SocketInfo, connection: Connection,
read_preference: Optional[_ServerMode], read_preference: Optional[_ServerMode],
cmd: Mapping[str, Any], cmd: Mapping[str, Any],
collation: Optional[_CollationIn], collation: Optional[_CollationIn],
@ -1770,7 +1770,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
) -> Optional[Mapping[str, Any]]: ) -> Optional[Mapping[str, Any]]:
"""Internal helper to run an aggregate that returns a single result.""" """Internal helper to run an aggregate that returns a single result."""
result = self._command( result = self._command(
sock_info, connection,
cmd, cmd,
read_preference, read_preference,
allowable_errors=[26], # Ignore NamespaceNotFound. allowable_errors=[26], # Ignore NamespaceNotFound.
@ -1821,12 +1821,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _cmd( def _cmd(
session: ClientSession, session: ClientSession,
server: Server, server: Server,
sock_info: SocketInfo, connection: Connection,
read_preference: Optional[_ServerMode], read_preference: Optional[_ServerMode],
) -> int: ) -> int:
cmd: SON[str, Any] = SON([("count", self.__name)]) cmd: SON[str, Any] = SON([("count", self.__name)])
cmd.update(kwargs) cmd.update(kwargs)
return self._count_cmd(session, sock_info, read_preference, cmd, collation=None) return self._count_cmd(session, connection, read_preference, cmd, collation=None)
return self._retryable_non_cursor_read(_cmd, None) return self._retryable_non_cursor_read(_cmd, None)
@ -1910,10 +1910,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _cmd( def _cmd(
session: ClientSession, session: ClientSession,
server: Server, server: Server,
sock_info: SocketInfo, connection: Connection,
read_preference: Optional[_ServerMode], read_preference: Optional[_ServerMode],
) -> int: ) -> int:
result = self._aggregate_one_result(sock_info, read_preference, cmd, collation, session) result = self._aggregate_one_result(
connection, read_preference, cmd, collation, session
)
if not result: if not result:
return 0 return 0
return result["n"] return result["n"]
@ -1922,7 +1924,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _retryable_non_cursor_read( def _retryable_non_cursor_read(
self, self,
func: Callable[[ClientSession, Server, SocketInfo, Optional[_ServerMode]], T], func: Callable[[ClientSession, Server, Connection, Optional[_ServerMode]], T],
session: Optional[ClientSession], session: Optional[ClientSession],
) -> T: ) -> T:
"""Non-cursor read helper to handle implicit session creation.""" """Non-cursor read helper to handle implicit session creation."""
@ -1993,8 +1995,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
command (like maxTimeMS) can be passed as keyword arguments. command (like maxTimeMS) can be passed as keyword arguments.
""" """
names = [] names = []
with self._socket_for_writes(session) as sock_info: with self._socket_for_writes(session) as connection:
supports_quorum = sock_info.max_wire_version >= 9 supports_quorum = connection.max_wire_version >= 9
def gen_indexes() -> Iterator[Mapping[str, Any]]: def gen_indexes() -> Iterator[Mapping[str, Any]]:
for index in indexes: for index in indexes:
@ -2015,7 +2017,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
) )
self._command( self._command(
sock_info, connection,
cmd, cmd,
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
@ -2236,9 +2238,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs) cmd.update(kwargs)
if comment is not None: if comment is not None:
cmd["comment"] = comment cmd["comment"] = comment
with self._socket_for_writes(session) as sock_info: with self._socket_for_writes(session) as connection:
self._command( self._command(
sock_info, connection,
cmd, cmd,
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26], allowable_errors=["ns not found", 26],
@ -2285,7 +2287,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _cmd( def _cmd(
session: ClientSession, session: ClientSession,
server: Server, server: Server,
sock_info: SocketInfo, connection: Connection,
read_preference: _ServerMode, read_preference: _ServerMode,
) -> CommandCursor[_DocumentType]: ) -> CommandCursor[_DocumentType]:
cmd = SON([("listIndexes", self.__name), ("cursor", {})]) cmd = SON([("listIndexes", self.__name), ("cursor", {})])
@ -2294,7 +2296,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
try: try:
cursor = self._command( cursor = self._command(
sock_info, cmd, read_preference, codec_options, session=session connection, cmd, read_preference, codec_options, session=session
)["cursor"] )["cursor"]
except OperationFailure as exc: except OperationFailure as exc:
# Ignore NamespaceNotFound errors to match the behavior # Ignore NamespaceNotFound errors to match the behavior
@ -2305,12 +2307,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd_cursor = CommandCursor( cmd_cursor = CommandCursor(
coll, coll,
cursor, cursor,
sock_info.address, connection.address,
session=session, session=session,
explicit_session=explicit_session, explicit_session=explicit_session,
comment=cmd.get("comment"), comment=cmd.get("comment"),
) )
cmd_cursor._maybe_pin_connection(sock_info) cmd_cursor._maybe_pin_connection(connection)
return cmd_cursor return cmd_cursor
with self.__database.client._tmp_session(session, False) as s: with self.__database.client._tmp_session(session, False) as s:
@ -2479,9 +2481,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd = SON([("createSearchIndexes", self.name), ("indexes", list(gen_indexes()))]) cmd = SON([("createSearchIndexes", self.name), ("indexes", list(gen_indexes()))])
cmd.update(kwargs) cmd.update(kwargs)
with self._socket_for_writes(session) as sock_info: with self._socket_for_writes(session) as connection:
resp = self._command( resp = self._command(
sock_info, connection,
cmd, cmd,
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
@ -2514,9 +2516,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs) cmd.update(kwargs)
if comment is not None: if comment is not None:
cmd["comment"] = comment cmd["comment"] = comment
with self._socket_for_writes(session) as sock_info: with self._socket_for_writes(session) as connection:
self._command( self._command(
sock_info, connection,
cmd, cmd,
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26], allowable_errors=["ns not found", 26],
@ -2551,9 +2553,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs) cmd.update(kwargs)
if comment is not None: if comment is not None:
cmd["comment"] = comment cmd["comment"] = comment
with self._socket_for_writes(session) as sock_info: with self._socket_for_writes(session) as connection:
self._command( self._command(
sock_info, connection,
cmd, cmd,
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26], allowable_errors=["ns not found", 26],
@ -2980,9 +2982,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd["comment"] = comment cmd["comment"] = comment
write_concern = self._write_concern_for_cmd(cmd, session) write_concern = self._write_concern_for_cmd(cmd, session)
with self._socket_for_writes(session) as sock_info: with self._socket_for_writes(session) as connection:
with self.__database.client._tmp_session(session) as s: with self.__database.client._tmp_session(session) as s:
return sock_info.command( return connection.command(
"admin", "admin",
cmd, cmd,
write_concern=write_concern, write_concern=write_concern,
@ -3049,11 +3051,11 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _cmd( def _cmd(
session: ClientSession, session: ClientSession,
server: Server, server: Server,
sock_info: SocketInfo, connection: Connection,
read_preference: Optional[_ServerMode], read_preference: Optional[_ServerMode],
) -> List: ) -> List:
return self._command( return self._command(
sock_info, connection,
cmd, cmd,
read_preference=read_preference, read_preference=read_preference,
read_concern=self.read_concern, read_concern=self.read_concern,
@ -3112,7 +3114,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
write_concern = self._write_concern_for_cmd(cmd, session) write_concern = self._write_concern_for_cmd(cmd, session)
def _find_and_modify( def _find_and_modify(
session: ClientSession, sock_info: SocketInfo, retryable_write: bool session: ClientSession, connection: Connection, retryable_write: bool
) -> Any: ) -> Any:
acknowledged = write_concern.acknowledged acknowledged = write_concern.acknowledged
if array_filters is not None: if array_filters is not None:
@ -3122,17 +3124,17 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
) )
cmd["arrayFilters"] = list(array_filters) cmd["arrayFilters"] = list(array_filters)
if hint is not None: if hint is not None:
if sock_info.max_wire_version < 8: if connection.max_wire_version < 8:
raise ConfigurationError( raise ConfigurationError(
"Must be connected to MongoDB 4.2+ to use hint on find and modify commands." "Must be connected to MongoDB 4.2+ to use hint on find and modify commands."
) )
elif not acknowledged and sock_info.max_wire_version < 9: elif not acknowledged and connection.max_wire_version < 9:
raise ConfigurationError( raise ConfigurationError(
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged find and modify commands." "Must be connected to MongoDB 4.4+ to use hint on unacknowledged find and modify commands."
) )
cmd["hint"] = hint cmd["hint"] = hint
out = self._command( out = self._command(
sock_info, connection,
cmd, cmd,
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
write_concern=write_concern, write_concern=write_concern,

View File

@ -38,7 +38,7 @@ from pymongo.typings import _Address, _DocumentType
if TYPE_CHECKING: if TYPE_CHECKING:
from pymongo.client_session import ClientSession from pymongo.client_session import ClientSession
from pymongo.collection import Collection from pymongo.collection import Collection
from pymongo.pool import SocketInfo from pymongo.pool import Connection
class CommandCursor(Generic[_DocumentType]): class CommandCursor(Generic[_DocumentType]):
@ -157,13 +157,13 @@ class CommandCursor(Generic[_DocumentType]):
""" """
return self.__postbatchresumetoken return self.__postbatchresumetoken
def _maybe_pin_connection(self, sock_info: SocketInfo) -> None: def _maybe_pin_connection(self, connection: Connection) -> None:
client = self.__collection.database.client client = self.__collection.database.client
if not client._should_pin_cursor(self.__session): if not client._should_pin_cursor(self.__session):
return return
if not self.__sock_mgr: if not self.__sock_mgr:
sock_info.pin_cursor() connection.pin_cursor()
sock_mgr = _SocketManager(sock_info, False) sock_mgr = _SocketManager(connection, False)
# Ensure the connection gets returned when the entire result is # Ensure the connection gets returned when the entire result is
# returned in the first batch. # returned in the first batch.
if self.__id == 0: if self.__id == 0:

View File

@ -750,6 +750,7 @@ KW_VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = {
"server_selector": validate_is_callable_or_none, "server_selector": validate_is_callable_or_none,
"auto_encryption_opts": validate_auto_encryption_opts_or_none, "auto_encryption_opts": validate_auto_encryption_opts_or_none,
"authoidcallowedhosts": validate_list, "authoidcallowedhosts": validate_list,
"grpc": validate_boolean,
} }
# Dictionary where keys are any URI option name, and values are the # Dictionary where keys are any URI option name, and values are the

View File

@ -64,7 +64,7 @@ if TYPE_CHECKING:
from pymongo.client_session import ClientSession from pymongo.client_session import ClientSession
from pymongo.collection import Collection from pymongo.collection import Collection
from pymongo.message import _OpMsg, _OpReply from pymongo.message import _OpMsg, _OpReply
from pymongo.pool import SocketInfo from pymongo.pool import Connection
from pymongo.read_preferences import _ServerMode from pymongo.read_preferences import _ServerMode
@ -142,8 +142,8 @@ class CursorType:
class _SocketManager: class _SocketManager:
"""Used with exhaust cursors to ensure the socket is returned.""" """Used with exhaust cursors to ensure the socket is returned."""
def __init__(self, sock: SocketInfo, more_to_come: bool): def __init__(self, sock: Connection, more_to_come: bool):
self.sock: Optional[SocketInfo] = sock self.sock: Optional[Connection] = sock
self.more_to_come = more_to_come self.more_to_come = more_to_come
self.lock = _create_lock() self.lock = _create_lock()

View File

@ -48,7 +48,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
if TYPE_CHECKING: if TYPE_CHECKING:
from pymongo.pool import SocketInfo from pymongo.pool import Connection
from pymongo.server import Server from pymongo.server import Server
@ -689,7 +689,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
@overload @overload
def _command( def _command(
self, self,
sock_info: SocketInfo, connection: Connection,
command: Union[str, MutableMapping[str, Any]], command: Union[str, MutableMapping[str, Any]],
value: int = 1, value: int = 1,
check: bool = True, check: bool = True,
@ -706,7 +706,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
@overload @overload
def _command( def _command(
self, self,
sock_info: SocketInfo, connection: Connection,
command: Union[str, MutableMapping[str, Any]], command: Union[str, MutableMapping[str, Any]],
value: int = 1, value: int = 1,
check: bool = True, check: bool = True,
@ -722,7 +722,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def _command( def _command(
self, self,
sock_info: SocketInfo, connection: Connection,
command: Union[str, MutableMapping[str, Any]], command: Union[str, MutableMapping[str, Any]],
value: int = 1, value: int = 1,
check: bool = True, check: bool = True,
@ -742,7 +742,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
command.update(kwargs) command.update(kwargs)
with self.__client._tmp_session(session) as s: with self.__client._tmp_session(session) as s:
return sock_info.command( return connection.command(
self.__name, self.__name,
command, command,
read_preference, read_preference,
@ -890,11 +890,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if read_preference is None: if read_preference is None:
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
with self.__client._socket_for_reads(read_preference, session) as ( with self.__client._socket_for_reads(read_preference, session) as (
sock_info, connection,
read_preference, read_preference,
): ):
return self._command( return self._command(
sock_info, connection,
command, command,
value, value,
check, check,
@ -974,11 +974,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
tmp_session and tmp_session._txn_read_preference() tmp_session and tmp_session._txn_read_preference()
) or ReadPreference.PRIMARY ) or ReadPreference.PRIMARY
with self.__client._socket_for_reads(read_preference, tmp_session) as ( with self.__client._socket_for_reads(read_preference, tmp_session) as (
sock_info, connection,
read_preference, read_preference,
): ):
response = self._command( response = self._command(
sock_info, connection,
command, command,
value, value,
True, True,
@ -993,13 +993,13 @@ class Database(common.BaseObject, Generic[_DocumentType]):
cmd_cursor = CommandCursor( cmd_cursor = CommandCursor(
coll, coll,
response["cursor"], response["cursor"],
sock_info.address, connection.address,
max_await_time_ms=max_await_time_ms, max_await_time_ms=max_await_time_ms,
session=tmp_session, session=tmp_session,
explicit_session=session is not None, explicit_session=session is not None,
comment=comment, comment=comment,
) )
cmd_cursor._maybe_pin_connection(sock_info) cmd_cursor._maybe_pin_connection(connection)
return cmd_cursor return cmd_cursor
else: else:
raise InvalidOperation("Command does not return a cursor.") raise InvalidOperation("Command does not return a cursor.")
@ -1015,11 +1015,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def _cmd( def _cmd(
session: Optional[ClientSession], session: Optional[ClientSession],
server: Server, server: Server,
sock_info: SocketInfo, connection: Connection,
read_preference: _ServerMode, read_preference: _ServerMode,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return self._command( return self._command(
sock_info, connection,
command, command,
read_preference=read_preference, read_preference=read_preference,
session=session, session=session,
@ -1029,7 +1029,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def _list_collections( def _list_collections(
self, self,
sock_info: SocketInfo, connection: Connection,
session: Optional[ClientSession], session: Optional[ClientSession],
read_preference: _ServerMode, read_preference: _ServerMode,
**kwargs: Any, **kwargs: Any,
@ -1040,17 +1040,17 @@ class Database(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs) cmd.update(kwargs)
with self.__client._tmp_session(session, close=False) as tmp_session: with self.__client._tmp_session(session, close=False) as tmp_session:
cursor = self._command( cursor = self._command(
sock_info, cmd, read_preference=read_preference, session=tmp_session connection, cmd, read_preference=read_preference, session=tmp_session
)["cursor"] )["cursor"]
cmd_cursor = CommandCursor( cmd_cursor = CommandCursor(
coll, coll,
cursor, cursor,
sock_info.address, connection.address,
session=tmp_session, session=tmp_session,
explicit_session=session is not None, explicit_session=session is not None,
comment=cmd.get("comment"), comment=cmd.get("comment"),
) )
cmd_cursor._maybe_pin_connection(sock_info) cmd_cursor._maybe_pin_connection(connection)
return cmd_cursor return cmd_cursor
def list_collections( def list_collections(
@ -1090,11 +1090,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def _cmd( def _cmd(
session: Optional[ClientSession], session: Optional[ClientSession],
server: Server, server: Server,
sock_info: SocketInfo, connection: Connection,
read_preference: _ServerMode, read_preference: _ServerMode,
) -> CommandCursor[_DocumentType]: ) -> CommandCursor[_DocumentType]:
return self._list_collections( return self._list_collections(
sock_info, session, read_preference=read_preference, **kwargs connection, session, read_preference=read_preference, **kwargs
) )
return self.__client._retryable_read(_cmd, read_pref, session) return self.__client._retryable_read(_cmd, read_pref, session)
@ -1154,9 +1154,9 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if comment is not None: if comment is not None:
command["comment"] = comment command["comment"] = comment
with self.__client._socket_for_writes(session) as sock_info: with self.__client._socket_for_writes(session) as connection:
return self._command( return self._command(
sock_info, connection,
command, command,
allowable_errors=["ns not found", 26], allowable_errors=["ns not found", 26],
write_concern=self._write_concern_for(session), write_concern=self._write_concern_for(session),

View File

@ -307,7 +307,7 @@ F = TypeVar("F", bound=Callable[..., Any])
def _handle_reauth(func: F) -> F: def _handle_reauth(func: F) -> F:
def inner(*args: Any, **kwargs: Any) -> Any: def inner(*args: Any, **kwargs: Any) -> Any:
no_reauth = kwargs.pop("no_reauth", False) no_reauth = kwargs.pop("no_reauth", False)
from pymongo.pool import SocketInfo from pymongo.pool import Connection
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
@ -315,19 +315,19 @@ def _handle_reauth(func: F) -> F:
if no_reauth: if no_reauth:
raise raise
if exc.code == _REAUTHENTICATION_REQUIRED_CODE: if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
# Look for an argument that either is a SocketInfo # Look for an argument that either is a Connection
# or has a socket_info attribute, so we can trigger # or has a socket_info attribute, so we can trigger
# a reauth. # a reauth.
sock_info = None connection = None
for arg in args: for arg in args:
if isinstance(arg, SocketInfo): if isinstance(arg, Connection):
sock_info = arg connection = arg
break break
if hasattr(arg, "sock_info"): if hasattr(arg, "connection"):
sock_info = arg.sock_info connection = arg.connection
break break
if sock_info: if connection:
sock_info.authenticate(reauthenticate=True) connection.authenticate(reauthenticate=True)
else: else:
raise raise
return func(*args, **kwargs) return func(*args, **kwargs)

View File

@ -227,14 +227,14 @@ def _gen_find_command(
return cmd return cmd
def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, sock_info): def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, connection):
"""Generate a getMore command document.""" """Generate a getMore command document."""
cmd = SON([("getMore", cursor_id), ("collection", coll)]) cmd = SON([("getMore", cursor_id), ("collection", coll)])
if batch_size: if batch_size:
cmd["batchSize"] = batch_size cmd["batchSize"] = batch_size
if max_await_time_ms is not None: if max_await_time_ms is not None:
cmd["maxTimeMS"] = max_await_time_ms cmd["maxTimeMS"] = max_await_time_ms
if comment is not None and sock_info.max_wire_version >= 9: if comment is not None and connection.max_wire_version >= 9:
cmd["comment"] = comment cmd["comment"] = comment
return cmd return cmd
@ -311,24 +311,24 @@ class _Query:
def namespace(self): def namespace(self):
return f"{self.db}.{self.coll}" return f"{self.db}.{self.coll}"
def use_command(self, sock_info): def use_command(self, connection):
use_find_cmd = False use_find_cmd = False
if not self.exhaust: if not self.exhaust:
use_find_cmd = True use_find_cmd = True
elif sock_info.max_wire_version >= 8: elif connection.max_wire_version >= 8:
# OP_MSG supports exhaust on MongoDB 4.2+ # OP_MSG supports exhaust on MongoDB 4.2+
use_find_cmd = True use_find_cmd = True
elif not self.read_concern.ok_for_legacy: elif not self.read_concern.ok_for_legacy:
raise ConfigurationError( raise ConfigurationError(
"read concern level of %s is not valid " "read concern level of %s is not valid "
"with a max wire version of %d." "with a max wire version of %d."
% (self.read_concern.level, sock_info.max_wire_version) % (self.read_concern.level, connection.max_wire_version)
) )
sock_info.validate_session(self.client, self.session) connection.validate_session(self.client, self.session)
return use_find_cmd return use_find_cmd
def as_command(self, sock_info, apply_timeout=False): def as_command(self, connection, apply_timeout=False):
"""Return a find command document for this query.""" """Return a find command document for this query."""
# We use the command twice: on the wire and for command monitoring. # We use the command twice: on the wire and for command monitoring.
# Generate it once, for speed and to avoid repeating side-effects. # Generate it once, for speed and to avoid repeating side-effects.
@ -353,24 +353,24 @@ class _Query:
self.name = "explain" self.name = "explain"
cmd = SON([("explain", cmd)]) cmd = SON([("explain", cmd)])
session = self.session session = self.session
sock_info.add_server_api(cmd) connection.add_server_api(cmd)
if session: if session:
session._apply_to(cmd, False, self.read_preference, sock_info) session._apply_to(cmd, False, self.read_preference, connection)
# Explain does not support readConcern. # Explain does not support readConcern.
if not explain and not session.in_transaction: if not explain and not session.in_transaction:
session._update_read_concern(cmd, sock_info) session._update_read_concern(cmd, connection)
sock_info.send_cluster_time(cmd, session, self.client) connection.send_cluster_time(cmd, session, self.client)
# Support auto encryption # Support auto encryption
client = self.client client = self.client
if client._encrypter and not client._encrypter._bypass_auto_encryption: if client._encrypter and not client._encrypter._bypass_auto_encryption:
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
# Support CSOT # Support CSOT
if apply_timeout: if apply_timeout:
sock_info.apply_timeout(client, cmd) connection.apply_timeout(client, cmd)
self._as_command = cmd, self.db self._as_command = cmd, self.db
return self._as_command return self._as_command
def get_message(self, read_preference, sock_info, use_cmd=False): def get_message(self, read_preference, connection, use_cmd=False):
"""Get a query message, possibly setting the secondaryOk bit.""" """Get a query message, possibly setting the secondaryOk bit."""
# Use the read_preference decided by _socket_from_server. # Use the read_preference decided by _socket_from_server.
self.read_preference = read_preference self.read_preference = read_preference
@ -384,14 +384,14 @@ class _Query:
spec = self.spec spec = self.spec
if use_cmd: if use_cmd:
spec = self.as_command(sock_info, apply_timeout=True)[0] spec = self.as_command(connection, apply_timeout=True)[0]
request_id, msg, size, _ = _op_msg( request_id, msg, size, _ = _op_msg(
0, 0,
spec, spec,
self.db, self.db,
read_preference, read_preference,
self.codec_options, self.codec_options,
ctx=sock_info.compression_context, ctx=connection.compression_context,
) )
return request_id, msg, size return request_id, msg, size
@ -405,7 +405,7 @@ class _Query:
else: else:
ntoreturn = self.limit ntoreturn = self.limit
if sock_info.is_mongos: if connection.is_mongos:
spec = _maybe_add_read_preference(spec, read_preference) spec = _maybe_add_read_preference(spec, read_preference)
return _query( return _query(
@ -416,7 +416,7 @@ class _Query:
spec, spec,
None if use_cmd else self.fields, None if use_cmd else self.fields,
self.codec_options, self.codec_options,
ctx=sock_info.compression_context, ctx=connection.compression_context,
) )
@ -476,18 +476,18 @@ class _GetMore:
def namespace(self): def namespace(self):
return f"{self.db}.{self.coll}" return f"{self.db}.{self.coll}"
def use_command(self, sock_info): def use_command(self, connection):
use_cmd = False use_cmd = False
if not self.exhaust: if not self.exhaust:
use_cmd = True use_cmd = True
elif sock_info.max_wire_version >= 8: elif connection.max_wire_version >= 8:
# OP_MSG supports exhaust on MongoDB 4.2+ # OP_MSG supports exhaust on MongoDB 4.2+
use_cmd = True use_cmd = True
sock_info.validate_session(self.client, self.session) connection.validate_session(self.client, self.session)
return use_cmd return use_cmd
def as_command(self, sock_info, apply_timeout=False): def as_command(self, connection, apply_timeout=False):
"""Return a getMore command document for this query.""" """Return a getMore command document for this query."""
# See _Query.as_command for an explanation of this caching. # See _Query.as_command for an explanation of this caching.
if self._as_command is not None: if self._as_command is not None:
@ -499,35 +499,35 @@ class _GetMore:
self.ntoreturn, self.ntoreturn,
self.max_await_time_ms, self.max_await_time_ms,
self.comment, self.comment,
sock_info, connection,
) )
if self.session: if self.session:
self.session._apply_to(cmd, False, self.read_preference, sock_info) self.session._apply_to(cmd, False, self.read_preference, connection)
sock_info.add_server_api(cmd) connection.add_server_api(cmd)
sock_info.send_cluster_time(cmd, self.session, self.client) connection.send_cluster_time(cmd, self.session, self.client)
# Support auto encryption # Support auto encryption
client = self.client client = self.client
if client._encrypter and not client._encrypter._bypass_auto_encryption: if client._encrypter and not client._encrypter._bypass_auto_encryption:
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
# Support CSOT # Support CSOT
if apply_timeout: if apply_timeout:
sock_info.apply_timeout(client, cmd=None) connection.apply_timeout(client, cmd=None)
self._as_command = cmd, self.db self._as_command = cmd, self.db
return self._as_command return self._as_command
def get_message(self, dummy0, sock_info, use_cmd=False): def get_message(self, dummy0, connection, use_cmd=False):
"""Get a getmore message.""" """Get a getmore message."""
ns = self.namespace() ns = self.namespace()
ctx = sock_info.compression_context ctx = connection.compression_context
if use_cmd: if use_cmd:
spec = self.as_command(sock_info, apply_timeout=True)[0] spec = self.as_command(connection, apply_timeout=True)[0]
if self.sock_mgr: if self.sock_mgr:
flags = _OpMsg.EXHAUST_ALLOWED flags = _OpMsg.EXHAUST_ALLOWED
else: else:
flags = 0 flags = 0
request_id, msg, size, _ = _op_msg( request_id, msg, size, _ = _op_msg(
flags, spec, self.db, None, self.codec_options, ctx=sock_info.compression_context flags, spec, self.db, None, self.codec_options, ctx=connection.compression_context
) )
return request_id, msg, size return request_id, msg, size
@ -535,10 +535,10 @@ class _GetMore:
class _RawBatchQuery(_Query): class _RawBatchQuery(_Query):
def use_command(self, sock_info): def use_command(self, connection):
# Compatibility checks. # Compatibility checks.
super().use_command(sock_info) super().use_command(connection)
if sock_info.max_wire_version >= 8: if connection.max_wire_version >= 8:
# MongoDB 4.2+ supports exhaust over OP_MSG # MongoDB 4.2+ supports exhaust over OP_MSG
return True return True
elif not self.exhaust: elif not self.exhaust:
@ -547,10 +547,10 @@ class _RawBatchQuery(_Query):
class _RawBatchGetMore(_GetMore): class _RawBatchGetMore(_GetMore):
def use_command(self, sock_info): def use_command(self, connection):
# Compatibility checks. # Compatibility checks.
super().use_command(sock_info) super().use_command(connection)
if sock_info.max_wire_version >= 8: if connection.max_wire_version >= 8:
# MongoDB 4.2+ supports exhaust over OP_MSG # MongoDB 4.2+ supports exhaust over OP_MSG
return True return True
elif not self.exhaust: elif not self.exhaust:
@ -794,11 +794,11 @@ def _get_more(collection_name, num_to_return, cursor_id, ctx=None):
class _BulkWriteContext: class _BulkWriteContext:
"""A wrapper around SocketInfo for use with write splitting functions.""" """A wrapper around Connection for use with write splitting functions."""
__slots__ = ( __slots__ = (
"db_name", "db_name",
"sock_info", "connection",
"op_id", "op_id",
"name", "name",
"field", "field",
@ -812,10 +812,10 @@ class _BulkWriteContext:
) )
def __init__( def __init__(
self, database_name, cmd_name, sock_info, operation_id, listeners, session, op_type, codec self, database_name, cmd_name, connection, operation_id, listeners, session, op_type, codec
): ):
self.db_name = database_name self.db_name = database_name
self.sock_info = sock_info self.connection = connection
self.op_id = operation_id self.op_id = operation_id
self.listeners = listeners self.listeners = listeners
self.publish = listeners.enabled_for_commands self.publish = listeners.enabled_for_commands
@ -823,7 +823,7 @@ class _BulkWriteContext:
self.field = _FIELD_MAP[self.name] self.field = _FIELD_MAP[self.name]
self.start_time = datetime.datetime.now() if self.publish else None self.start_time = datetime.datetime.now() if self.publish else None
self.session = session self.session = session
self.compress = True if sock_info.compression_context else False self.compress = True if connection.compression_context else False
self.op_type = op_type self.op_type = op_type
self.codec = codec self.codec = codec
@ -855,20 +855,20 @@ class _BulkWriteContext:
@property @property
def max_bson_size(self): def max_bson_size(self):
"""A proxy for SockInfo.max_bson_size.""" """A proxy for SockInfo.max_bson_size."""
return self.sock_info.max_bson_size return self.connection.max_bson_size
@property @property
def max_message_size(self): def max_message_size(self):
"""A proxy for SockInfo.max_message_size.""" """A proxy for SockInfo.max_message_size."""
if self.compress: if self.compress:
# Subtract 16 bytes for the message header. # Subtract 16 bytes for the message header.
return self.sock_info.max_message_size - 16 return self.connection.max_message_size - 16
return self.sock_info.max_message_size return self.connection.max_message_size
@property @property
def max_write_batch_size(self): def max_write_batch_size(self):
"""A proxy for SockInfo.max_write_batch_size.""" """A proxy for SockInfo.max_write_batch_size."""
return self.sock_info.max_write_batch_size return self.connection.max_write_batch_size
@property @property
def max_split_size(self): def max_split_size(self):
@ -876,14 +876,14 @@ class _BulkWriteContext:
return self.max_bson_size return self.max_bson_size
def unack_write(self, cmd, request_id, msg, max_doc_size, docs): def unack_write(self, cmd, request_id, msg, max_doc_size, docs):
"""A proxy for SocketInfo.unack_write that handles event publishing.""" """A proxy for Connection.unack_write that handles event publishing."""
if self.publish: if self.publish:
assert self.start_time is not None assert self.start_time is not None
duration = datetime.datetime.now() - self.start_time duration = datetime.datetime.now() - self.start_time
cmd = self._start(cmd, request_id, docs) cmd = self._start(cmd, request_id, docs)
start = datetime.datetime.now() start = datetime.datetime.now()
try: try:
result = self.sock_info.unack_write(msg, max_doc_size) result = self.connection.unack_write(msg, max_doc_size)
if self.publish: if self.publish:
duration = (datetime.datetime.now() - start) + duration duration = (datetime.datetime.now() - start) + duration
if result is not None: if result is not None:
@ -910,14 +910,14 @@ class _BulkWriteContext:
@_handle_reauth @_handle_reauth
def write_command(self, cmd, request_id, msg, docs): def write_command(self, cmd, request_id, msg, docs):
"""A proxy for SocketInfo.write_command that handles event publishing.""" """A proxy for Connection.write_command that handles event publishing."""
if self.publish: if self.publish:
assert self.start_time is not None assert self.start_time is not None
duration = datetime.datetime.now() - self.start_time duration = datetime.datetime.now() - self.start_time
self._start(cmd, request_id, docs) self._start(cmd, request_id, docs)
start = datetime.datetime.now() start = datetime.datetime.now()
try: try:
reply = self.sock_info.write_command(request_id, msg, self.codec) reply = self.connection.write_command(request_id, msg, self.codec)
if self.publish: if self.publish:
duration = (datetime.datetime.now() - start) + duration duration = (datetime.datetime.now() - start) + duration
self._succeed(request_id, reply, duration) self._succeed(request_id, reply, duration)
@ -941,9 +941,9 @@ class _BulkWriteContext:
cmd, cmd,
self.db_name, self.db_name,
request_id, request_id,
self.sock_info.address, self.connection.address,
self.op_id, self.op_id,
self.sock_info.service_id, self.connection.service_id,
) )
return cmd return cmd
@ -954,9 +954,9 @@ class _BulkWriteContext:
reply, reply,
self.name, self.name,
request_id, request_id,
self.sock_info.address, self.connection.address,
self.op_id, self.op_id,
self.sock_info.service_id, self.connection.service_id,
) )
def _fail(self, request_id, failure, duration): def _fail(self, request_id, failure, duration):
@ -966,9 +966,9 @@ class _BulkWriteContext:
failure, failure,
self.name, self.name,
request_id, request_id,
self.sock_info.address, self.connection.address,
self.op_id, self.op_id,
self.sock_info.service_id, self.connection.service_id,
) )
@ -997,14 +997,14 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
def execute(self, cmd, docs, client): def execute(self, cmd, docs, client):
batched_cmd, to_send = self._batch_command(cmd, docs) batched_cmd, to_send = self._batch_command(cmd, docs)
result = self.sock_info.command( result = self.connection.command(
self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client
) )
return result, to_send return result, to_send
def execute_unack(self, cmd, docs, client): def execute_unack(self, cmd, docs, client):
batched_cmd, to_send = self._batch_command(cmd, docs) batched_cmd, to_send = self._batch_command(cmd, docs)
self.sock_info.command( self.connection.command(
self.db_name, self.db_name,
batched_cmd, batched_cmd,
write_concern=WriteConcern(w=0), write_concern=WriteConcern(w=0),
@ -1124,7 +1124,7 @@ def _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx):
""" """
data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx) data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx)
request_id, msg = _compress(2013, data, ctx.sock_info.compression_context) request_id, msg = _compress(2013, data, ctx.connection.compression_context)
return request_id, msg, to_send return request_id, msg, to_send
@ -1162,7 +1162,7 @@ def _do_batched_op_msg(namespace, operation, command, docs, opts, ctx):
ack = bool(command["writeConcern"].get("w", 1)) ack = bool(command["writeConcern"].get("w", 1))
else: else:
ack = True ack = True
if ctx.sock_info.compression_context: if ctx.connection.compression_context:
return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx) return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx)
return _batched_op_msg(operation, command, docs, ack, opts, ctx) return _batched_op_msg(operation, command, docs, ack, opts, ctx)

View File

@ -1160,18 +1160,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
def _end_sessions(self, session_ids): def _end_sessions(self, session_ids):
"""Send endSessions command(s) with the given session ids.""" """Send endSessions command(s) with the given session ids."""
try: try:
# Use SocketInfo.command directly to avoid implicitly creating # Use Connection.command directly to avoid implicitly creating
# another session. # another session.
with self._socket_for_reads(ReadPreference.PRIMARY_PREFERRED, None) as ( with self._socket_for_reads(ReadPreference.PRIMARY_PREFERRED, None) as (
sock_info, connection,
read_pref, read_pref,
): ):
if not sock_info.supports_sessions: if not connection.supports_sessions:
return return
for i in range(0, len(session_ids), common._MAX_END_SESSIONS): for i in range(0, len(session_ids), common._MAX_END_SESSIONS):
spec = SON([("endSessions", session_ids[i : i + common._MAX_END_SESSIONS])]) spec = SON([("endSessions", session_ids[i : i + common._MAX_END_SESSIONS])])
sock_info.command("admin", spec, read_preference=read_pref, client=self) connection.command("admin", spec, read_preference=read_pref, client=self)
except PyMongoError: except PyMongoError:
# Drivers MUST ignore any errors returned by the endSessions # Drivers MUST ignore any errors returned by the endSessions
# command. # command.
@ -1224,23 +1224,23 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
err_handler.contribute_socket(session._pinned_connection) err_handler.contribute_socket(session._pinned_connection)
yield session._pinned_connection yield session._pinned_connection
return return
with server.get_socket(handler=err_handler) as sock_info: with server.get_socket(handler=err_handler) as connection:
# Pin this session to the selected server or connection. # Pin this session to the selected server or connection.
if in_txn and server.description.server_type in ( if in_txn and server.description.server_type in (
SERVER_TYPE.Mongos, SERVER_TYPE.Mongos,
SERVER_TYPE.LoadBalancer, SERVER_TYPE.LoadBalancer,
): ):
session._pin(server, sock_info) session._pin(server, connection)
err_handler.contribute_socket(sock_info) err_handler.contribute_socket(connection)
if ( if (
self._encrypter self._encrypter
and not self._encrypter._bypass_auto_encryption and not self._encrypter._bypass_auto_encryption
and sock_info.max_wire_version < 8 and connection.max_wire_version < 8
): ):
raise ConfigurationError( raise ConfigurationError(
"Auto-encryption requires a minimum MongoDB version of 4.2" "Auto-encryption requires a minimum MongoDB version of 4.2"
) )
yield sock_info yield connection
def _select_server(self, server_selector, session, address=None): def _select_server(self, server_selector, session, address=None):
"""Select a server to run an operation on this client. """Select a server to run an operation on this client.
@ -1281,7 +1281,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
def _socket_from_server(self, read_preference, server, session): def _socket_from_server(self, read_preference, server, session):
assert read_preference is not None, "read_preference must not be None" assert read_preference is not None, "read_preference must not be None"
# Get a socket for a server matching the read preference, and yield # Get a socket for a server matching the read preference, and yield
# sock_info with the effective read preference. The Server Selection # connection with the effective read preference. The Server Selection
# Spec says not to send any $readPreference to standalones and to # Spec says not to send any $readPreference to standalones and to
# always send primaryPreferred when directly connected to a repl set # always send primaryPreferred when directly connected to a repl set
# member. # member.
@ -1289,16 +1289,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
topology = self._get_topology() topology = self._get_topology()
single = topology.description.topology_type == TOPOLOGY_TYPE.Single single = topology.description.topology_type == TOPOLOGY_TYPE.Single
with self._get_socket(server, session) as sock_info: with self._get_socket(server, session) as connection:
if single: if single:
if sock_info.is_repl and not (session and session.in_transaction): if connection.is_repl and not (session and session.in_transaction):
# Use primary preferred to ensure any repl set member # Use primary preferred to ensure any repl set member
# can handle the request. # can handle the request.
read_preference = ReadPreference.PRIMARY_PREFERRED read_preference = ReadPreference.PRIMARY_PREFERRED
elif sock_info.is_standalone: elif connection.is_standalone:
# Don't send read preference to standalones. # Don't send read preference to standalones.
read_preference = ReadPreference.PRIMARY read_preference = ReadPreference.PRIMARY
yield sock_info, read_preference yield connection, read_preference
def _socket_for_reads(self, read_preference, session): def _socket_for_reads(self, read_preference, session):
assert read_preference is not None, "read_preference must not be None" assert read_preference is not None, "read_preference must not be None"
@ -1331,10 +1331,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
operation.sock_mgr.sock, operation, True, self._event_listeners, unpack_res operation.sock_mgr.sock, operation, True, self._event_listeners, unpack_res
) )
def _cmd(session, server, sock_info, read_preference): def _cmd(session, server, connection, read_preference):
operation.reset() # Reset op in case of retry. operation.reset() # Reset op in case of retry.
return server.run_operation( return server.run_operation(
sock_info, operation, read_preference, self._event_listeners, unpack_res connection, operation, read_preference, self._event_listeners, unpack_res
) )
return self._retryable_read( return self._retryable_read(
@ -1388,8 +1388,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
supports_session = ( supports_session = (
session is not None and server.description.retryable_writes_supported session is not None and server.description.retryable_writes_supported
) )
with self._get_socket(server, session) as sock_info: with self._get_socket(server, session) as connection:
max_wire_version = sock_info.max_wire_version max_wire_version = connection.max_wire_version
if retryable and not supports_session: if retryable and not supports_session:
if is_retrying(): if is_retrying():
# A retry is not possible because this server does # A retry is not possible because this server does
@ -1397,7 +1397,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
assert last_error is not None assert last_error is not None
raise last_error raise last_error
retryable = False retryable = False
return func(session, sock_info, retryable) return func(session, connection, retryable)
except ServerSelectionTimeoutError: except ServerSelectionTimeoutError:
if is_retrying(): if is_retrying():
# The application may think the write was never attempted # The application may think the write was never attempted
@ -1455,13 +1455,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
raise last_error raise last_error
try: try:
server = self._select_server(read_pref, session, address=address) server = self._select_server(read_pref, session, address=address)
with self._socket_from_server(read_pref, server, session) as (sock_info, read_pref): with self._socket_from_server(read_pref, server, session) as (
connection,
read_pref,
):
if retrying and not retryable: if retrying and not retryable:
# A retry is not possible because this server does # A retry is not possible because this server does
# not support retryable reads, raise the last error. # not support retryable reads, raise the last error.
assert last_error is not None assert last_error is not None
raise last_error raise last_error
return func(session, server, sock_info, read_pref) return func(session, server, connection, read_pref)
except ServerSelectionTimeoutError: except ServerSelectionTimeoutError:
if retrying: if retrying:
# The application may think the write was never attempted # The application may think the write was never attempted
@ -1633,14 +1636,14 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
# Application called close_cursor() with no address. # Application called close_cursor() with no address.
server = topology.select_server(writable_server_selector) server = topology.select_server(writable_server_selector)
with self._get_socket(server, session) as sock_info: with self._get_socket(server, session) as connection:
self._kill_cursor_impl(cursor_ids, address, session, sock_info) self._kill_cursor_impl(cursor_ids, address, session, connection)
def _kill_cursor_impl(self, cursor_ids, address, session, sock_info): def _kill_cursor_impl(self, cursor_ids, address, session, connection):
namespace = address.namespace namespace = address.namespace
db, coll = namespace.split(".", 1) db, coll = namespace.split(".", 1)
spec = SON([("killCursors", coll), ("cursors", cursor_ids)]) spec = SON([("killCursors", coll), ("cursors", cursor_ids)])
sock_info.command(db, spec, session=session, client=self) connection.command(db, spec, session=session, client=self)
def _process_kill_cursors(self): def _process_kill_cursors(self):
"""Process any pending kill cursors requests.""" """Process any pending kill cursors requests."""
@ -1925,9 +1928,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
if not isinstance(name, str): if not isinstance(name, str):
raise TypeError("name_or_database must be an instance of str or a Database") raise TypeError("name_or_database must be an instance of str or a Database")
with self._socket_for_writes(session) as sock_info: with self._socket_for_writes(session) as connection:
self[name]._command( self[name]._command(
sock_info, connection,
{"dropDatabase": 1, "comment": comment}, {"dropDatabase": 1, "comment": comment},
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session), write_concern=self._write_concern_for(session),
@ -2149,11 +2152,11 @@ class _MongoClientErrorHandler:
self.service_id = None self.service_id = None
self.handled = False self.handled = False
def contribute_socket(self, sock_info, completed_handshake=True): def contribute_socket(self, connection, completed_handshake=True):
"""Provide socket information to the error handler.""" """Provide socket information to the error handler."""
self.max_wire_version = sock_info.max_wire_version self.max_wire_version = connection.max_wire_version
self.sock_generation = sock_info.generation self.sock_generation = connection.generation
self.service_id = sock_info.service_id self.service_id = connection.service_id
self.completed_handshake = completed_handshake self.completed_handshake = completed_handshake
def handle(self, exc_type, exc_val): def handle(self, exc_type, exc_val):

View File

@ -245,9 +245,9 @@ class Monitor(MonitorBase):
if self._cancel_context and self._cancel_context.cancelled: if self._cancel_context and self._cancel_context.cancelled:
self._reset_connection() self._reset_connection()
with self._pool.get_socket() as sock_info: with self._pool.get_socket() as connection:
self._cancel_context = sock_info.cancel_context self._cancel_context = connection.cancel_context
response, round_trip_time = self._check_with_socket(sock_info) response, round_trip_time = self._check_with_socket(connection)
if not response.awaitable: if not response.awaitable:
self._rtt_monitor.add_sample(round_trip_time) self._rtt_monitor.add_sample(round_trip_time)
@ -393,11 +393,11 @@ class _RttMonitor(MonitorBase):
def _ping(self): def _ping(self):
"""Run a "hello" command and return the RTT.""" """Run a "hello" command and return the RTT."""
with self._pool.get_socket() as sock_info: with self._pool.get_socket() as connection:
if self._executor._stopped: if self._executor._stopped:
raise Exception("_RttMonitor closed") raise Exception("_RttMonitor closed")
start = time.monotonic() start = time.monotonic()
sock_info.hello() connection.hello()
return time.monotonic() - start return time.monotonic() - start

View File

@ -52,7 +52,7 @@ if TYPE_CHECKING:
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
from pymongo.mongo_client import MongoClient from pymongo.mongo_client import MongoClient
from pymongo.monitoring import _EventListeners from pymongo.monitoring import _EventListeners
from pymongo.pool import SocketInfo from pymongo.pool import Connection
from pymongo.read_concern import ReadConcern from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode from pymongo.read_preferences import _ServerMode
from pymongo.typings import _Address from pymongo.typings import _Address
@ -62,7 +62,7 @@ _UNPACK_HEADER = struct.Struct("<iiii").unpack
def command( def command(
sock_info: SocketInfo, connection: Connection,
dbname: str, dbname: str,
spec: MutableMapping[str, Any], spec: MutableMapping[str, Any],
is_mongos: bool, is_mongos: bool,
@ -125,7 +125,7 @@ def command(
if read_concern.level: if read_concern.level:
spec["readConcern"] = read_concern.document spec["readConcern"] = read_concern.document
if session: if session:
session._update_read_concern(spec, sock_info) session._update_read_concern(spec, connection)
if collation is not None: if collation is not None:
spec["collation"] = collation spec["collation"] = collation
@ -142,7 +142,7 @@ def command(
# Support CSOT # Support CSOT
if client: if client:
sock_info.apply_timeout(client, spec) connection.apply_timeout(client, spec)
_csot.apply_write_concern(spec, write_concern) _csot.apply_write_concern(spec, write_concern)
if use_op_msg: if use_op_msg:
@ -167,19 +167,19 @@ def command(
encoding_duration = datetime.datetime.now() - start encoding_duration = datetime.datetime.now() - start
assert listeners is not None assert listeners is not None
listeners.publish_command_start( listeners.publish_command_start(
orig, dbname, request_id, address, service_id=sock_info.service_id orig, dbname, request_id, address, service_id=connection.service_id
) )
start = datetime.datetime.now() start = datetime.datetime.now()
try: try:
sock_info.sock.sendall(msg) connection.send_message(msg, max_bson_size)
if use_op_msg and unacknowledged: if use_op_msg and unacknowledged:
# Unacknowledged, fake a successful command response. # Unacknowledged, fake a successful command response.
reply = None reply = None
response_doc = {"ok": 1} response_doc = {"ok": 1}
else: else:
reply = receive_message(sock_info, request_id) reply = connection.receive_message(request_id)
sock_info.more_to_come = reply.more_to_come connection.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response( unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields codec_options=codec_options, user_fields=user_fields
) )
@ -190,7 +190,7 @@ def command(
if check: if check:
helpers._check_command_response( helpers._check_command_response(
response_doc, response_doc,
sock_info.max_wire_version, connection.max_wire_version,
allowable_errors, allowable_errors,
parse_write_concern_error=parse_write_concern_error, parse_write_concern_error=parse_write_concern_error,
) )
@ -203,7 +203,7 @@ def command(
failure = message._convert_exception(exc) failure = message._convert_exception(exc)
assert listeners is not None assert listeners is not None
listeners.publish_command_failure( listeners.publish_command_failure(
duration, failure, name, request_id, address, service_id=sock_info.service_id duration, failure, name, request_id, address, service_id=connection.service_id
) )
raise raise
if publish: if publish:
@ -215,7 +215,7 @@ def command(
name, name,
request_id, request_id,
address, address,
service_id=sock_info.service_id, service_id=connection.service_id,
speculative_hello=speculative_hello, speculative_hello=speculative_hello,
) )
@ -229,21 +229,21 @@ def command(
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack _UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
def receive_message( def receive_message_tcp(
sock_info: SocketInfo, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE connection: Connection, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]: ) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error.""" """Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout(): if _csot.get_timeout():
deadline = _csot.get_deadline() deadline = _csot.get_deadline()
else: else:
timeout = sock_info.sock.gettimeout() timeout = connection.connector.gettimeout()
if timeout: if timeout:
deadline = time.monotonic() + timeout deadline = time.monotonic() + timeout
else: else:
deadline = None deadline = None
# Ignore the response's request id. # Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER( length, _, response_to, op_code = _UNPACK_HEADER(
_receive_data_on_socket(sock_info, 16, deadline) _receive_data_on_socket(connection, 16, deadline)
) )
# No request_id for exhaust cursor "getMore". # No request_id for exhaust cursor "getMore".
if request_id is not None: if request_id is not None:
@ -260,11 +260,11 @@ def receive_message(
) )
if op_code == 2012: if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
_receive_data_on_socket(sock_info, 9, deadline) _receive_data_on_socket(connection, 9, deadline)
) )
data = decompress(_receive_data_on_socket(sock_info, length - 25, deadline), compressor_id) data = decompress(_receive_data_on_socket(connection, length - 25, deadline), compressor_id)
else: else:
data = _receive_data_on_socket(sock_info, length - 16, deadline) data = _receive_data_on_socket(connection, length - 16, deadline)
try: try:
unpack_reply = _UNPACK_REPLY[op_code] unpack_reply = _UNPACK_REPLY[op_code]
@ -273,15 +273,51 @@ def receive_message(
return unpack_reply(data) return unpack_reply(data)
def receive_message_grpc(
connection: Connection, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
# Ignore the response's request id.
if connection.response is not None:
data = connection.response.__next__()
length, _, response_to, op_code = _UNPACK_HEADER(data[:16])
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
if length <= 16:
raise ProtocolError(
f"Message length ({length!r}) not longer than standard message header size (16)"
)
if length > max_message_size:
raise ProtocolError(
"Message length ({!r}) is larger than server max "
"message size ({!r})".format(length, max_message_size)
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(data[:9])
data = decompress(data[25:], compressor_id)
else:
data = data[16:]
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}")
return unpack_reply(data)
else:
raise ProtocolError(f"Received empty response for request id {request_id}")
_POLL_TIMEOUT = 0.5 _POLL_TIMEOUT = 0.5
def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None: def wait_for_read(connection: Connection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel.""" """Block until at least one byte is read, or a timeout, or a cancel."""
context = sock_info.cancel_context context = connection.cancel_context
# Only Monitor connections can be cancelled. # Only Monitor connections can be cancelled.
if context: if context:
sock = sock_info.sock sock = connection.connector
timed_out = False timed_out = False
while True: while True:
# SSLSocket can have buffered data which won't be caught by select. # SSLSocket can have buffered data which won't be caught by select.
@ -300,7 +336,7 @@ def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None:
timeout = max(min(remaining, _POLL_TIMEOUT), 0) timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else: else:
timeout = _POLL_TIMEOUT timeout = _POLL_TIMEOUT
readable = sock_info.socket_checker.select(sock, read=True, timeout=timeout) readable = connection.socket_checker.select(sock, read=True, timeout=timeout)
if context.cancelled: if context.cancelled:
raise _OperationCancelled("hello cancelled") raise _OperationCancelled("hello cancelled")
if readable: if readable:
@ -314,20 +350,20 @@ BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS)
def _receive_data_on_socket( def _receive_data_on_socket(
sock_info: SocketInfo, length: int, deadline: Optional[float] connection: Connection, length: int, deadline: Optional[float]
) -> memoryview: ) -> memoryview:
buf = bytearray(length) buf = bytearray(length)
mv = memoryview(buf) mv = memoryview(buf)
bytes_read = 0 bytes_read = 0
while bytes_read < length: while bytes_read < length:
try: try:
wait_for_read(sock_info, deadline) wait_for_read(connection, deadline)
# CSOT: Update timeout. When the timeout has expired perform one # CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when # final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client. # the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None: if _csot.get_timeout() and deadline is not None:
sock_info.set_socket_timeout(max(deadline - time.monotonic(), 0)) connection.set_connector_timeout(max(deadline - time.monotonic(), 0))
chunk_length = sock_info.sock.recv_into(mv[bytes_read:]) chunk_length = connection.connector.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS: except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") raise socket.timeout("timed out")
except OSError as exc: # noqa: B014 except OSError as exc: # noqa: B014

View File

@ -23,7 +23,14 @@ import sys
import threading import threading
import time import time
import weakref import weakref
from typing import Any, Dict, NoReturn, Optional from typing import Any, Dict, Iterator, NoReturn, Optional
try:
import grpc # type: ignore
_HAVE_GRPC = True
except ImportError:
_HAVE_GRPC = False
import bson import bson
from bson import DEFAULT_CODEC_OPTIONS from bson import DEFAULT_CODEC_OPTIONS
@ -60,7 +67,7 @@ from pymongo.hello import Hello, HelloCompat
from pymongo.helpers import _handle_reauth from pymongo.helpers import _handle_reauth
from pymongo.lock import _create_lock from pymongo.lock import _create_lock
from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
from pymongo.network import command, receive_message from pymongo.network import command, receive_message_grpc, receive_message_tcp
from pymongo.read_preferences import ReadPreference from pymongo.read_preferences import ReadPreference
from pymongo.server_api import _add_to_command from pymongo.server_api import _add_to_command
from pymongo.server_type import SERVER_TYPE from pymongo.server_type import SERVER_TYPE
@ -370,6 +377,12 @@ def _cond_wait(condition, deadline):
return condition.wait(timeout) return condition.wait(timeout)
class ConnectionProtocol:
TCP_SOCKET = 0
GRPC = 1
class PoolOptions: class PoolOptions:
"""Read only connection pool options for a MongoClient. """Read only connection pool options for a MongoClient.
@ -402,6 +415,7 @@ class PoolOptions:
"__server_api", "__server_api",
"__load_balanced", "__load_balanced",
"__credentials", "__credentials",
"__protocol",
) )
def __init__( def __init__(
@ -423,6 +437,7 @@ class PoolOptions:
server_api=None, server_api=None,
load_balanced=None, load_balanced=None,
credentials=None, credentials=None,
protocol=ConnectionProtocol.TCP_SOCKET,
): ):
self.__max_pool_size = max_pool_size self.__max_pool_size = max_pool_size
self.__min_pool_size = min_pool_size self.__min_pool_size = min_pool_size
@ -441,6 +456,7 @@ class PoolOptions:
self.__server_api = server_api self.__server_api = server_api
self.__load_balanced = load_balanced self.__load_balanced = load_balanced
self.__credentials = credentials self.__credentials = credentials
self.__protocol = protocol
self.__metadata = copy.deepcopy(_METADATA) self.__metadata = copy.deepcopy(_METADATA)
if appname: if appname:
self.__metadata["application"] = {"name": appname} self.__metadata["application"] = {"name": appname}
@ -473,6 +489,11 @@ class PoolOptions:
_truncate_metadata(self.__metadata) _truncate_metadata(self.__metadata)
@property
def protocol(self):
"""A :class:`~pymongo.pool.ConnectionProtocol` value."""
return self.__protocol
@property @property
def _credentials(self): def _credentials(self):
"""A :class:`~pymongo.auth.MongoCredentials` instance or None.""" """A :class:`~pymongo.auth.MongoCredentials` instance or None."""
@ -614,21 +635,22 @@ class _CancellationContext:
return self._cancelled return self._cancelled
class SocketInfo: class Connection:
"""Store a socket with some metadata. """Store a connection with some metadata.
:Parameters: :Parameters:
- `sock`: a raw socket object - `connector`: a raw connector implementation, currently either a TCP Socket or a gRPC Stream
- `pool`: a Pool instance - `pool`: a Pool instance
- `address`: the server's (host, port) - `address`: the server's (host, port)
- `id`: the id of this socket in it's pool - `id`: the id of this socket in its pool
""" """
def __init__(self, sock, pool, address, id): def __init__(self, connector, pool, address, id, protocol):
self.pool_ref = weakref.ref(pool) self.pool_ref = weakref.ref(pool)
self.sock = sock self.connector = connector
self.address = address self.address = address
self.id = id self.id = id
self.protocol = protocol
self.authed = set() self.authed = set()
self.closed = False self.closed = False
self.last_checkin_time = time.monotonic() self.last_checkin_time = time.monotonic()
@ -671,14 +693,18 @@ class SocketInfo:
self.pinned_cursor = False self.pinned_cursor = False
self.active = False self.active = False
self.last_timeout = self.opts.socket_timeout self.last_timeout = self.opts.socket_timeout
self.current_timeout = self.last_timeout
self.connect_rtt = 0.0 self.connect_rtt = 0.0
self.response: Optional[Iterator] = None
def set_socket_timeout(self, timeout): def set_connector_timeout(self, timeout):
"""Cache last timeout to avoid duplicate calls to sock.settimeout.""" """Cache last timeout to avoid duplicate calls to connector timeout implementations."""
if timeout == self.last_timeout: if timeout == self.last_timeout:
return return
self.last_timeout = timeout self.last_timeout = timeout
self.sock.settimeout(timeout) self.current_timeout = timeout
if self.protocol == ConnectionProtocol.TCP_SOCKET: # TODO: Implement timeouts for gRPC mode
self.connector.settimeout(timeout)
def apply_timeout(self, client, cmd): def apply_timeout(self, client, cmd):
# CSOT: use remaining timeout when set. # CSOT: use remaining timeout when set.
@ -686,7 +712,7 @@ class SocketInfo:
if timeout is None: if timeout is None:
# Reset the socket timeout unless we're performing a streaming monitor check. # Reset the socket timeout unless we're performing a streaming monitor check.
if not self.more_to_come: if not self.more_to_come:
self.set_socket_timeout(self.opts.socket_timeout) self.set_connector_timeout(self.opts.socket_timeout)
return None return None
# RTT validation. # RTT validation.
rtt = _csot.get_rtt() rtt = _csot.get_rtt()
@ -701,7 +727,7 @@ class SocketInfo:
) )
if cmd is not None: if cmd is not None:
cmd["maxTimeMS"] = int(max_time_ms * 1000) cmd["maxTimeMS"] = int(max_time_ms * 1000)
self.set_socket_timeout(timeout) self.set_connector_timeout(timeout)
return timeout return timeout
def pin_txn(self): def pin_txn(self):
@ -748,7 +774,7 @@ class SocketInfo:
awaitable = True awaitable = True
# If connect_timeout is None there is no timeout. # If connect_timeout is None there is no timeout.
if self.opts.connect_timeout: if self.opts.connect_timeout:
self.set_socket_timeout(self.opts.connect_timeout + heartbeat_frequency) self.set_connector_timeout(self.opts.connect_timeout + heartbeat_frequency)
if not performing_handshake and cluster_time is not None: if not performing_handshake and cluster_time is not None:
cmd["$clusterTime"] = cluster_time cmd["$clusterTime"] = cluster_time
@ -910,7 +936,7 @@ class SocketInfo:
def send_message(self, message, max_doc_size): def send_message(self, message, max_doc_size):
"""Send a raw BSON message or raise ConnectionFailure. """Send a raw BSON message or raise ConnectionFailure.
If a network exception is raised, the socket is closed. If a network exception is raised, the connector is closed.
""" """
if self.max_bson_size is not None and max_doc_size > self.max_bson_size: if self.max_bson_size is not None and max_doc_size > self.max_bson_size:
raise DocumentTooLarge( raise DocumentTooLarge(
@ -919,17 +945,32 @@ class SocketInfo:
) )
try: try:
self.sock.sendall(message) if self.protocol == ConnectionProtocol.GRPC:
self.response = self.connector.__call__(
iter([message]),
metadata=[
("security-uuid", "uuid"),
("username", "user"),
("servername", "host.local.10gen.cc"),
("mongodb-wireversion", "18"),
("x-forwarded-for", "127.0.0.1:9901"),
],
)
else:
self.connector.sendall(message)
except BaseException as error: except BaseException as error:
self._raise_connection_failure(error) self._raise_connection_failure(error)
def receive_message(self, request_id): def receive_message(self, request_id):
"""Receive a raw BSON message or raise ConnectionFailure. """Receive a raw BSON message or raise ConnectionFailure.
If any exception is raised, the socket is closed. If any exception is raised, the connector is closed.
""" """
try: try:
return receive_message(self, request_id, self.max_message_size) if self.protocol == ConnectionProtocol.GRPC:
return receive_message_grpc(self, request_id, self.max_message_size)
else:
return receive_message_tcp(self, request_id, self.max_message_size)
except BaseException as error: except BaseException as error:
self._raise_connection_failure(error) self._raise_connection_failure(error)
@ -1017,13 +1058,13 @@ class SocketInfo:
# Note: We catch exceptions to avoid spurious errors on interpreter # Note: We catch exceptions to avoid spurious errors on interpreter
# shutdown. # shutdown.
try: try:
self.sock.close() self.connector.close()
except Exception: except Exception:
pass pass
def socket_closed(self): def socket_closed(self):
"""Return True if we know socket has been closed, False otherwise.""" """Return True if we know socket has been closed, False otherwise."""
return self.socket_checker.socket_closed(self.sock) return self.socket_checker.socket_closed(self.connector)
def send_cluster_time(self, command, session, client): def send_cluster_time(self, command, session, client):
"""Add $clusterTime.""" """Add $clusterTime."""
@ -1073,17 +1114,17 @@ class SocketInfo:
raise raise
def __eq__(self, other): def __eq__(self, other):
return self.sock == other.sock return self.connector == other.connector
def __ne__(self, other): def __ne__(self, other):
return not self == other return not self == other
def __hash__(self): def __hash__(self):
return hash(self.sock) return hash(self.connector)
def __repr__(self): def __repr__(self):
return "SocketInfo({}){} at {}".format( return "Connection({}){} at {}".format(
repr(self.sock), repr(self.connector),
self.closed and " CLOSED" or "", self.closed and " CLOSED" or "",
id(self), id(self),
) )
@ -1256,7 +1297,7 @@ class Pool:
:Parameters: :Parameters:
- `address`: a (hostname, port) tuple - `address`: a (hostname, port) tuple
- `options`: a PoolOptions instance - `options`: a PoolOptions instance
- `handshake`: whether to call hello for each new SocketInfo - `handshake`: whether to call hello for each new Connection
""" """
if options.pause_enabled: if options.pause_enabled:
self.state = PoolState.PAUSED self.state = PoolState.PAUSED
@ -1318,6 +1359,11 @@ class Pool:
self.ncursors = 0 self.ncursors = 0
self.ntxns = 0 self.ntxns = 0
if self.opts.protocol == ConnectionProtocol.GRPC:
self.channel = self._create_grpc_channel()
else:
self.channel = None
def ready(self): def ready(self):
# Take the lock to avoid the race condition described in PYTHON-2699. # Take the lock to avoid the race condition described in PYTHON-2699.
with self.lock: with self.lock:
@ -1348,11 +1394,11 @@ class Pool:
else: else:
discard: collections.deque = collections.deque() discard: collections.deque = collections.deque()
keep: collections.deque = collections.deque() keep: collections.deque = collections.deque()
for sock_info in self.sockets: for connection in self.sockets:
if sock_info.service_id == service_id: if connection.service_id == service_id:
discard.append(sock_info) discard.append(connection)
else: else:
keep.append(sock_info) keep.append(connection)
sockets = discard sockets = discard
self.sockets = keep self.sockets = keep
@ -1367,15 +1413,15 @@ class Pool:
# PoolClosedEvent but that reset() SHOULD close sockets *after* # PoolClosedEvent but that reset() SHOULD close sockets *after*
# publishing the PoolClearedEvent. # publishing the PoolClearedEvent.
if close: if close:
for sock_info in sockets: for connection in sockets:
sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED) connection.close_socket(ConnectionClosedReason.POOL_CLOSED)
if self.enabled_for_cmap: if self.enabled_for_cmap:
listeners.publish_pool_closed(self.address) listeners.publish_pool_closed(self.address)
else: else:
if old_state != PoolState.PAUSED and self.enabled_for_cmap: if old_state != PoolState.PAUSED and self.enabled_for_cmap:
listeners.publish_pool_cleared(self.address, service_id=service_id) listeners.publish_pool_cleared(self.address, service_id=service_id)
for sock_info in sockets: for connection in sockets:
sock_info.close_socket(ConnectionClosedReason.STALE) connection.close_socket(ConnectionClosedReason.STALE)
def update_is_writable(self, is_writable): def update_is_writable(self, is_writable):
"""Updates the is_writable attribute on all sockets currently in the """Updates the is_writable attribute on all sockets currently in the
@ -1395,6 +1441,10 @@ class Pool:
def close(self): def close(self):
self._reset(close=True) self._reset(close=True)
def _create_grpc_channel(self):
connection_string = self.address[0] + ":" + str(self.address[1])
return grpc.insecure_channel(connection_string, options=self.grpc_channel_options())
def stale_generation(self, gen, service_id): def stale_generation(self, gen, service_id):
return self.gen.stale(gen, service_id) return self.gen.stale(gen, service_id)
@ -1415,8 +1465,8 @@ class Pool:
self.sockets self.sockets
and self.sockets[-1].idle_time_seconds() > self.opts.max_idle_time_seconds and self.sockets[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
): ):
sock_info = self.sockets.pop() connection = self.sockets.pop()
sock_info.close_socket(ConnectionClosedReason.IDLE) connection.close_socket(ConnectionClosedReason.IDLE)
while True: while True:
with self.size_cond: with self.size_cond:
@ -1435,14 +1485,14 @@ class Pool:
return return
self._pending += 1 self._pending += 1
incremented = True incremented = True
sock_info = self.connect() connection = self.connect()
with self.lock: with self.lock:
# Close connection and return if the pool was reset during # Close connection and return if the pool was reset during
# socket creation or while acquiring the pool lock. # socket creation or while acquiring the pool lock.
if self.gen.get_overall() != reference_generation: if self.gen.get_overall() != reference_generation:
sock_info.close_socket(ConnectionClosedReason.STALE) connection.close_socket(ConnectionClosedReason.STALE)
return return
self.sockets.appendleft(sock_info) self.sockets.appendleft(connection)
finally: finally:
if incremented: if incremented:
# Notify after adding the socket to the pool. # Notify after adding the socket to the pool.
@ -1455,7 +1505,7 @@ class Pool:
self.size_cond.notify() self.size_cond.notify()
def connect(self, handler=None): def connect(self, handler=None):
"""Connect to Mongo and return a new SocketInfo. """Connect to Mongo and return a new Connection.
Can raise ConnectionFailure. Can raise ConnectionFailure.
@ -1471,7 +1521,12 @@ class Pool:
listeners.publish_connection_created(self.address, conn_id) listeners.publish_connection_created(self.address, conn_id)
try: try:
sock = _configured_socket(self.address, self.opts) if self.opts.protocol == ConnectionProtocol.GRPC:
connector = self.channel.stream_stream(
"/mongodb.CommandService/UnauthenticatedCommandStream"
)
else:
connector = _configured_socket(self.address, self.opts)
except BaseException as error: except BaseException as error:
if self.enabled_for_cmap: if self.enabled_for_cmap:
listeners.publish_connection_closed( listeners.publish_connection_closed(
@ -1483,26 +1538,42 @@ class Pool:
raise raise
sock_info = SocketInfo(sock, self, self.address, conn_id) connection = Connection(connector, self, self.address, conn_id, self.opts.protocol)
try: try:
if self.handshake: if self.handshake:
sock_info.hello() connection.hello()
self.is_writable = sock_info.is_writable self.is_writable = connection.is_writable
if handler: if handler:
handler.contribute_socket(sock_info, completed_handshake=False) handler.contribute_socket(connection, completed_handshake=False)
sock_info.authenticate() connection.authenticate()
except BaseException: except BaseException:
sock_info.close_socket(ConnectionClosedReason.ERROR) connection.close_socket(ConnectionClosedReason.ERROR)
raise raise
return sock_info return connection
def grpc_metadata(self):
return [
("security-uuid", "client-uuid"),
("username", "user"),
("servername", "host.local.10gen.cc"),
("mongodb-wireversion", 18),
("x-forwarded-for", "127.0.0.1:9901"),
]
def grpc_channel_options(self):
return [
("grpc.default_authority", "host.local.10gen.cc"),
("grpc.max_receive_message_length", 48000000),
("grpc.max_send_message_length", 48000000),
]
@contextlib.contextmanager @contextlib.contextmanager
def get_socket(self, handler=None): def get_socket(self, handler=None):
"""Get a socket from the pool. Use with a "with" statement. """Get a socket from the pool. Use with a "with" statement.
Returns a :class:`SocketInfo` object wrapping a connected Returns a :class:`Connection` object wrapping a connected
:class:`socket.socket`. :class:`socket.socket`.
This method should always be used in a with-statement:: This method should always be used in a with-statement::
@ -1520,36 +1591,36 @@ class Pool:
if self.enabled_for_cmap: if self.enabled_for_cmap:
listeners.publish_connection_check_out_started(self.address) listeners.publish_connection_check_out_started(self.address)
sock_info = self._get_socket(handler=handler) connection = self._get_socket(handler=handler)
if self.enabled_for_cmap: if self.enabled_for_cmap:
listeners.publish_connection_checked_out(self.address, sock_info.id) listeners.publish_connection_checked_out(self.address, connection.id)
try: try:
yield sock_info yield connection
except BaseException: except BaseException:
# Exception in caller. Ensure the connection gets returned. # Exception in caller. Ensure the connection gets returned.
# Note that when pinned is True, the session owns the # Note that when pinned is True, the session owns the
# connection and it is responsible for checking the connection # connection and it is responsible for checking the connection
# back into the pool. # back into the pool.
pinned = sock_info.pinned_txn or sock_info.pinned_cursor pinned = connection.pinned_txn or connection.pinned_cursor
if handler: if handler:
# Perform SDAM error handling rules while the connection is # Perform SDAM error handling rules while the connection is
# still checked out. # still checked out.
exc_type, exc_val, _ = sys.exc_info() exc_type, exc_val, _ = sys.exc_info()
handler.handle(exc_type, exc_val) handler.handle(exc_type, exc_val)
if not pinned and sock_info.active: if not pinned and connection.active:
self.return_socket(sock_info) self.return_socket(connection)
raise raise
if sock_info.pinned_txn: if connection.pinned_txn:
with self.lock: with self.lock:
self.__pinned_sockets.add(sock_info) self.__pinned_sockets.add(connection)
self.ntxns += 1 self.ntxns += 1
elif sock_info.pinned_cursor: elif connection.pinned_cursor:
with self.lock: with self.lock:
self.__pinned_sockets.add(sock_info) self.__pinned_sockets.add(connection)
self.ncursors += 1 self.ncursors += 1
elif sock_info.active: elif connection.active:
self.return_socket(sock_info) self.return_socket(connection)
def _raise_if_not_ready(self, emit_event): def _raise_if_not_ready(self, emit_event):
if self.state != PoolState.READY: if self.state != PoolState.READY:
@ -1560,7 +1631,7 @@ class Pool:
_raise_connection_failure(self.address, AutoReconnect("connection pool paused")) _raise_connection_failure(self.address, AutoReconnect("connection pool paused"))
def _get_socket(self, handler=None): def _get_socket(self, handler=None):
"""Get or create a SocketInfo. Can raise ConnectionFailure.""" """Get or create a Connection. Can raise ConnectionFailure."""
# We use the pid here to avoid issues with fork / multiprocessing. # We use the pid here to avoid issues with fork / multiprocessing.
# See test.test_client:TestClient.test_fork for an example of # See test.test_client:TestClient.test_fork for an example of
# what could go wrong otherwise # what could go wrong otherwise
@ -1600,7 +1671,7 @@ class Pool:
self.requests += 1 self.requests += 1
# We've now acquired the semaphore and must release it on error. # We've now acquired the semaphore and must release it on error.
sock_info = None connection = None
incremented = False incremented = False
emitted_event = False emitted_event = False
try: try:
@ -1608,7 +1679,7 @@ class Pool:
self.active_sockets += 1 self.active_sockets += 1
incremented = True incremented = True
while sock_info is None: while connection is None:
# CMAP: we MUST wait for either maxConnecting OR for a socket # CMAP: we MUST wait for either maxConnecting OR for a socket
# to be checked back into the pool. # to be checked back into the pool.
with self._max_connecting_cond: with self._max_connecting_cond:
@ -1624,24 +1695,24 @@ class Pool:
self._raise_if_not_ready(emit_event=False) self._raise_if_not_ready(emit_event=False)
try: try:
sock_info = self.sockets.popleft() connection = self.sockets.popleft()
except IndexError: except IndexError:
self._pending += 1 self._pending += 1
if sock_info: # We got a socket from the pool if connection: # We got a socket from the pool
if self._perished(sock_info): if self._perished(connection):
sock_info = None connection = None
continue continue
else: # We need to create a new connection else: # We need to create a new connection
try: try:
sock_info = self.connect(handler=handler) connection = self.connect(handler=handler)
finally: finally:
with self._max_connecting_cond: with self._max_connecting_cond:
self._pending -= 1 self._pending -= 1
self._max_connecting_cond.notify() self._max_connecting_cond.notify()
except BaseException: except BaseException:
if sock_info: if connection:
# We checked out a socket but authentication failed. # We checked out a socket but authentication failed.
sock_info.close_socket(ConnectionClosedReason.ERROR) connection.close_socket(ConnectionClosedReason.ERROR)
with self.size_cond: with self.size_cond:
self.requests -= 1 self.requests -= 1
if incremented: if incremented:
@ -1654,45 +1725,45 @@ class Pool:
) )
raise raise
sock_info.active = True connection.active = True
return sock_info return connection
def return_socket(self, sock_info): def return_socket(self, connection):
"""Return the socket to the pool, or if it's closed discard it. """Return the socket to the pool, or if it's closed discard it.
:Parameters: :Parameters:
- `sock_info`: The socket to check into the pool. - `connection`: The socket to check into the pool.
""" """
txn = sock_info.pinned_txn txn = connection.pinned_txn
cursor = sock_info.pinned_cursor cursor = connection.pinned_cursor
sock_info.active = False connection.active = False
sock_info.pinned_txn = False connection.pinned_txn = False
sock_info.pinned_cursor = False connection.pinned_cursor = False
self.__pinned_sockets.discard(sock_info) self.__pinned_sockets.discard(connection)
listeners = self.opts._event_listeners listeners = self.opts._event_listeners
if self.enabled_for_cmap: if self.enabled_for_cmap:
listeners.publish_connection_checked_in(self.address, sock_info.id) listeners.publish_connection_checked_in(self.address, connection.id)
if self.pid != os.getpid(): if self.pid != os.getpid():
self.reset_without_pause() self.reset_without_pause()
else: else:
if self.closed: if self.closed:
sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED) connection.close_socket(ConnectionClosedReason.POOL_CLOSED)
elif sock_info.closed: elif connection.closed:
# CMAP requires the closed event be emitted after the check in. # CMAP requires the closed event be emitted after the check in.
if self.enabled_for_cmap: if self.enabled_for_cmap:
listeners.publish_connection_closed( listeners.publish_connection_closed(
self.address, sock_info.id, ConnectionClosedReason.ERROR self.address, connection.id, ConnectionClosedReason.ERROR
) )
else: else:
with self.lock: with self.lock:
# Hold the lock to ensure this section does not race with # Hold the lock to ensure this section does not race with
# Pool.reset(). # Pool.reset().
if self.stale_generation(sock_info.generation, sock_info.service_id): if self.stale_generation(connection.generation, connection.service_id):
sock_info.close_socket(ConnectionClosedReason.STALE) connection.close_socket(ConnectionClosedReason.STALE)
else: else:
sock_info.update_last_checkin_time() connection.update_last_checkin_time()
sock_info.update_is_writable(self.is_writable) connection.update_is_writable(self.is_writable)
self.sockets.appendleft(sock_info) self.sockets.appendleft(connection)
# Notify any threads waiting to create a connection. # Notify any threads waiting to create a connection.
self._max_connecting_cond.notify() self._max_connecting_cond.notify()
@ -1706,7 +1777,7 @@ class Pool:
self.operation_count -= 1 self.operation_count -= 1
self.size_cond.notify() self.size_cond.notify()
def _perished(self, sock_info): def _perished(self, connection):
"""Return True and close the connection if it is "perished". """Return True and close the connection if it is "perished".
This side-effecty function checks if this socket has been idle for This side-effecty function checks if this socket has been idle for
@ -1720,24 +1791,24 @@ class Pool:
pool, to keep performance reasonable - we can't avoid AutoReconnects pool, to keep performance reasonable - we can't avoid AutoReconnects
completely anyway. completely anyway.
""" """
idle_time_seconds = sock_info.idle_time_seconds() idle_time_seconds = connection.idle_time_seconds()
# If socket is idle, open a new one. # If socket is idle, open a new one.
if ( if (
self.opts.max_idle_time_seconds is not None self.opts.max_idle_time_seconds is not None
and idle_time_seconds > self.opts.max_idle_time_seconds and idle_time_seconds > self.opts.max_idle_time_seconds
): ):
sock_info.close_socket(ConnectionClosedReason.IDLE) connection.close_socket(ConnectionClosedReason.IDLE)
return True return True
if self._check_interval_seconds is not None and ( if self._check_interval_seconds is not None and (
0 == self._check_interval_seconds or idle_time_seconds > self._check_interval_seconds 0 == self._check_interval_seconds or idle_time_seconds > self._check_interval_seconds
): ):
if sock_info.socket_closed(): if connection.socket_closed():
sock_info.close_socket(ConnectionClosedReason.ERROR) connection.close_socket(ConnectionClosedReason.ERROR)
return True return True
if self.stale_generation(sock_info.generation, sock_info.service_id): if self.stale_generation(connection.generation, connection.service_id):
sock_info.close_socket(ConnectionClosedReason.STALE) connection.close_socket(ConnectionClosedReason.STALE)
return True return True
return False return False
@ -1772,5 +1843,5 @@ class Pool:
# Avoid ResourceWarnings in Python 3 # Avoid ResourceWarnings in Python 3
# Close all sockets without calling reset() or close() because it is # Close all sockets without calling reset() or close() because it is
# not safe to acquire a lock in __del__. # not safe to acquire a lock in __del__.
for sock_info in self.sockets: for connection in self.sockets:
sock_info.close_socket(None) connection.close_socket(None)

View File

@ -21,7 +21,7 @@ if TYPE_CHECKING:
from datetime import timedelta from datetime import timedelta
from pymongo.message import _OpMsg, _OpReply from pymongo.message import _OpMsg, _OpReply
from pymongo.pool import SocketInfo from pymongo.pool import Connection
from pymongo.typings import _Address from pymongo.typings import _Address
@ -91,7 +91,7 @@ class PinnedResponse(Response):
self, self,
data: Union[_OpMsg, _OpReply], data: Union[_OpMsg, _OpReply],
address: _Address, address: _Address,
socket_info: SocketInfo, socket_info: Connection,
request_id: int, request_id: int,
duration: Optional[timedelta], duration: Optional[timedelta],
from_command: bool, from_command: bool,
@ -103,7 +103,7 @@ class PinnedResponse(Response):
:Parameters: :Parameters:
- `data`: A network response message. - `data`: A network response message.
- `address`: (host, port) of the source server. - `address`: (host, port) of the source server.
- `socket_info`: The SocketInfo used for the initial query. - `socket_info`: The Connection used for the initial query.
- `request_id`: The request id of this operation. - `request_id`: The request id of this operation.
- `duration`: The duration of the operation. - `duration`: The duration of the operation.
- `from_command`: If the response is the result of a db command. - `from_command`: If the response is the result of a db command.
@ -116,8 +116,8 @@ class PinnedResponse(Response):
self._more_to_come = more_to_come self._more_to_come = more_to_come
@property @property
def socket_info(self) -> SocketInfo: def socket_info(self) -> Connection:
"""The SocketInfo used for the initial query. """The Connection used for the initial query.
The server will send batches on this socket, without waiting for The server will send batches on this socket, without waiting for
getMores from the client, until the result set is exhausted or there getMores from the client, until the result set is exhausted or there

View File

@ -42,7 +42,7 @@ if TYPE_CHECKING:
from pymongo.mongo_client import _MongoClientErrorHandler from pymongo.mongo_client import _MongoClientErrorHandler
from pymongo.monitor import Monitor from pymongo.monitor import Monitor
from pymongo.monitoring import _EventListeners from pymongo.monitoring import _EventListeners
from pymongo.pool import Pool, SocketInfo from pymongo.pool import Connection, Pool
from pymongo.server_description import ServerDescription from pymongo.server_description import ServerDescription
_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} _CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}
@ -105,7 +105,7 @@ class Server:
@_handle_reauth @_handle_reauth
def run_operation( def run_operation(
self, self,
sock_info: SocketInfo, connection: Connection,
operation: Union[_Query, _GetMore], operation: Union[_Query, _GetMore],
read_preference: bool, read_preference: bool,
listeners: _EventListeners, listeners: _EventListeners,
@ -118,7 +118,7 @@ class Server:
Can raise ConnectionFailure, OperationFailure, etc. Can raise ConnectionFailure, OperationFailure, etc.
:Parameters: :Parameters:
- `sock_info`: A SocketInfo instance. - `connection`: A Connection instance.
- `operation`: A _Query or _GetMore object. - `operation`: A _Query or _GetMore object.
- `read_preference`: The read preference to use. - `read_preference`: The read preference to use.
- `listeners`: Instance of _EventListeners or None. - `listeners`: Instance of _EventListeners or None.
@ -129,27 +129,27 @@ class Server:
if publish: if publish:
start = datetime.now() start = datetime.now()
use_cmd = operation.use_command(sock_info) use_cmd = operation.use_command(connection)
more_to_come = operation.sock_mgr and operation.sock_mgr.more_to_come more_to_come = operation.sock_mgr and operation.sock_mgr.more_to_come
if more_to_come: if more_to_come:
request_id = 0 request_id = 0
else: else:
message = operation.get_message(read_preference, sock_info, use_cmd) message = operation.get_message(read_preference, connection, use_cmd)
request_id, data, max_doc_size = self._split_message(message) request_id, data, max_doc_size = self._split_message(message)
if publish: if publish:
cmd, dbn = operation.as_command(sock_info) cmd, dbn = operation.as_command(connection)
listeners.publish_command_start( listeners.publish_command_start(
cmd, dbn, request_id, sock_info.address, service_id=sock_info.service_id cmd, dbn, request_id, connection.address, service_id=connection.service_id
) )
start = datetime.now() start = datetime.now()
try: try:
if more_to_come: if more_to_come:
reply = sock_info.receive_message(None) reply = connection.receive_message(None)
else: else:
sock_info.send_message(data, max_doc_size) connection.send_message(data, max_doc_size)
reply = sock_info.receive_message(request_id) reply = connection.receive_message(request_id)
# Unpack and check for command errors. # Unpack and check for command errors.
if use_cmd: if use_cmd:
@ -168,7 +168,7 @@ class Server:
if use_cmd: if use_cmd:
first = docs[0] first = docs[0]
operation.client._process_response(first, operation.session) operation.client._process_response(first, operation.session)
_check_command_response(first, sock_info.max_wire_version) _check_command_response(first, connection.max_wire_version)
except Exception as exc: except Exception as exc:
if publish: if publish:
duration = datetime.now() - start duration = datetime.now() - start
@ -181,8 +181,8 @@ class Server:
failure, failure,
operation.name, operation.name,
request_id, request_id,
sock_info.address, connection.address,
service_id=sock_info.service_id, service_id=connection.service_id,
) )
raise raise
@ -205,8 +205,8 @@ class Server:
res, res,
operation.name, operation.name,
request_id, request_id,
sock_info.address, connection.address,
service_id=sock_info.service_id, service_id=connection.service_id,
) )
# Decrypt response. # Decrypt response.
@ -219,7 +219,7 @@ class Server:
response: Response response: Response
if client._should_pin_cursor(operation.session) or operation.exhaust: if client._should_pin_cursor(operation.session) or operation.exhaust:
sock_info.pin_cursor() connection.pin_cursor()
if isinstance(reply, _OpMsg): if isinstance(reply, _OpMsg):
# In OP_MSG, the server keeps sending only if the # In OP_MSG, the server keeps sending only if the
# more_to_come flag is set. # more_to_come flag is set.
@ -232,7 +232,7 @@ class Server:
response = PinnedResponse( response = PinnedResponse(
data=reply, data=reply,
address=self._description.address, address=self._description.address,
socket_info=sock_info, socket_info=connection,
duration=duration, duration=duration,
request_id=request_id, request_id=request_id,
from_command=use_cmd, from_command=use_cmd,
@ -253,7 +253,7 @@ class Server:
def get_socket( def get_socket(
self, handler: Optional[_MongoClientErrorHandler] = None self, handler: Optional[_MongoClientErrorHandler] = None
) -> ContextManager[SocketInfo]: ) -> ContextManager[Connection]:
return self.pool.get_socket(handler) return self.pool.get_socket(handler)
@property @property

View File

@ -778,6 +778,7 @@ class Topology:
driver=options.driver, driver=options.driver,
pause_enabled=False, pause_enabled=False,
server_api=options.server_api, server_api=options.server_api,
protocol=options.protocol,
) )
return self._settings.pool_class(address, monitor_pool_options, handshake=False) return self._settings.pool_class(address, monitor_pool_options, handshake=False)

View File

@ -41,6 +41,7 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"dnspython>=1.16.0,<3.0.0", "dnspython>=1.16.0,<3.0.0",
"grpcio>=1.56.0"
] ]
[project.optional-dependencies] [project.optional-dependencies]

7
test/grpc_example.py Normal file
View File

@ -0,0 +1,7 @@
from pymongo import MongoClient
client = MongoClient("mongodb://host9.local.10gen.cc:9901", grpc=True, loadBalanced=True)
dbs = client.list_databases()
for db in dbs:
print(db["name"], ": ", client[db["name"]].list_collection_names())

View File

@ -48,10 +48,10 @@ class MockPool(Pool):
client.mock_standalones + client.mock_members + client.mock_mongoses client.mock_standalones + client.mock_members + client.mock_mongoses
), ("bad host: %s" % host_and_port) ), ("bad host: %s" % host_and_port)
with Pool.get_socket(self, handler) as sock_info: with Pool.get_socket(self, handler) as connection:
sock_info.mock_host = self.mock_host connection.mock_host = self.mock_host
sock_info.mock_port = self.mock_port connection.mock_port = self.mock_port
yield sock_info yield connection
class DummyMonitor: class DummyMonitor:

View File

@ -96,7 +96,7 @@ from pymongo.errors import (
) )
from pymongo.mongo_client import MongoClient from pymongo.mongo_client import MongoClient
from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent
from pymongo.pool import _METADATA, PoolOptions, SocketInfo from pymongo.pool import _METADATA, Connection, PoolOptions
from pymongo.read_preferences import ReadPreference from pymongo.read_preferences import ReadPreference
from pymongo.server_description import ServerDescription from pymongo.server_description import ServerDescription
from pymongo.server_selectors import readable_server_selector, writable_server_selector from pymongo.server_selectors import readable_server_selector, writable_server_selector
@ -541,10 +541,10 @@ class TestClient(IntegrationTest):
# Assert reaper doesn't remove sockets when maxIdleTimeMS not set # Assert reaper doesn't remove sockets when maxIdleTimeMS not set
client = rs_or_single_client() client = rs_or_single_client()
server = client._get_topology().select_server(readable_server_selector) server = client._get_topology().select_server(readable_server_selector)
with server._pool.get_socket() as sock_info: with server._pool.get_socket() as connection:
pass pass
self.assertEqual(1, len(server._pool.sockets)) self.assertEqual(1, len(server._pool.sockets))
self.assertTrue(sock_info in server._pool.sockets) self.assertTrue(connection in server._pool.sockets)
client.close() client.close()
def test_max_idle_time_reaper_removes_stale_minPoolSize(self): def test_max_idle_time_reaper_removes_stale_minPoolSize(self):
@ -552,12 +552,12 @@ class TestClient(IntegrationTest):
# Assert reaper removes idle socket and replaces it with a new one # Assert reaper removes idle socket and replaces it with a new one
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1)
server = client._get_topology().select_server(readable_server_selector) server = client._get_topology().select_server(readable_server_selector)
with server._pool.get_socket() as sock_info: with server._pool.get_socket() as connection:
pass pass
# When the reaper runs at the same time as the get_socket, two # When the reaper runs at the same time as the get_socket, two
# sockets could be created and checked into the pool. # sockets could be created and checked into the pool.
self.assertGreaterEqual(len(server._pool.sockets), 1) self.assertGreaterEqual(len(server._pool.sockets), 1)
wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket") wait_until(lambda: connection not in server._pool.sockets, "remove stale socket")
wait_until(lambda: 1 <= len(server._pool.sockets), "replace stale socket") wait_until(lambda: 1 <= len(server._pool.sockets), "replace stale socket")
client.close() client.close()
@ -566,12 +566,12 @@ class TestClient(IntegrationTest):
# Assert reaper respects maxPoolSize when adding new sockets. # Assert reaper respects maxPoolSize when adding new sockets.
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1)
server = client._get_topology().select_server(readable_server_selector) server = client._get_topology().select_server(readable_server_selector)
with server._pool.get_socket() as sock_info: with server._pool.get_socket() as connection:
pass pass
# When the reaper runs at the same time as the get_socket, # When the reaper runs at the same time as the get_socket,
# maxPoolSize=1 should prevent two sockets from being created. # maxPoolSize=1 should prevent two sockets from being created.
self.assertEqual(1, len(server._pool.sockets)) self.assertEqual(1, len(server._pool.sockets))
wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket") wait_until(lambda: connection not in server._pool.sockets, "remove stale socket")
wait_until(lambda: 1 == len(server._pool.sockets), "replace stale socket") wait_until(lambda: 1 == len(server._pool.sockets), "replace stale socket")
client.close() client.close()
@ -605,39 +605,39 @@ class TestClient(IntegrationTest):
wait_until(lambda: 10 == len(server._pool.sockets), "pool initialized with 10 sockets") wait_until(lambda: 10 == len(server._pool.sockets), "pool initialized with 10 sockets")
# Assert that if a socket is closed, a new one takes its place # Assert that if a socket is closed, a new one takes its place
with server._pool.get_socket() as sock_info: with server._pool.get_socket() as connection:
sock_info.close_socket(None) connection.close_socket(None)
wait_until( wait_until(
lambda: 10 == len(server._pool.sockets), lambda: 10 == len(server._pool.sockets),
"a closed socket gets replaced from the pool", "a closed socket gets replaced from the pool",
) )
self.assertFalse(sock_info in server._pool.sockets) self.assertFalse(connection in server._pool.sockets)
def test_max_idle_time_checkout(self): def test_max_idle_time_checkout(self):
# Use high frequency to test _get_socket_no_auth. # Use high frequency to test _get_socket_no_auth.
with client_knobs(kill_cursor_frequency=99999999): with client_knobs(kill_cursor_frequency=99999999):
client = rs_or_single_client(maxIdleTimeMS=500) client = rs_or_single_client(maxIdleTimeMS=500)
server = client._get_topology().select_server(readable_server_selector) server = client._get_topology().select_server(readable_server_selector)
with server._pool.get_socket() as sock_info: with server._pool.get_socket() as connection:
pass pass
self.assertEqual(1, len(server._pool.sockets)) self.assertEqual(1, len(server._pool.sockets))
time.sleep(1) # Sleep so that the socket becomes stale. time.sleep(1) # Sleep so that the socket becomes stale.
with server._pool.get_socket() as new_sock_info: with server._pool.get_socket() as new_sock_info:
self.assertNotEqual(sock_info, new_sock_info) self.assertNotEqual(connection, new_sock_info)
self.assertEqual(1, len(server._pool.sockets)) self.assertEqual(1, len(server._pool.sockets))
self.assertFalse(sock_info in server._pool.sockets) self.assertFalse(connection in server._pool.sockets)
self.assertTrue(new_sock_info in server._pool.sockets) self.assertTrue(new_sock_info in server._pool.sockets)
# Test that sockets are reused if maxIdleTimeMS is not set. # Test that sockets are reused if maxIdleTimeMS is not set.
client = rs_or_single_client() client = rs_or_single_client()
server = client._get_topology().select_server(readable_server_selector) server = client._get_topology().select_server(readable_server_selector)
with server._pool.get_socket() as sock_info: with server._pool.get_socket() as connection:
pass pass
self.assertEqual(1, len(server._pool.sockets)) self.assertEqual(1, len(server._pool.sockets))
time.sleep(1) time.sleep(1)
with server._pool.get_socket() as new_sock_info: with server._pool.get_socket() as new_sock_info:
self.assertEqual(sock_info, new_sock_info) self.assertEqual(connection, new_sock_info)
self.assertEqual(1, len(server._pool.sockets)) self.assertEqual(1, len(server._pool.sockets))
def test_constants(self): def test_constants(self):
@ -1130,8 +1130,8 @@ class TestClient(IntegrationTest):
def test_socketKeepAlive(self): def test_socketKeepAlive(self):
pool = get_pool(self.client) pool = get_pool(self.client)
with pool.get_socket() as sock_info: with pool.get_socket() as connection:
keepalive = sock_info.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) keepalive = connection.connector.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
self.assertTrue(keepalive) self.assertTrue(keepalive)
@no_type_check @no_type_check
@ -1326,13 +1326,13 @@ class TestClient(IntegrationTest):
connected(client) connected(client)
# Cause a network error. # Cause a network error.
sock_info = one(pool.sockets) connection = one(pool.sockets)
sock_info.sock.close() connection.connector.close()
cursor = collection.find(cursor_type=CursorType.EXHAUST) cursor = collection.find(cursor_type=CursorType.EXHAUST)
with self.assertRaises(ConnectionFailure): with self.assertRaises(ConnectionFailure):
next(cursor) next(cursor)
self.assertTrue(sock_info.closed) self.assertTrue(connection.closed)
# The semaphore was decremented despite the error. # The semaphore was decremented despite the error.
self.assertEqual(0, pool.requests) self.assertEqual(0, pool.requests)
@ -1350,7 +1350,7 @@ class TestClient(IntegrationTest):
socket_info = one(pool.sockets) socket_info = one(pool.sockets)
socket_info.sock.close() socket_info.sock.close()
# SocketInfo.authenticate logs, but gets a socket.error. Should be # Connection.authenticate logs, but gets a socket.error. Should be
# reraised as AutoReconnect. # reraised as AutoReconnect.
self.assertRaises(AutoReconnect, c.test.collection.find_one) self.assertRaises(AutoReconnect, c.test.collection.find_one)
@ -1847,7 +1847,7 @@ class TestExhaustCursor(IntegrationTest):
collection = client.pymongo_test.test collection = client.pymongo_test.test
pool = get_pool(client) pool = get_pool(client)
sock_info = one(pool.sockets) connection = one(pool.sockets)
# This will cause OperationFailure in all mongo versions since # This will cause OperationFailure in all mongo versions since
# the value for $orderby must be a document. # the value for $orderby must be a document.
@ -1856,10 +1856,10 @@ class TestExhaustCursor(IntegrationTest):
) )
self.assertRaises(OperationFailure, cursor.next) self.assertRaises(OperationFailure, cursor.next)
self.assertFalse(sock_info.closed) self.assertFalse(connection.closed)
# The socket was checked in and the semaphore was decremented. # The socket was checked in and the semaphore was decremented.
self.assertIn(sock_info, pool.sockets) self.assertIn(connection, pool.sockets)
self.assertEqual(0, pool.requests) self.assertEqual(0, pool.requests)
def test_exhaust_getmore_server_error(self): def test_exhaust_getmore_server_error(self):
@ -1874,7 +1874,7 @@ class TestExhaustCursor(IntegrationTest):
pool = get_pool(client) pool = get_pool(client)
pool._check_interval_seconds = None # Never check. pool._check_interval_seconds = None # Never check.
sock_info = one(pool.sockets) connection = one(pool.sockets)
cursor = collection.find(cursor_type=CursorType.EXHAUST) cursor = collection.find(cursor_type=CursorType.EXHAUST)
@ -1884,21 +1884,21 @@ class TestExhaustCursor(IntegrationTest):
# Cause a server error on getmore. # Cause a server error on getmore.
def receive_message(request_id): def receive_message(request_id):
# Discard the actual server response. # Discard the actual server response.
SocketInfo.receive_message(sock_info, request_id) Connection.receive_message(connection, request_id)
# responseFlags bit 1 is QueryFailure. # responseFlags bit 1 is QueryFailure.
msg = struct.pack("<iiiii", 1 << 1, 0, 0, 0, 0) msg = struct.pack("<iiiii", 1 << 1, 0, 0, 0, 0)
msg += encode({"$err": "mock err", "code": 0}) msg += encode({"$err": "mock err", "code": 0})
return message._OpReply.unpack(msg) return message._OpReply.unpack(msg)
sock_info.receive_message = receive_message connection.receive_message = receive_message
self.assertRaises(OperationFailure, list, cursor) self.assertRaises(OperationFailure, list, cursor)
# Unpatch the instance. # Unpatch the instance.
del sock_info.receive_message del connection.receive_message
# The socket is returned to the pool and it still works. # The socket is returned to the pool and it still works.
self.assertEqual(200, collection.count_documents({})) self.assertEqual(200, collection.count_documents({}))
self.assertIn(sock_info, pool.sockets) self.assertIn(connection, pool.sockets)
def test_exhaust_query_network_error(self): def test_exhaust_query_network_error(self):
# When doing an exhaust query, the socket stays checked out on success # When doing an exhaust query, the socket stays checked out on success
@ -1909,15 +1909,15 @@ class TestExhaustCursor(IntegrationTest):
pool._check_interval_seconds = None # Never check. pool._check_interval_seconds = None # Never check.
# Cause a network error. # Cause a network error.
sock_info = one(pool.sockets) connection = one(pool.sockets)
sock_info.sock.close() connection.connector.close()
cursor = collection.find(cursor_type=CursorType.EXHAUST) cursor = collection.find(cursor_type=CursorType.EXHAUST)
self.assertRaises(ConnectionFailure, cursor.next) self.assertRaises(ConnectionFailure, cursor.next)
self.assertTrue(sock_info.closed) self.assertTrue(connection.closed)
# The socket was closed and the semaphore was decremented. # The socket was closed and the semaphore was decremented.
self.assertNotIn(sock_info, pool.sockets) self.assertNotIn(connection, pool.sockets)
self.assertEqual(0, pool.requests) self.assertEqual(0, pool.requests)
def test_exhaust_getmore_network_error(self): def test_exhaust_getmore_network_error(self):
@ -1936,15 +1936,15 @@ class TestExhaustCursor(IntegrationTest):
cursor.next() cursor.next()
# Cause a network error. # Cause a network error.
sock_info = cursor._Cursor__sock_mgr.sock connection = cursor._Cursor__sock_mgr.sock
sock_info.sock.close() connection.connector.close()
# A getmore fails. # A getmore fails.
self.assertRaises(ConnectionFailure, list, cursor) self.assertRaises(ConnectionFailure, list, cursor)
self.assertTrue(sock_info.closed) self.assertTrue(connection.closed)
# The socket was closed and the semaphore was decremented. # The socket was closed and the semaphore was decremented.
self.assertNotIn(sock_info, pool.sockets) self.assertNotIn(connection, pool.sockets)
self.assertEqual(0, pool.requests) self.assertEqual(0, pool.requests)
def test_gevent_task(self): def test_gevent_task(self):

View File

@ -123,19 +123,19 @@ class TestCMAP(IntegrationTest):
def check_out(self, op): def check_out(self, op):
"""Run the 'checkOut' operation.""" """Run the 'checkOut' operation."""
label = op["label"] label = op["label"]
with self.pool.get_socket() as sock_info: with self.pool.get_socket() as connection:
# Call 'pin_cursor' so we can hold the socket. # Call 'pin_cursor' so we can hold the socket.
sock_info.pin_cursor() connection.pin_cursor()
if label: if label:
self.labels[label] = sock_info self.labels[label] = connection
else: else:
self.addCleanup(sock_info.close_socket, None) self.addCleanup(connection.close_socket, None)
def check_in(self, op): def check_in(self, op):
"""Run the 'checkIn' operation.""" """Run the 'checkIn' operation."""
label = op["connection"] label = op["connection"]
sock_info = self.labels[label] connection = self.labels[label]
self.pool.return_socket(sock_info) self.pool.return_socket(connection)
def ready(self, op): def ready(self, op):
"""Run the 'ready' operation.""" """Run the 'ready' operation."""

View File

@ -94,7 +94,7 @@ def got_app_error(topology, app_error):
when = app_error["when"] when = app_error["when"]
max_wire_version = app_error["maxWireVersion"] max_wire_version = app_error["maxWireVersion"]
# XXX: We could get better test coverage by mocking the errors on the # XXX: We could get better test coverage by mocking the errors on the
# Pool/SocketInfo. # Pool/Connection.
try: try:
if error_type == "command": if error_type == "command":
_check_command_response(app_error["response"], max_wire_version) _check_command_response(app_error["response"], max_wire_version)
@ -279,7 +279,7 @@ class TestIgnoreStaleErrors(IntegrationTest):
def mock_command(*args, **kwargs): def mock_command(*args, **kwargs):
# Synchronize all threads to ensure they use the same generation. # Synchronize all threads to ensure they use the same generation.
barrier.wait() barrier.wait()
raise AutoReconnect("mock SocketInfo.command error") raise AutoReconnect("mock Connection.command error")
for sock in pool.sockets: for sock in pool.sockets:
sock.command = mock_command sock.command = mock_command

View File

@ -188,11 +188,11 @@ class TestPooling(_TestPoolingBase):
# Test Pool's _check_closed() method doesn't close a healthy socket. # Test Pool's _check_closed() method doesn't close a healthy socket.
cx_pool = self.create_pool(max_pool_size=10) cx_pool = self.create_pool(max_pool_size=10)
cx_pool._check_interval_seconds = 0 # Always check. cx_pool._check_interval_seconds = 0 # Always check.
with cx_pool.get_socket() as sock_info: with cx_pool.get_socket() as connection:
pass pass
with cx_pool.get_socket() as new_sock_info: with cx_pool.get_socket() as new_sock_info:
self.assertEqual(sock_info, new_sock_info) self.assertEqual(connection, new_sock_info)
self.assertEqual(1, len(cx_pool.sockets)) self.assertEqual(1, len(cx_pool.sockets))
@ -200,12 +200,12 @@ class TestPooling(_TestPoolingBase):
# get_socket() returns socket after a non-network error. # get_socket() returns socket after a non-network error.
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1) cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
with self.assertRaises(ZeroDivisionError): with self.assertRaises(ZeroDivisionError):
with cx_pool.get_socket() as sock_info: with cx_pool.get_socket() as connection:
1 / 0 1 / 0
# Socket was returned, not closed. # Socket was returned, not closed.
with cx_pool.get_socket() as new_sock_info: with cx_pool.get_socket() as new_sock_info:
self.assertEqual(sock_info, new_sock_info) self.assertEqual(connection, new_sock_info)
self.assertEqual(1, len(cx_pool.sockets)) self.assertEqual(1, len(cx_pool.sockets))
@ -213,9 +213,9 @@ class TestPooling(_TestPoolingBase):
# Test that Pool removes explicitly closed socket. # Test that Pool removes explicitly closed socket.
cx_pool = self.create_pool() cx_pool = self.create_pool()
with cx_pool.get_socket() as sock_info: with cx_pool.get_socket() as connection:
# Use SocketInfo's API to close the socket. # Use Connection's API to close the socket.
sock_info.close_socket(None) connection.close_socket(None)
self.assertEqual(0, len(cx_pool.sockets)) self.assertEqual(0, len(cx_pool.sockets))
@ -225,15 +225,15 @@ class TestPooling(_TestPoolingBase):
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1) cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
cx_pool._check_interval_seconds = 0 # Always check. cx_pool._check_interval_seconds = 0 # Always check.
with cx_pool.get_socket() as sock_info: with cx_pool.get_socket() as connection:
# Simulate a closed socket without telling the SocketInfo it's # Simulate a closed socket without telling the Connection it's
# closed. # closed.
sock_info.sock.close() connection.connector.close()
self.assertTrue(sock_info.socket_closed()) self.assertTrue(connection.socket_closed())
with cx_pool.get_socket() as new_sock_info: with cx_pool.get_socket() as new_sock_info:
self.assertEqual(0, len(cx_pool.sockets)) self.assertEqual(0, len(cx_pool.sockets))
self.assertNotEqual(sock_info, new_sock_info) self.assertNotEqual(connection, new_sock_info)
self.assertEqual(1, len(cx_pool.sockets)) self.assertEqual(1, len(cx_pool.sockets))
@ -299,10 +299,10 @@ class TestPooling(_TestPoolingBase):
cx_pool._check_interval_seconds = 0 # Always check. cx_pool._check_interval_seconds = 0 # Always check.
self.addCleanup(cx_pool.close) self.addCleanup(cx_pool.close)
with cx_pool.get_socket() as sock_info: with cx_pool.get_socket() as connection:
# Simulate a closed socket without telling the SocketInfo it's # Simulate a closed socket without telling the Connection it's
# closed. # closed.
sock_info.sock.close() connection.connector.close()
# Swap pool's address with a bad one. # Swap pool's address with a bad one.
address, cx_pool.address = cx_pool.address, ("foo.com", 1234) address, cx_pool.address = cx_pool.address, ("foo.com", 1234)

View File

@ -286,16 +286,16 @@ class ReadPrefTester(MongoClient):
@contextlib.contextmanager @contextlib.contextmanager
def _socket_for_reads(self, read_preference, session): def _socket_for_reads(self, read_preference, session):
context = super()._socket_for_reads(read_preference, session) context = super()._socket_for_reads(read_preference, session)
with context as (sock_info, read_preference): with context as (connection, read_preference):
self.record_a_read(sock_info.address) self.record_a_read(connection.address)
yield sock_info, read_preference yield connection, read_preference
@contextlib.contextmanager @contextlib.contextmanager
def _socket_from_server(self, read_preference, server, session): def _socket_from_server(self, read_preference, server, session):
context = super()._socket_from_server(read_preference, server, session) context = super()._socket_from_server(read_preference, server, session)
with context as (sock_info, read_preference): with context as (connection, read_preference):
self.record_a_read(sock_info.address) self.record_a_read(connection.address)
yield sock_info, read_preference yield connection, read_preference
def record_a_read(self, address): def record_a_read(self, address):
server = self._get_topology().select_server_by_address(address, 0) server = self._get_topology().select_server_by_address(address, 0)