diff --git a/pymongo/aggregation.py b/pymongo/aggregation.py index cd86564fe..3d3f34ab1 100644 --- a/pymongo/aggregation.py +++ b/pymongo/aggregation.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: from pymongo.collection import Collection from pymongo.command_cursor import CommandCursor from pymongo.database import Database - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.read_preferences import _ServerMode from pymongo.server import Server from pymongo.typings import _Pipeline @@ -52,7 +52,7 @@ class _AggregationCommand: explicit_session: bool, let: Optional[Mapping[str, Any]] = None, user_fields: Optional[MutableMapping[str, Any]] = None, - result_processor: Optional[Callable[[Mapping[str, Any], SocketInfo], None]] = None, + result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None, comment: Any = None, ) -> None: if "explain" in options: @@ -134,7 +134,7 @@ class _AggregationCommand: self, session: ClientSession, server: Server, - sock_info: SocketInfo, + connection: Connection, read_preference: _ServerMode, ) -> CommandCursor: # Serialize command. @@ -146,7 +146,7 @@ class _AggregationCommand: # - server version is >= 4.2 or # - server version is >= 3.2 and pipeline doesn't use $out if ("readConcern" not in cmd) and ( - not self._performs_write or (sock_info.max_wire_version >= 8) + not self._performs_write or (connection.max_wire_version >= 8) ): read_concern = self._target.read_concern else: @@ -161,7 +161,7 @@ class _AggregationCommand: write_concern = None # Run command. - result = sock_info.command( + result = connection.command( self._database.name, cmd, read_preference, @@ -176,7 +176,7 @@ class _AggregationCommand: ) if self._result_processor: - self._result_processor(result, sock_info) + self._result_processor(result, connection) # Extract cursor from result or mock/fake one if necessary. if "cursor" in result: @@ -193,14 +193,14 @@ class _AggregationCommand: cmd_cursor = self._cursor_class( self._cursor_collection(cursor), cursor, - sock_info.address, + connection.address, batch_size=self._batch_size or 0, max_await_time_ms=self._max_await_time_ms, session=session, explicit_session=self._explicit_session, comment=self._options.get("comment"), ) - cmd_cursor._maybe_pin_connection(sock_info) + cmd_cursor._maybe_pin_connection(connection) return cmd_cursor diff --git a/pymongo/auth.py b/pymongo/auth.py index b41e88542..c3617abc3 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -35,7 +35,7 @@ from pymongo.saslprep import saslprep if TYPE_CHECKING: from pymongo.hello import Hello - from pymongo.pool import SocketInfo + from pymongo.pool import Connection HAVE_KERBEROS = True _USE_PRINCIPAL = False @@ -221,7 +221,7 @@ def _authenticate_scram_start( def _authenticate_scram( - credentials: MongoCredential, sock_info: SocketInfo, mechanism: str + credentials: MongoCredential, connection: Connection, mechanism: str ) -> None: """Authenticate using SCRAM.""" username = credentials.username @@ -239,13 +239,13 @@ def _authenticate_scram( # Make local _hmac = hmac.HMAC - ctx = sock_info.auth_ctx + ctx = connection.auth_ctx if ctx and ctx.speculate_succeeded(): nonce, first_bare = ctx.scram_data res = ctx.speculative_authenticate else: nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism) - res = sock_info.command(source, cmd) + res = connection.command(source, cmd) server_first = res["payload"] parsed = _parse_scram_response(server_first) @@ -285,7 +285,7 @@ def _authenticate_scram( ("payload", Binary(client_final)), ] ) - res = sock_info.command(source, cmd) + res = connection.command(source, cmd) parsed = _parse_scram_response(res["payload"]) if not hmac.compare_digest(parsed[b"v"], server_sig): @@ -301,7 +301,7 @@ def _authenticate_scram( ("payload", Binary(b"")), ] ) - res = sock_info.command(source, cmd) + res = connection.command(source, cmd) if not res["done"]: raise OperationFailure("SASL conversation failed to complete.") @@ -345,7 +345,7 @@ def _canonicalize_hostname(hostname: str) -> str: return name[0].lower() -def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> None: +def _authenticate_gssapi(credentials: MongoCredential, connection: Connection) -> None: """Authenticate using GSSAPI.""" if not HAVE_KERBEROS: raise ConfigurationError( @@ -358,7 +358,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> props = credentials.mechanism_properties # Starting here and continuing through the while loop below - establish # the security context. See RFC 4752, Section 3.1, first paragraph. - host = sock_info.address[0] + host = connection.address[0] if props.canonicalize_host_name: host = _canonicalize_hostname(host) service = props.service_name + "@" + host @@ -413,7 +413,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> ("autoAuthorize", 1), ] ) - response = sock_info.command("$external", cmd) + response = connection.command("$external", cmd) # Limit how many times we loop to catch protocol / library issues for _ in range(10): @@ -430,7 +430,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> ("payload", payload), ] ) - response = sock_info.command("$external", cmd) + response = connection.command("$external", cmd) if result == kerberos.AUTH_GSS_COMPLETE: break @@ -453,7 +453,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> ("payload", payload), ] ) - sock_info.command("$external", cmd) + connection.command("$external", cmd) finally: kerberos.authGSSClientClean(ctx) @@ -462,7 +462,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> raise OperationFailure(str(exc)) -def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) -> None: +def _authenticate_plain(credentials: MongoCredential, connection: Connection) -> None: """Authenticate using SASL PLAIN (RFC 4616)""" source = credentials.source username = credentials.username @@ -476,52 +476,52 @@ def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) -> ("autoAuthorize", 1), ] ) - sock_info.command(source, cmd) + connection.command(source, cmd) -def _authenticate_x509(credentials: MongoCredential, sock_info: SocketInfo) -> None: +def _authenticate_x509(credentials: MongoCredential, connection: Connection) -> None: """Authenticate using MONGODB-X509.""" - ctx = sock_info.auth_ctx + ctx = connection.auth_ctx if ctx and ctx.speculate_succeeded(): # MONGODB-X509 is done after the speculative auth step. return - cmd = _X509Context(credentials, sock_info.address).speculate_command() - sock_info.command("$external", cmd) + cmd = _X509Context(credentials, connection.address).speculate_command() + connection.command("$external", cmd) -def _authenticate_mongo_cr(credentials: MongoCredential, sock_info: SocketInfo) -> None: +def _authenticate_mongo_cr(credentials: MongoCredential, connection: Connection) -> None: """Authenticate using MONGODB-CR.""" source = credentials.source username = credentials.username password = credentials.password # Get a nonce - response = sock_info.command(source, {"getnonce": 1}) + response = connection.command(source, {"getnonce": 1}) nonce = response["nonce"] key = _auth_key(nonce, username, password) # Actually authenticate query = SON([("authenticate", 1), ("user", username), ("nonce", nonce), ("key", key)]) - sock_info.command(source, query) + connection.command(source, query) -def _authenticate_default(credentials: MongoCredential, sock_info: SocketInfo) -> None: - if sock_info.max_wire_version >= 7: - if sock_info.negotiated_mechs: - mechs = sock_info.negotiated_mechs +def _authenticate_default(credentials: MongoCredential, connection: Connection) -> None: + if connection.max_wire_version >= 7: + if connection.negotiated_mechs: + mechs = connection.negotiated_mechs else: source = credentials.source - cmd = sock_info.hello_cmd() + cmd = connection.hello_cmd() cmd["saslSupportedMechs"] = source + "." + credentials.username - mechs = sock_info.command(source, cmd, publish_events=False).get( + mechs = connection.command(source, cmd, publish_events=False).get( "saslSupportedMechs", [] ) if "SCRAM-SHA-256" in mechs: - return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-256") + return _authenticate_scram(credentials, connection, "SCRAM-SHA-256") else: - return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1") + return _authenticate_scram(credentials, connection, "SCRAM-SHA-1") else: - return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1") + return _authenticate_scram(credentials, connection, "SCRAM-SHA-1") _AUTH_MAP: Mapping[str, Callable] = { @@ -606,12 +606,12 @@ _SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = { def authenticate( - credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool = False + credentials: MongoCredential, connection: Connection, reauthenticate: bool = False ) -> None: - """Authenticate sock_info.""" + """Authenticate connection.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] if mechanism == "MONGODB-OIDC": - _authenticate_oidc(credentials, sock_info, reauthenticate) + _authenticate_oidc(credentials, connection, reauthenticate) else: - auth_func(credentials, sock_info) + auth_func(credentials, connection) diff --git a/pymongo/auth_aws.py b/pymongo/auth_aws.py index edefd3c93..36b070bd3 100644 --- a/pymongo/auth_aws.py +++ b/pymongo/auth_aws.py @@ -49,7 +49,7 @@ from pymongo.errors import ConfigurationError, OperationFailure if TYPE_CHECKING: from bson.typings import _ReadableBuffer from pymongo.auth import MongoCredential - from pymongo.pool import SocketInfo + from pymongo.pool import Connection class _AwsSaslContext(AwsSaslContext): # type: ignore @@ -67,7 +67,7 @@ class _AwsSaslContext(AwsSaslContext): # type: ignore return bson.decode(data) -def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> None: +def _authenticate_aws(credentials: MongoCredential, connection: Connection) -> None: """Authenticate using MONGODB-AWS.""" if not _HAVE_MONGODB_AWS: raise ConfigurationError( @@ -75,7 +75,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No "install with: python -m pip install 'pymongo[aws]'" ) - if sock_info.max_wire_version < 9: + if connection.max_wire_version < 9: raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later") try: @@ -90,7 +90,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No client_first = SON( [("saslStart", 1), ("mechanism", "MONGODB-AWS"), ("payload", client_payload)] ) - server_first = sock_info.command("$external", client_first) + server_first = connection.command("$external", client_first) res = server_first # Limit how many times we loop to catch protocol / library issues for _ in range(10): @@ -102,7 +102,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No ("payload", client_payload), ] ) - res = sock_info.command("$external", cmd) + res = connection.command("$external", cmd) if res["done"]: # SASL complete. break diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 62648b2c0..f38b6fc85 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -29,7 +29,7 @@ from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE if TYPE_CHECKING: from pymongo.auth import MongoCredential - from pymongo.pool import SocketInfo + from pymongo.pool import Connection @dataclass @@ -243,24 +243,24 @@ class _OIDCAuthenticator: self.token_exp_utc = None def run_command( - self, sock_info: SocketInfo, cmd: Mapping[str, Any] + self, connection: Connection, cmd: Mapping[str, Any] ) -> Optional[Mapping[str, Any]]: try: - return sock_info.command("$external", cmd, no_reauth=True) # type: ignore[call-arg] + return connection.command("$external", cmd, no_reauth=True) # type: ignore[call-arg] except OperationFailure as exc: self.clear() if exc.code == _REAUTHENTICATION_REQUIRED_CODE: if "jwt" in bson.decode(cmd["payload"]): if self.idp_info_gen_id > self.reauth_gen_id: raise - return self.authenticate(sock_info, reauthenticate=True) + return self.authenticate(connection, reauthenticate=True) raise def authenticate( - self, sock_info: SocketInfo, reauthenticate: bool = False + self, connection: Connection, reauthenticate: bool = False ) -> Optional[Mapping[str, Any]]: if reauthenticate: - prev_id = getattr(sock_info, "oidc_token_gen_id", None) + prev_id = getattr(connection, "oidc_token_gen_id", None) # Check if we've already changed tokens. if prev_id == self.token_gen_id: self.reauth_gen_id = self.idp_info_gen_id @@ -268,7 +268,7 @@ class _OIDCAuthenticator: if not self.properties.refresh_token_callback: self.clear() - ctx = sock_info.auth_ctx + ctx = connection.auth_ctx cmd = None if ctx and ctx.speculate_succeeded(): @@ -276,10 +276,10 @@ class _OIDCAuthenticator: else: cmd = self.auth_start_cmd() assert cmd is not None - resp = self.run_command(sock_info, cmd) + resp = self.run_command(connection, cmd) if resp["done"]: - sock_info.oidc_token_gen_id = self.token_gen_id + connection.oidc_token_gen_id = self.token_gen_id return None server_resp: Dict = bson.decode(resp["payload"]) @@ -289,7 +289,7 @@ class _OIDCAuthenticator: conversation_id = resp["conversationId"] token = self.get_current_token() - sock_info.oidc_token_gen_id = self.token_gen_id + connection.oidc_token_gen_id = self.token_gen_id bin_payload = Binary(bson.encode({"jwt": token})) cmd = SON( [ @@ -298,7 +298,7 @@ class _OIDCAuthenticator: ("payload", bin_payload), ] ) - resp = self.run_command(sock_info, cmd) + resp = self.run_command(connection, cmd) if not resp["done"]: self.clear() raise OperationFailure("SASL conversation failed to complete.") @@ -306,8 +306,8 @@ class _OIDCAuthenticator: def _authenticate_oidc( - credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool + credentials: MongoCredential, connection: Connection, reauthenticate: bool ) -> Optional[Mapping[str, Any]]: """Authenticate using MONGODB-OIDC.""" - authenticator = _get_authenticator(credentials, sock_info.address) - return authenticator.authenticate(sock_info, reauthenticate=reauthenticate) + authenticator = _get_authenticator(credentials, connection.address) + return authenticator.authenticate(connection, reauthenticate=reauthenticate) diff --git a/pymongo/bulk.py b/pymongo/bulk.py index f7fdc805b..08ba35f15 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -65,7 +65,7 @@ from pymongo.write_concern import WriteConcern if TYPE_CHECKING: from pymongo.collection import Collection - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline _DELETE_ALL: int = 0 @@ -311,7 +311,7 @@ class _Bulk: generator: Iterator[Any], write_concern: WriteConcern, session: Optional[ClientSession], - sock_info: SocketInfo, + connection: Connection, op_id: int, retryable: bool, full_result: MutableMapping[str, Any], @@ -326,9 +326,9 @@ class _Bulk: self.next_run = None run = self.current_run - # sock_info.command validates the session, but we use - # sock_info.write_command. - sock_info.validate_session(client, session) + # connection.command validates the session, but we use + # connection.write_command. + connection.validate_session(client, session) last_run = False while run: @@ -341,7 +341,7 @@ class _Bulk: bwc = self.bulk_ctx_class( db_name, cmd_name, - sock_info, + connection, op_id, listeners, session, @@ -369,11 +369,11 @@ class _Bulk: if retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, sock_info) - sock_info.send_cluster_time(cmd, session, client) - sock_info.add_server_api(cmd) + session._apply_to(cmd, retryable, ReadPreference.PRIMARY, connection) + connection.send_cluster_time(cmd, session, client) + connection.add_server_api(cmd) # CSOT: apply timeout before encoding the command. - sock_info.apply_timeout(client, cmd) + connection.apply_timeout(client, cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible in one command. @@ -430,13 +430,13 @@ class _Bulk: op_id = _randint() def retryable_bulk( - session: Optional[ClientSession], sock_info: SocketInfo, retryable: bool + session: Optional[ClientSession], connection: Connection, retryable: bool ) -> None: self._execute_command( generator, write_concern, session, - sock_info, + connection, op_id, retryable, full_result, @@ -450,7 +450,7 @@ class _Bulk: _raise_bulk_write_error(full_result) return full_result - def execute_op_msg_no_results(self, sock_info: SocketInfo, generator: Iterator[Any]) -> None: + def execute_op_msg_no_results(self, connection: Connection, generator: Iterator[Any]) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name client = self.collection.database.client @@ -466,7 +466,7 @@ class _Bulk: bwc = self.bulk_ctx_class( db_name, cmd_name, - sock_info, + connection, op_id, listeners, None, @@ -482,7 +482,7 @@ class _Bulk: ("writeConcern", {"w": 0}), ] ) - sock_info.add_server_api(cmd) + connection.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. to_send = bwc.execute_unack(cmd, ops, client) @@ -491,7 +491,7 @@ class _Bulk: def execute_command_no_results( self, - sock_info: SocketInfo, + connection: Connection, generator: Iterator[Any], write_concern: WriteConcern, ) -> None: @@ -516,7 +516,7 @@ class _Bulk: generator, initial_write_concern, None, - sock_info, + connection, op_id, False, full_result, @@ -527,7 +527,7 @@ class _Bulk: def execute_no_results( self, - sock_info: SocketInfo, + connection: Connection, generator: Iterator[Any], write_concern: WriteConcern, ) -> None: @@ -538,11 +538,11 @@ class _Bulk: raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.") # Guard against unsupported unacknowledged writes. unack = write_concern and not write_concern.acknowledged - if unack and self.uses_hint_delete and sock_info.max_wire_version < 9: + if unack and self.uses_hint_delete and connection.max_wire_version < 9: raise ConfigurationError( "Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands." ) - if unack and self.uses_hint_update and sock_info.max_wire_version < 8: + if unack and self.uses_hint_update and connection.max_wire_version < 8: raise ConfigurationError( "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." ) @@ -553,8 +553,8 @@ class _Bulk: ) if self.ordered: - return self.execute_command_no_results(sock_info, generator, write_concern) - return self.execute_op_msg_no_results(sock_info, generator) + return self.execute_command_no_results(connection, generator, write_concern) + return self.execute_op_msg_no_results(connection, generator) def execute(self, write_concern: WriteConcern, session: Optional[ClientSession]) -> Any: """Execute operations.""" @@ -573,8 +573,8 @@ class _Bulk: client = self.collection.database.client if not write_concern.acknowledged: - with client._socket_for_writes(session) as sock_info: - self.execute_no_results(sock_info, generator, write_concern) + with client._socket_for_writes(session) as connection: + self.execute_no_results(connection, generator, write_concern) return None else: return self.execute_command(generator, write_concern, session) diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index 10bfd3623..d8ba0c103 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -78,7 +78,7 @@ if TYPE_CHECKING: from pymongo.collection import Collection from pymongo.database import Database from pymongo.mongo_client import MongoClient - from pymongo.pool import SocketInfo + from pymongo.pool import Connection def _resumable(exc: PyMongoError) -> bool: @@ -213,7 +213,7 @@ class ChangeStream(Generic[_DocumentType]): full_pipeline.extend(self._pipeline) return full_pipeline - def _process_result(self, result: Mapping[str, Any], sock_info: SocketInfo) -> None: + def _process_result(self, result: Mapping[str, Any], connection: Connection) -> None: """Callback that caches the postBatchResumeToken or startAtOperationTime from a changeStream aggregate command response containing an empty batch of change documents. @@ -228,7 +228,7 @@ class ChangeStream(Generic[_DocumentType]): self._start_at_operation_time is None and self._uses_resume_after is False and self._uses_start_after is False - and sock_info.max_wire_version >= 7 + and connection.max_wire_version >= 7 ): self._start_at_operation_time = result.get("operationTime") # PYTHON-2181: informative error on missing operationTime. diff --git a/pymongo/client_options.py b/pymongo/client_options.py index 91ef51a52..7d06dfa33 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -24,7 +24,7 @@ from pymongo.common import validate_boolean from pymongo.compression_support import CompressionSettings from pymongo.errors import ConfigurationError from pymongo.monitoring import _EventListeners -from pymongo.pool import PoolOptions +from pymongo.pool import ConnectionProtocol, PoolOptions from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ( _ServerMode, @@ -162,6 +162,13 @@ def _parse_pool_options( ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options) load_balanced = options.get("loadbalanced") max_connecting = options.get("maxconnecting", common.MAX_CONNECTING) + + grpc_enabled = options.get("grpc", False) + if grpc_enabled: + protocol = ConnectionProtocol.GRPC + else: + protocol = ConnectionProtocol.TCP_SOCKET + return PoolOptions( max_pool_size, min_pool_size, @@ -179,6 +186,7 @@ def _parse_pool_options( server_api=server_api, load_balanced=load_balanced, credentials=credentials, + protocol=protocol, ) diff --git a/pymongo/client_session.py b/pymongo/client_session.py index a43982e43..c11777394 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -178,7 +178,7 @@ from pymongo.write_concern import WriteConcern if TYPE_CHECKING: from types import TracebackType - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.server import Server @@ -412,17 +412,17 @@ class _Transaction: return self.state == _TxnState.STARTING @property - def pinned_conn(self) -> Optional[SocketInfo]: + def pinned_conn(self) -> Optional[Connection]: if self.active() and self.sock_mgr: return self.sock_mgr.sock return None - def pin(self, server: Server, sock_info: SocketInfo) -> None: + def pin(self, server: Server, connection: Connection) -> None: self.sharded = True self.pinned_address = server.description.address if server.description.server_type == SERVER_TYPE.LoadBalancer: - sock_info.pin_txn() - self.sock_mgr = _SocketManager(sock_info, False) + connection.pin_txn() + self.sock_mgr = _SocketManager(connection, False) def unpin(self) -> None: self.pinned_address = None @@ -839,12 +839,12 @@ class ClientSession: - `command_name`: Either "commitTransaction" or "abortTransaction". """ - def func(session: ClientSession, sock_info: SocketInfo, retryable: bool) -> Dict[str, Any]: - return self._finish_transaction(sock_info, command_name) + def func(session: ClientSession, connection: Connection, retryable: bool) -> Dict[str, Any]: + return self._finish_transaction(connection, command_name) return self._client._retry_internal(True, func, self, None) - def _finish_transaction(self, sock_info: SocketInfo, command_name: str) -> Dict[str, Any]: + def _finish_transaction(self, connection: Connection, command_name: str) -> Dict[str, Any]: self._transaction.attempt += 1 opts = self._transaction.opts assert opts @@ -868,7 +868,7 @@ class ClientSession: cmd["recoveryToken"] = self._transaction.recovery_token return self._client.admin._command( - sock_info, cmd, session=self, write_concern=wc, parse_write_concern_error=True + connection, cmd, session=self, write_concern=wc, parse_write_concern_error=True ) def _advance_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None: @@ -954,13 +954,13 @@ class ClientSession: return None @property - def _pinned_connection(self) -> Optional[SocketInfo]: + def _pinned_connection(self) -> Optional[Connection]: """The connection this transaction was started on.""" return self._transaction.pinned_conn - def _pin(self, server: Server, sock_info: SocketInfo) -> None: + def _pin(self, server: Server, connection: Connection) -> None: """Pin this session to the given Server or to the given connection.""" - self._transaction.pin(server, sock_info) + self._transaction.pin(server, connection) def _unpin(self) -> None: """Unpin this session from any pinned Server.""" @@ -985,12 +985,12 @@ class ClientSession: command: MutableMapping[str, Any], is_retryable: bool, read_preference: ReadPreference, - sock_info: SocketInfo, + connection: Connection, ) -> None: self._check_ended() self._materialize() if self.options.snapshot: - self._update_read_concern(command, sock_info) + self._update_read_concern(command, connection) self._server_session.last_use = time.monotonic() command["lsid"] = self._server_session.session_id @@ -1016,7 +1016,7 @@ class ClientSession: rc = self._transaction.opts.read_concern.document if rc: command["readConcern"] = rc - self._update_read_concern(command, sock_info) + self._update_read_concern(command, connection) command["txnNumber"] = self._server_session.transaction_id command["autocommit"] = False @@ -1025,11 +1025,11 @@ class ClientSession: self._check_ended() self._server_session.inc_transaction_id() - def _update_read_concern(self, cmd: MutableMapping[str, Any], sock_info: SocketInfo) -> None: + def _update_read_concern(self, cmd: MutableMapping[str, Any], connection: Connection) -> None: if self.options.causal_consistency and self.operation_time is not None: cmd.setdefault("readConcern", {})["afterClusterTime"] = self.operation_time if self.options.snapshot: - if sock_info.max_wire_version < 13: + if connection.max_wire_version < 13: raise ConfigurationError("Snapshot reads require MongoDB 5.0 or later") rc = cmd.setdefault("readConcern", {}) rc["level"] = "snapshot" diff --git a/pymongo/collection.py b/pymongo/collection.py index 6a1bcf8c0..c05ae0557 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -126,7 +126,7 @@ if TYPE_CHECKING: from pymongo.client_session import ClientSession from pymongo.collation import Collation from pymongo.database import Database - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.read_concern import ReadConcern from pymongo.server import Server @@ -264,15 +264,15 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _socket_for_reads( self, session: ClientSession - ) -> ContextManager[Tuple[SocketInfo, Union[PrimaryPreferred, Primary]]]: + ) -> ContextManager[Tuple[Connection, Union[PrimaryPreferred, Primary]]]: return self.__database.client._socket_for_reads(self._read_preference_for(session), session) - def _socket_for_writes(self, session: Optional[ClientSession]) -> ContextManager[SocketInfo]: + def _socket_for_writes(self, session: Optional[ClientSession]) -> ContextManager[Connection]: return self.__database.client._socket_for_writes(session) def _command( self, - sock_info: SocketInfo, + connection: Connection, command: Mapping[str, Any], read_preference: Optional[_ServerMode] = None, codec_options: Optional[CodecOptions] = None, @@ -288,7 +288,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): """Internal command helper. :Parameters: - - `sock_info` - A SocketInfo instance. + - `connection` - A Connection instance. - `command` - The command itself, as a :class:`~bson.son.SON` instance. - `read_preference` (optional) - The read preference to use. - `codec_options` (optional) - An instance of @@ -313,7 +313,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): The result document. """ with self.__database.client._tmp_session(session) as s: - return sock_info.command( + return connection.command( self.__database.name, command, read_preference or self._read_preference_for(session), @@ -348,16 +348,16 @@ class Collection(common.BaseObject, Generic[_DocumentType]): if "size" in options: options["size"] = float(options["size"]) cmd.update(options) - with self._socket_for_writes(session) as sock_info: - if qev2_required and sock_info.max_wire_version < 21: + with self._socket_for_writes(session) as connection: + if qev2_required and connection.max_wire_version < 21: raise ConfigurationError( "Driver support of Queryable Encryption is incompatible with server. " "Upgrade server to use Queryable Encryption. " - f"Got maxWireVersion {sock_info.max_wire_version} but need maxWireVersion >= 21 (MongoDB >=7.0)" + f"Got maxWireVersion {connection.max_wire_version} but need maxWireVersion >= 21 (MongoDB >=7.0)" ) self._command( - sock_info, + connection, cmd, read_preference=ReadPreference.PRIMARY, write_concern=self._write_concern_for(session), @@ -597,12 +597,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]): command["comment"] = comment def _insert_command( - session: ClientSession, sock_info: SocketInfo, retryable_write: bool + session: ClientSession, connection: Connection, retryable_write: bool ) -> None: if bypass_doc_val: command["bypassDocumentValidation"] = True - result = sock_info.command( + result = connection.command( self.__database.name, command, write_concern=write_concern, @@ -765,7 +765,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _update( self, - sock_info: SocketInfo, + connection: Connection, criteria: Mapping[str, Any], document: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, @@ -801,7 +801,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): else: update_doc["arrayFilters"] = array_filters if hint is not None: - if not acknowledged and sock_info.max_wire_version < 8: + if not acknowledged and connection.max_wire_version < 8: raise ConfigurationError( "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." ) @@ -821,7 +821,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): # The command result has to be published for APM unmodified # so we make a shallow copy here before adding updatedExisting. - result = sock_info.command( + result = connection.command( self.__database.name, command, write_concern=write_concern, @@ -865,10 +865,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]): """Internal update / replace helper.""" def _update( - session: Optional[ClientSession], sock_info: SocketInfo, retryable_write: bool + session: Optional[ClientSession], connection: Connection, retryable_write: bool ) -> Optional[Mapping[str, Any]]: return self._update( - sock_info, + connection, criteria, document, upsert=upsert, @@ -1255,7 +1255,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _delete( self, - sock_info: SocketInfo, + connection: Connection, criteria: Mapping[str, Any], multi: bool, write_concern: Optional[WriteConcern] = None, @@ -1280,7 +1280,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): else: delete_doc["collation"] = collation if hint is not None: - if not acknowledged and sock_info.max_wire_version < 9: + if not acknowledged and connection.max_wire_version < 9: raise ConfigurationError( "Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands." ) @@ -1297,7 +1297,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): command["comment"] = comment # Delete command. - result = sock_info.command( + result = connection.command( self.__database.name, command, write_concern=write_concern, @@ -1325,10 +1325,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]): """Internal delete helper.""" def _delete( - session: Optional[ClientSession], sock_info: SocketInfo, retryable_write: bool + session: Optional[ClientSession], connection: Connection, retryable_write: bool ) -> Mapping[str, Any]: return self._delete( - sock_info, + connection, criteria, multi, write_concern=write_concern, @@ -1738,7 +1738,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _count_cmd( self, session: ClientSession, - sock_info: SocketInfo, + connection: Connection, read_preference: Optional[_ServerMode], cmd: Mapping[str, Any], collation: Optional[Collation], @@ -1747,7 +1747,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): # XXX: "ns missing" checks can be removed when we drop support for # MongoDB 3.0, see SERVER-17051. res = self._command( - sock_info, + connection, cmd, read_preference=read_preference, allowable_errors=["ns missing"], @@ -1762,7 +1762,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _aggregate_one_result( self, - sock_info: SocketInfo, + connection: Connection, read_preference: Optional[_ServerMode], cmd: Mapping[str, Any], collation: Optional[_CollationIn], @@ -1770,7 +1770,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): ) -> Optional[Mapping[str, Any]]: """Internal helper to run an aggregate that returns a single result.""" result = self._command( - sock_info, + connection, cmd, read_preference, allowable_errors=[26], # Ignore NamespaceNotFound. @@ -1821,12 +1821,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _cmd( session: ClientSession, server: Server, - sock_info: SocketInfo, + connection: Connection, read_preference: Optional[_ServerMode], ) -> int: cmd: SON[str, Any] = SON([("count", self.__name)]) cmd.update(kwargs) - return self._count_cmd(session, sock_info, read_preference, cmd, collation=None) + return self._count_cmd(session, connection, read_preference, cmd, collation=None) return self._retryable_non_cursor_read(_cmd, None) @@ -1910,10 +1910,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _cmd( session: ClientSession, server: Server, - sock_info: SocketInfo, + connection: Connection, read_preference: Optional[_ServerMode], ) -> int: - result = self._aggregate_one_result(sock_info, read_preference, cmd, collation, session) + result = self._aggregate_one_result( + connection, read_preference, cmd, collation, session + ) if not result: return 0 return result["n"] @@ -1922,7 +1924,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _retryable_non_cursor_read( self, - func: Callable[[ClientSession, Server, SocketInfo, Optional[_ServerMode]], T], + func: Callable[[ClientSession, Server, Connection, Optional[_ServerMode]], T], session: Optional[ClientSession], ) -> T: """Non-cursor read helper to handle implicit session creation.""" @@ -1993,8 +1995,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]): command (like maxTimeMS) can be passed as keyword arguments. """ names = [] - with self._socket_for_writes(session) as sock_info: - supports_quorum = sock_info.max_wire_version >= 9 + with self._socket_for_writes(session) as connection: + supports_quorum = connection.max_wire_version >= 9 def gen_indexes() -> Iterator[Mapping[str, Any]]: for index in indexes: @@ -2015,7 +2017,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): ) self._command( - sock_info, + connection, cmd, read_preference=ReadPreference.PRIMARY, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, @@ -2236,9 +2238,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]): cmd.update(kwargs) if comment is not None: cmd["comment"] = comment - with self._socket_for_writes(session) as sock_info: + with self._socket_for_writes(session) as connection: self._command( - sock_info, + connection, cmd, read_preference=ReadPreference.PRIMARY, allowable_errors=["ns not found", 26], @@ -2285,7 +2287,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _cmd( session: ClientSession, server: Server, - sock_info: SocketInfo, + connection: Connection, read_preference: _ServerMode, ) -> CommandCursor[_DocumentType]: cmd = SON([("listIndexes", self.__name), ("cursor", {})]) @@ -2294,7 +2296,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): try: cursor = self._command( - sock_info, cmd, read_preference, codec_options, session=session + connection, cmd, read_preference, codec_options, session=session )["cursor"] except OperationFailure as exc: # Ignore NamespaceNotFound errors to match the behavior @@ -2305,12 +2307,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]): cmd_cursor = CommandCursor( coll, cursor, - sock_info.address, + connection.address, session=session, explicit_session=explicit_session, comment=cmd.get("comment"), ) - cmd_cursor._maybe_pin_connection(sock_info) + cmd_cursor._maybe_pin_connection(connection) return cmd_cursor with self.__database.client._tmp_session(session, False) as s: @@ -2479,9 +2481,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]): cmd = SON([("createSearchIndexes", self.name), ("indexes", list(gen_indexes()))]) cmd.update(kwargs) - with self._socket_for_writes(session) as sock_info: + with self._socket_for_writes(session) as connection: resp = self._command( - sock_info, + connection, cmd, read_preference=ReadPreference.PRIMARY, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, @@ -2514,9 +2516,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]): cmd.update(kwargs) if comment is not None: cmd["comment"] = comment - with self._socket_for_writes(session) as sock_info: + with self._socket_for_writes(session) as connection: self._command( - sock_info, + connection, cmd, read_preference=ReadPreference.PRIMARY, allowable_errors=["ns not found", 26], @@ -2551,9 +2553,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]): cmd.update(kwargs) if comment is not None: cmd["comment"] = comment - with self._socket_for_writes(session) as sock_info: + with self._socket_for_writes(session) as connection: self._command( - sock_info, + connection, cmd, read_preference=ReadPreference.PRIMARY, allowable_errors=["ns not found", 26], @@ -2980,9 +2982,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]): cmd["comment"] = comment write_concern = self._write_concern_for_cmd(cmd, session) - with self._socket_for_writes(session) as sock_info: + with self._socket_for_writes(session) as connection: with self.__database.client._tmp_session(session) as s: - return sock_info.command( + return connection.command( "admin", cmd, write_concern=write_concern, @@ -3049,11 +3051,11 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _cmd( session: ClientSession, server: Server, - sock_info: SocketInfo, + connection: Connection, read_preference: Optional[_ServerMode], ) -> List: return self._command( - sock_info, + connection, cmd, read_preference=read_preference, read_concern=self.read_concern, @@ -3112,7 +3114,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): write_concern = self._write_concern_for_cmd(cmd, session) def _find_and_modify( - session: ClientSession, sock_info: SocketInfo, retryable_write: bool + session: ClientSession, connection: Connection, retryable_write: bool ) -> Any: acknowledged = write_concern.acknowledged if array_filters is not None: @@ -3122,17 +3124,17 @@ class Collection(common.BaseObject, Generic[_DocumentType]): ) cmd["arrayFilters"] = list(array_filters) if hint is not None: - if sock_info.max_wire_version < 8: + if connection.max_wire_version < 8: raise ConfigurationError( "Must be connected to MongoDB 4.2+ to use hint on find and modify commands." ) - elif not acknowledged and sock_info.max_wire_version < 9: + elif not acknowledged and connection.max_wire_version < 9: raise ConfigurationError( "Must be connected to MongoDB 4.4+ to use hint on unacknowledged find and modify commands." ) cmd["hint"] = hint out = self._command( - sock_info, + connection, cmd, read_preference=ReadPreference.PRIMARY, write_concern=write_concern, diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index 4a3d0311c..4684f4232 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -38,7 +38,7 @@ from pymongo.typings import _Address, _DocumentType if TYPE_CHECKING: from pymongo.client_session import ClientSession from pymongo.collection import Collection - from pymongo.pool import SocketInfo + from pymongo.pool import Connection class CommandCursor(Generic[_DocumentType]): @@ -157,13 +157,13 @@ class CommandCursor(Generic[_DocumentType]): """ return self.__postbatchresumetoken - def _maybe_pin_connection(self, sock_info: SocketInfo) -> None: + def _maybe_pin_connection(self, connection: Connection) -> None: client = self.__collection.database.client if not client._should_pin_cursor(self.__session): return if not self.__sock_mgr: - sock_info.pin_cursor() - sock_mgr = _SocketManager(sock_info, False) + connection.pin_cursor() + sock_mgr = _SocketManager(connection, False) # Ensure the connection gets returned when the entire result is # returned in the first batch. if self.__id == 0: diff --git a/pymongo/common.py b/pymongo/common.py index a791a5e44..efe5c183c 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -750,6 +750,7 @@ KW_VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = { "server_selector": validate_is_callable_or_none, "auto_encryption_opts": validate_auto_encryption_opts_or_none, "authoidcallowedhosts": validate_list, + "grpc": validate_boolean, } # Dictionary where keys are any URI option name, and values are the diff --git a/pymongo/cursor.py b/pymongo/cursor.py index b718d905e..83aed702c 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -64,7 +64,7 @@ if TYPE_CHECKING: from pymongo.client_session import ClientSession from pymongo.collection import Collection from pymongo.message import _OpMsg, _OpReply - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.read_preferences import _ServerMode @@ -142,8 +142,8 @@ class CursorType: class _SocketManager: """Used with exhaust cursors to ensure the socket is returned.""" - def __init__(self, sock: SocketInfo, more_to_come: bool): - self.sock: Optional[SocketInfo] = sock + def __init__(self, sock: Connection, more_to_come: bool): + self.sock: Optional[Connection] = sock self.more_to_come = more_to_come self.lock = _create_lock() diff --git a/pymongo/database.py b/pymongo/database.py index a55555cc4..83dd3b57a 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -48,7 +48,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline if TYPE_CHECKING: - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.server import Server @@ -689,7 +689,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): @overload def _command( self, - sock_info: SocketInfo, + connection: Connection, command: Union[str, MutableMapping[str, Any]], value: int = 1, check: bool = True, @@ -706,7 +706,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): @overload def _command( self, - sock_info: SocketInfo, + connection: Connection, command: Union[str, MutableMapping[str, Any]], value: int = 1, check: bool = True, @@ -722,7 +722,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): def _command( self, - sock_info: SocketInfo, + connection: Connection, command: Union[str, MutableMapping[str, Any]], value: int = 1, check: bool = True, @@ -742,7 +742,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): command.update(kwargs) with self.__client._tmp_session(session) as s: - return sock_info.command( + return connection.command( self.__name, command, read_preference, @@ -890,11 +890,11 @@ class Database(common.BaseObject, Generic[_DocumentType]): if read_preference is None: read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY with self.__client._socket_for_reads(read_preference, session) as ( - sock_info, + connection, read_preference, ): return self._command( - sock_info, + connection, command, value, check, @@ -974,11 +974,11 @@ class Database(common.BaseObject, Generic[_DocumentType]): tmp_session and tmp_session._txn_read_preference() ) or ReadPreference.PRIMARY with self.__client._socket_for_reads(read_preference, tmp_session) as ( - sock_info, + connection, read_preference, ): response = self._command( - sock_info, + connection, command, value, True, @@ -993,13 +993,13 @@ class Database(common.BaseObject, Generic[_DocumentType]): cmd_cursor = CommandCursor( coll, response["cursor"], - sock_info.address, + connection.address, max_await_time_ms=max_await_time_ms, session=tmp_session, explicit_session=session is not None, comment=comment, ) - cmd_cursor._maybe_pin_connection(sock_info) + cmd_cursor._maybe_pin_connection(connection) return cmd_cursor else: raise InvalidOperation("Command does not return a cursor.") @@ -1015,11 +1015,11 @@ class Database(common.BaseObject, Generic[_DocumentType]): def _cmd( session: Optional[ClientSession], server: Server, - sock_info: SocketInfo, + connection: Connection, read_preference: _ServerMode, ) -> Dict[str, Any]: return self._command( - sock_info, + connection, command, read_preference=read_preference, session=session, @@ -1029,7 +1029,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): def _list_collections( self, - sock_info: SocketInfo, + connection: Connection, session: Optional[ClientSession], read_preference: _ServerMode, **kwargs: Any, @@ -1040,17 +1040,17 @@ class Database(common.BaseObject, Generic[_DocumentType]): cmd.update(kwargs) with self.__client._tmp_session(session, close=False) as tmp_session: cursor = self._command( - sock_info, cmd, read_preference=read_preference, session=tmp_session + connection, cmd, read_preference=read_preference, session=tmp_session )["cursor"] cmd_cursor = CommandCursor( coll, cursor, - sock_info.address, + connection.address, session=tmp_session, explicit_session=session is not None, comment=cmd.get("comment"), ) - cmd_cursor._maybe_pin_connection(sock_info) + cmd_cursor._maybe_pin_connection(connection) return cmd_cursor def list_collections( @@ -1090,11 +1090,11 @@ class Database(common.BaseObject, Generic[_DocumentType]): def _cmd( session: Optional[ClientSession], server: Server, - sock_info: SocketInfo, + connection: Connection, read_preference: _ServerMode, ) -> CommandCursor[_DocumentType]: return self._list_collections( - sock_info, session, read_preference=read_preference, **kwargs + connection, session, read_preference=read_preference, **kwargs ) return self.__client._retryable_read(_cmd, read_pref, session) @@ -1154,9 +1154,9 @@ class Database(common.BaseObject, Generic[_DocumentType]): if comment is not None: command["comment"] = comment - with self.__client._socket_for_writes(session) as sock_info: + with self.__client._socket_for_writes(session) as connection: return self._command( - sock_info, + connection, command, allowable_errors=["ns not found", 26], write_concern=self._write_concern_for(session), diff --git a/pymongo/helpers.py b/pymongo/helpers.py index 9a4a2b04c..2c10778c1 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -307,7 +307,7 @@ F = TypeVar("F", bound=Callable[..., Any]) def _handle_reauth(func: F) -> F: def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) - from pymongo.pool import SocketInfo + from pymongo.pool import Connection try: return func(*args, **kwargs) @@ -315,19 +315,19 @@ def _handle_reauth(func: F) -> F: if no_reauth: raise if exc.code == _REAUTHENTICATION_REQUIRED_CODE: - # Look for an argument that either is a SocketInfo + # Look for an argument that either is a Connection # or has a socket_info attribute, so we can trigger # a reauth. - sock_info = None + connection = None for arg in args: - if isinstance(arg, SocketInfo): - sock_info = arg + if isinstance(arg, Connection): + connection = arg break - if hasattr(arg, "sock_info"): - sock_info = arg.sock_info + if hasattr(arg, "connection"): + connection = arg.connection break - if sock_info: - sock_info.authenticate(reauthenticate=True) + if connection: + connection.authenticate(reauthenticate=True) else: raise return func(*args, **kwargs) diff --git a/pymongo/message.py b/pymongo/message.py index 735f8a8cc..3542ccdc5 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -227,14 +227,14 @@ def _gen_find_command( return cmd -def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, sock_info): +def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, connection): """Generate a getMore command document.""" cmd = SON([("getMore", cursor_id), ("collection", coll)]) if batch_size: cmd["batchSize"] = batch_size if max_await_time_ms is not None: cmd["maxTimeMS"] = max_await_time_ms - if comment is not None and sock_info.max_wire_version >= 9: + if comment is not None and connection.max_wire_version >= 9: cmd["comment"] = comment return cmd @@ -311,24 +311,24 @@ class _Query: def namespace(self): return f"{self.db}.{self.coll}" - def use_command(self, sock_info): + def use_command(self, connection): use_find_cmd = False if not self.exhaust: use_find_cmd = True - elif sock_info.max_wire_version >= 8: + elif connection.max_wire_version >= 8: # OP_MSG supports exhaust on MongoDB 4.2+ use_find_cmd = True elif not self.read_concern.ok_for_legacy: raise ConfigurationError( "read concern level of %s is not valid " "with a max wire version of %d." - % (self.read_concern.level, sock_info.max_wire_version) + % (self.read_concern.level, connection.max_wire_version) ) - sock_info.validate_session(self.client, self.session) + connection.validate_session(self.client, self.session) return use_find_cmd - def as_command(self, sock_info, apply_timeout=False): + def as_command(self, connection, apply_timeout=False): """Return a find command document for this query.""" # We use the command twice: on the wire and for command monitoring. # Generate it once, for speed and to avoid repeating side-effects. @@ -353,24 +353,24 @@ class _Query: self.name = "explain" cmd = SON([("explain", cmd)]) session = self.session - sock_info.add_server_api(cmd) + connection.add_server_api(cmd) if session: - session._apply_to(cmd, False, self.read_preference, sock_info) + session._apply_to(cmd, False, self.read_preference, connection) # Explain does not support readConcern. if not explain and not session.in_transaction: - session._update_read_concern(cmd, sock_info) - sock_info.send_cluster_time(cmd, session, self.client) + session._update_read_concern(cmd, connection) + connection.send_cluster_time(cmd, session, self.client) # Support auto encryption client = self.client if client._encrypter and not client._encrypter._bypass_auto_encryption: cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) # Support CSOT if apply_timeout: - sock_info.apply_timeout(client, cmd) + connection.apply_timeout(client, cmd) self._as_command = cmd, self.db return self._as_command - def get_message(self, read_preference, sock_info, use_cmd=False): + def get_message(self, read_preference, connection, use_cmd=False): """Get a query message, possibly setting the secondaryOk bit.""" # Use the read_preference decided by _socket_from_server. self.read_preference = read_preference @@ -384,14 +384,14 @@ class _Query: spec = self.spec if use_cmd: - spec = self.as_command(sock_info, apply_timeout=True)[0] + spec = self.as_command(connection, apply_timeout=True)[0] request_id, msg, size, _ = _op_msg( 0, spec, self.db, read_preference, self.codec_options, - ctx=sock_info.compression_context, + ctx=connection.compression_context, ) return request_id, msg, size @@ -405,7 +405,7 @@ class _Query: else: ntoreturn = self.limit - if sock_info.is_mongos: + if connection.is_mongos: spec = _maybe_add_read_preference(spec, read_preference) return _query( @@ -416,7 +416,7 @@ class _Query: spec, None if use_cmd else self.fields, self.codec_options, - ctx=sock_info.compression_context, + ctx=connection.compression_context, ) @@ -476,18 +476,18 @@ class _GetMore: def namespace(self): return f"{self.db}.{self.coll}" - def use_command(self, sock_info): + def use_command(self, connection): use_cmd = False if not self.exhaust: use_cmd = True - elif sock_info.max_wire_version >= 8: + elif connection.max_wire_version >= 8: # OP_MSG supports exhaust on MongoDB 4.2+ use_cmd = True - sock_info.validate_session(self.client, self.session) + connection.validate_session(self.client, self.session) return use_cmd - def as_command(self, sock_info, apply_timeout=False): + def as_command(self, connection, apply_timeout=False): """Return a getMore command document for this query.""" # See _Query.as_command for an explanation of this caching. if self._as_command is not None: @@ -499,35 +499,35 @@ class _GetMore: self.ntoreturn, self.max_await_time_ms, self.comment, - sock_info, + connection, ) if self.session: - self.session._apply_to(cmd, False, self.read_preference, sock_info) - sock_info.add_server_api(cmd) - sock_info.send_cluster_time(cmd, self.session, self.client) + self.session._apply_to(cmd, False, self.read_preference, connection) + connection.add_server_api(cmd) + connection.send_cluster_time(cmd, self.session, self.client) # Support auto encryption client = self.client if client._encrypter and not client._encrypter._bypass_auto_encryption: cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) # Support CSOT if apply_timeout: - sock_info.apply_timeout(client, cmd=None) + connection.apply_timeout(client, cmd=None) self._as_command = cmd, self.db return self._as_command - def get_message(self, dummy0, sock_info, use_cmd=False): + def get_message(self, dummy0, connection, use_cmd=False): """Get a getmore message.""" ns = self.namespace() - ctx = sock_info.compression_context + ctx = connection.compression_context if use_cmd: - spec = self.as_command(sock_info, apply_timeout=True)[0] + spec = self.as_command(connection, apply_timeout=True)[0] if self.sock_mgr: flags = _OpMsg.EXHAUST_ALLOWED else: flags = 0 request_id, msg, size, _ = _op_msg( - flags, spec, self.db, None, self.codec_options, ctx=sock_info.compression_context + flags, spec, self.db, None, self.codec_options, ctx=connection.compression_context ) return request_id, msg, size @@ -535,10 +535,10 @@ class _GetMore: class _RawBatchQuery(_Query): - def use_command(self, sock_info): + def use_command(self, connection): # Compatibility checks. - super().use_command(sock_info) - if sock_info.max_wire_version >= 8: + super().use_command(connection) + if connection.max_wire_version >= 8: # MongoDB 4.2+ supports exhaust over OP_MSG return True elif not self.exhaust: @@ -547,10 +547,10 @@ class _RawBatchQuery(_Query): class _RawBatchGetMore(_GetMore): - def use_command(self, sock_info): + def use_command(self, connection): # Compatibility checks. - super().use_command(sock_info) - if sock_info.max_wire_version >= 8: + super().use_command(connection) + if connection.max_wire_version >= 8: # MongoDB 4.2+ supports exhaust over OP_MSG return True elif not self.exhaust: @@ -794,11 +794,11 @@ def _get_more(collection_name, num_to_return, cursor_id, ctx=None): class _BulkWriteContext: - """A wrapper around SocketInfo for use with write splitting functions.""" + """A wrapper around Connection for use with write splitting functions.""" __slots__ = ( "db_name", - "sock_info", + "connection", "op_id", "name", "field", @@ -812,10 +812,10 @@ class _BulkWriteContext: ) def __init__( - self, database_name, cmd_name, sock_info, operation_id, listeners, session, op_type, codec + self, database_name, cmd_name, connection, operation_id, listeners, session, op_type, codec ): self.db_name = database_name - self.sock_info = sock_info + self.connection = connection self.op_id = operation_id self.listeners = listeners self.publish = listeners.enabled_for_commands @@ -823,7 +823,7 @@ class _BulkWriteContext: self.field = _FIELD_MAP[self.name] self.start_time = datetime.datetime.now() if self.publish else None self.session = session - self.compress = True if sock_info.compression_context else False + self.compress = True if connection.compression_context else False self.op_type = op_type self.codec = codec @@ -855,20 +855,20 @@ class _BulkWriteContext: @property def max_bson_size(self): """A proxy for SockInfo.max_bson_size.""" - return self.sock_info.max_bson_size + return self.connection.max_bson_size @property def max_message_size(self): """A proxy for SockInfo.max_message_size.""" if self.compress: # Subtract 16 bytes for the message header. - return self.sock_info.max_message_size - 16 - return self.sock_info.max_message_size + return self.connection.max_message_size - 16 + return self.connection.max_message_size @property def max_write_batch_size(self): """A proxy for SockInfo.max_write_batch_size.""" - return self.sock_info.max_write_batch_size + return self.connection.max_write_batch_size @property def max_split_size(self): @@ -876,14 +876,14 @@ class _BulkWriteContext: return self.max_bson_size def unack_write(self, cmd, request_id, msg, max_doc_size, docs): - """A proxy for SocketInfo.unack_write that handles event publishing.""" + """A proxy for Connection.unack_write that handles event publishing.""" if self.publish: assert self.start_time is not None duration = datetime.datetime.now() - self.start_time cmd = self._start(cmd, request_id, docs) start = datetime.datetime.now() try: - result = self.sock_info.unack_write(msg, max_doc_size) + result = self.connection.unack_write(msg, max_doc_size) if self.publish: duration = (datetime.datetime.now() - start) + duration if result is not None: @@ -910,14 +910,14 @@ class _BulkWriteContext: @_handle_reauth def write_command(self, cmd, request_id, msg, docs): - """A proxy for SocketInfo.write_command that handles event publishing.""" + """A proxy for Connection.write_command that handles event publishing.""" if self.publish: assert self.start_time is not None duration = datetime.datetime.now() - self.start_time self._start(cmd, request_id, docs) start = datetime.datetime.now() try: - reply = self.sock_info.write_command(request_id, msg, self.codec) + reply = self.connection.write_command(request_id, msg, self.codec) if self.publish: duration = (datetime.datetime.now() - start) + duration self._succeed(request_id, reply, duration) @@ -941,9 +941,9 @@ class _BulkWriteContext: cmd, self.db_name, request_id, - self.sock_info.address, + self.connection.address, self.op_id, - self.sock_info.service_id, + self.connection.service_id, ) return cmd @@ -954,9 +954,9 @@ class _BulkWriteContext: reply, self.name, request_id, - self.sock_info.address, + self.connection.address, self.op_id, - self.sock_info.service_id, + self.connection.service_id, ) def _fail(self, request_id, failure, duration): @@ -966,9 +966,9 @@ class _BulkWriteContext: failure, self.name, request_id, - self.sock_info.address, + self.connection.address, self.op_id, - self.sock_info.service_id, + self.connection.service_id, ) @@ -997,14 +997,14 @@ class _EncryptedBulkWriteContext(_BulkWriteContext): def execute(self, cmd, docs, client): batched_cmd, to_send = self._batch_command(cmd, docs) - result = self.sock_info.command( + result = self.connection.command( self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client ) return result, to_send def execute_unack(self, cmd, docs, client): batched_cmd, to_send = self._batch_command(cmd, docs) - self.sock_info.command( + self.connection.command( self.db_name, batched_cmd, write_concern=WriteConcern(w=0), @@ -1124,7 +1124,7 @@ def _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx): """ data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx) - request_id, msg = _compress(2013, data, ctx.sock_info.compression_context) + request_id, msg = _compress(2013, data, ctx.connection.compression_context) return request_id, msg, to_send @@ -1162,7 +1162,7 @@ def _do_batched_op_msg(namespace, operation, command, docs, opts, ctx): ack = bool(command["writeConcern"].get("w", 1)) else: ack = True - if ctx.sock_info.compression_context: + if ctx.connection.compression_context: return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx) return _batched_op_msg(operation, command, docs, ack, opts, ctx) diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index c8a265622..77152bbef 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1160,18 +1160,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): def _end_sessions(self, session_ids): """Send endSessions command(s) with the given session ids.""" try: - # Use SocketInfo.command directly to avoid implicitly creating + # Use Connection.command directly to avoid implicitly creating # another session. with self._socket_for_reads(ReadPreference.PRIMARY_PREFERRED, None) as ( - sock_info, + connection, read_pref, ): - if not sock_info.supports_sessions: + if not connection.supports_sessions: return for i in range(0, len(session_ids), common._MAX_END_SESSIONS): spec = SON([("endSessions", session_ids[i : i + common._MAX_END_SESSIONS])]) - sock_info.command("admin", spec, read_preference=read_pref, client=self) + connection.command("admin", spec, read_preference=read_pref, client=self) except PyMongoError: # Drivers MUST ignore any errors returned by the endSessions # command. @@ -1224,23 +1224,23 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): err_handler.contribute_socket(session._pinned_connection) yield session._pinned_connection return - with server.get_socket(handler=err_handler) as sock_info: + with server.get_socket(handler=err_handler) as connection: # Pin this session to the selected server or connection. if in_txn and server.description.server_type in ( SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer, ): - session._pin(server, sock_info) - err_handler.contribute_socket(sock_info) + session._pin(server, connection) + err_handler.contribute_socket(connection) if ( self._encrypter and not self._encrypter._bypass_auto_encryption - and sock_info.max_wire_version < 8 + and connection.max_wire_version < 8 ): raise ConfigurationError( "Auto-encryption requires a minimum MongoDB version of 4.2" ) - yield sock_info + yield connection def _select_server(self, server_selector, session, address=None): """Select a server to run an operation on this client. @@ -1281,7 +1281,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): def _socket_from_server(self, read_preference, server, session): assert read_preference is not None, "read_preference must not be None" # Get a socket for a server matching the read preference, and yield - # sock_info with the effective read preference. The Server Selection + # connection with the effective read preference. The Server Selection # Spec says not to send any $readPreference to standalones and to # always send primaryPreferred when directly connected to a repl set # member. @@ -1289,16 +1289,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): topology = self._get_topology() single = topology.description.topology_type == TOPOLOGY_TYPE.Single - with self._get_socket(server, session) as sock_info: + with self._get_socket(server, session) as connection: if single: - if sock_info.is_repl and not (session and session.in_transaction): + if connection.is_repl and not (session and session.in_transaction): # Use primary preferred to ensure any repl set member # can handle the request. read_preference = ReadPreference.PRIMARY_PREFERRED - elif sock_info.is_standalone: + elif connection.is_standalone: # Don't send read preference to standalones. read_preference = ReadPreference.PRIMARY - yield sock_info, read_preference + yield connection, read_preference def _socket_for_reads(self, read_preference, session): assert read_preference is not None, "read_preference must not be None" @@ -1331,10 +1331,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): operation.sock_mgr.sock, operation, True, self._event_listeners, unpack_res ) - def _cmd(session, server, sock_info, read_preference): + def _cmd(session, server, connection, read_preference): operation.reset() # Reset op in case of retry. return server.run_operation( - sock_info, operation, read_preference, self._event_listeners, unpack_res + connection, operation, read_preference, self._event_listeners, unpack_res ) return self._retryable_read( @@ -1388,8 +1388,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): supports_session = ( session is not None and server.description.retryable_writes_supported ) - with self._get_socket(server, session) as sock_info: - max_wire_version = sock_info.max_wire_version + with self._get_socket(server, session) as connection: + max_wire_version = connection.max_wire_version if retryable and not supports_session: if is_retrying(): # A retry is not possible because this server does @@ -1397,7 +1397,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): assert last_error is not None raise last_error retryable = False - return func(session, sock_info, retryable) + return func(session, connection, retryable) except ServerSelectionTimeoutError: if is_retrying(): # The application may think the write was never attempted @@ -1455,13 +1455,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): raise last_error try: server = self._select_server(read_pref, session, address=address) - with self._socket_from_server(read_pref, server, session) as (sock_info, read_pref): + with self._socket_from_server(read_pref, server, session) as ( + connection, + read_pref, + ): if retrying and not retryable: # A retry is not possible because this server does # not support retryable reads, raise the last error. assert last_error is not None raise last_error - return func(session, server, sock_info, read_pref) + return func(session, server, connection, read_pref) except ServerSelectionTimeoutError: if retrying: # The application may think the write was never attempted @@ -1633,14 +1636,14 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): # Application called close_cursor() with no address. server = topology.select_server(writable_server_selector) - with self._get_socket(server, session) as sock_info: - self._kill_cursor_impl(cursor_ids, address, session, sock_info) + with self._get_socket(server, session) as connection: + self._kill_cursor_impl(cursor_ids, address, session, connection) - def _kill_cursor_impl(self, cursor_ids, address, session, sock_info): + def _kill_cursor_impl(self, cursor_ids, address, session, connection): namespace = address.namespace db, coll = namespace.split(".", 1) spec = SON([("killCursors", coll), ("cursors", cursor_ids)]) - sock_info.command(db, spec, session=session, client=self) + connection.command(db, spec, session=session, client=self) def _process_kill_cursors(self): """Process any pending kill cursors requests.""" @@ -1925,9 +1928,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): if not isinstance(name, str): raise TypeError("name_or_database must be an instance of str or a Database") - with self._socket_for_writes(session) as sock_info: + with self._socket_for_writes(session) as connection: self[name]._command( - sock_info, + connection, {"dropDatabase": 1, "comment": comment}, read_preference=ReadPreference.PRIMARY, write_concern=self._write_concern_for(session), @@ -2149,11 +2152,11 @@ class _MongoClientErrorHandler: self.service_id = None self.handled = False - def contribute_socket(self, sock_info, completed_handshake=True): + def contribute_socket(self, connection, completed_handshake=True): """Provide socket information to the error handler.""" - self.max_wire_version = sock_info.max_wire_version - self.sock_generation = sock_info.generation - self.service_id = sock_info.service_id + self.max_wire_version = connection.max_wire_version + self.sock_generation = connection.generation + self.service_id = connection.service_id self.completed_handshake = completed_handshake def handle(self, exc_type, exc_val): diff --git a/pymongo/monitor.py b/pymongo/monitor.py index 2fc0bf8ba..d28b550b4 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -245,9 +245,9 @@ class Monitor(MonitorBase): if self._cancel_context and self._cancel_context.cancelled: self._reset_connection() - with self._pool.get_socket() as sock_info: - self._cancel_context = sock_info.cancel_context - response, round_trip_time = self._check_with_socket(sock_info) + with self._pool.get_socket() as connection: + self._cancel_context = connection.cancel_context + response, round_trip_time = self._check_with_socket(connection) if not response.awaitable: self._rtt_monitor.add_sample(round_trip_time) @@ -393,11 +393,11 @@ class _RttMonitor(MonitorBase): def _ping(self): """Run a "hello" command and return the RTT.""" - with self._pool.get_socket() as sock_info: + with self._pool.get_socket() as connection: if self._executor._stopped: raise Exception("_RttMonitor closed") start = time.monotonic() - sock_info.hello() + connection.hello() return time.monotonic() - start diff --git a/pymongo/network.py b/pymongo/network.py index 4cff1e529..c1a225f3e 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -52,7 +52,7 @@ if TYPE_CHECKING: from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.mongo_client import MongoClient from pymongo.monitoring import _EventListeners - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode from pymongo.typings import _Address @@ -62,7 +62,7 @@ _UNPACK_HEADER = struct.Struct(" Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise socket.error.""" if _csot.get_timeout(): deadline = _csot.get_deadline() else: - timeout = sock_info.sock.gettimeout() + timeout = connection.connector.gettimeout() if timeout: deadline = time.monotonic() + timeout else: deadline = None # Ignore the response's request id. length, _, response_to, op_code = _UNPACK_HEADER( - _receive_data_on_socket(sock_info, 16, deadline) + _receive_data_on_socket(connection, 16, deadline) ) # No request_id for exhaust cursor "getMore". if request_id is not None: @@ -260,11 +260,11 @@ def receive_message( ) if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( - _receive_data_on_socket(sock_info, 9, deadline) + _receive_data_on_socket(connection, 9, deadline) ) - data = decompress(_receive_data_on_socket(sock_info, length - 25, deadline), compressor_id) + data = decompress(_receive_data_on_socket(connection, length - 25, deadline), compressor_id) else: - data = _receive_data_on_socket(sock_info, length - 16, deadline) + data = _receive_data_on_socket(connection, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] @@ -273,15 +273,51 @@ def receive_message( return unpack_reply(data) +def receive_message_grpc( + connection: Connection, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE +) -> Union[_OpReply, _OpMsg]: + """Receive a raw BSON message or raise socket.error.""" + # Ignore the response's request id. + if connection.response is not None: + data = connection.response.__next__() + length, _, response_to, op_code = _UNPACK_HEADER(data[:16]) + # No request_id for exhaust cursor "getMore". + if request_id is not None: + if request_id != response_to: + raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}") + if length <= 16: + raise ProtocolError( + f"Message length ({length!r}) not longer than standard message header size (16)" + ) + if length > max_message_size: + raise ProtocolError( + "Message length ({!r}) is larger than server max " + "message size ({!r})".format(length, max_message_size) + ) + if op_code == 2012: + op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(data[:9]) + data = decompress(data[25:], compressor_id) + else: + data = data[16:] + + try: + unpack_reply = _UNPACK_REPLY[op_code] + except KeyError: + raise ProtocolError(f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}") + return unpack_reply(data) + else: + raise ProtocolError(f"Received empty response for request id {request_id}") + + _POLL_TIMEOUT = 0.5 -def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None: +def wait_for_read(connection: Connection, deadline: Optional[float]) -> None: """Block until at least one byte is read, or a timeout, or a cancel.""" - context = sock_info.cancel_context + context = connection.cancel_context # Only Monitor connections can be cancelled. if context: - sock = sock_info.sock + sock = connection.connector timed_out = False while True: # SSLSocket can have buffered data which won't be caught by select. @@ -300,7 +336,7 @@ def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None: timeout = max(min(remaining, _POLL_TIMEOUT), 0) else: timeout = _POLL_TIMEOUT - readable = sock_info.socket_checker.select(sock, read=True, timeout=timeout) + readable = connection.socket_checker.select(sock, read=True, timeout=timeout) if context.cancelled: raise _OperationCancelled("hello cancelled") if readable: @@ -314,20 +350,20 @@ BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS) def _receive_data_on_socket( - sock_info: SocketInfo, length: int, deadline: Optional[float] + connection: Connection, length: int, deadline: Optional[float] ) -> memoryview: buf = bytearray(length) mv = memoryview(buf) bytes_read = 0 while bytes_read < length: try: - wait_for_read(sock_info, deadline) + wait_for_read(connection, deadline) # CSOT: Update timeout. When the timeout has expired perform one # final non-blocking recv. This helps avoid spurious timeouts when # the response is actually already buffered on the client. if _csot.get_timeout() and deadline is not None: - sock_info.set_socket_timeout(max(deadline - time.monotonic(), 0)) - chunk_length = sock_info.sock.recv_into(mv[bytes_read:]) + connection.set_connector_timeout(max(deadline - time.monotonic(), 0)) + chunk_length = connection.connector.recv_into(mv[bytes_read:]) except BLOCKING_IO_ERRORS: raise socket.timeout("timed out") except OSError as exc: # noqa: B014 diff --git a/pymongo/pool.py b/pymongo/pool.py index a827d10f9..f3c1ccc40 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -23,7 +23,14 @@ import sys import threading import time import weakref -from typing import Any, Dict, NoReturn, Optional +from typing import Any, Dict, Iterator, NoReturn, Optional + +try: + import grpc # type: ignore + + _HAVE_GRPC = True +except ImportError: + _HAVE_GRPC = False import bson from bson import DEFAULT_CODEC_OPTIONS @@ -60,7 +67,7 @@ from pymongo.hello import Hello, HelloCompat from pymongo.helpers import _handle_reauth from pymongo.lock import _create_lock from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason -from pymongo.network import command, receive_message +from pymongo.network import command, receive_message_grpc, receive_message_tcp from pymongo.read_preferences import ReadPreference from pymongo.server_api import _add_to_command from pymongo.server_type import SERVER_TYPE @@ -370,6 +377,12 @@ def _cond_wait(condition, deadline): return condition.wait(timeout) +class ConnectionProtocol: + TCP_SOCKET = 0 + + GRPC = 1 + + class PoolOptions: """Read only connection pool options for a MongoClient. @@ -402,6 +415,7 @@ class PoolOptions: "__server_api", "__load_balanced", "__credentials", + "__protocol", ) def __init__( @@ -423,6 +437,7 @@ class PoolOptions: server_api=None, load_balanced=None, credentials=None, + protocol=ConnectionProtocol.TCP_SOCKET, ): self.__max_pool_size = max_pool_size self.__min_pool_size = min_pool_size @@ -441,6 +456,7 @@ class PoolOptions: self.__server_api = server_api self.__load_balanced = load_balanced self.__credentials = credentials + self.__protocol = protocol self.__metadata = copy.deepcopy(_METADATA) if appname: self.__metadata["application"] = {"name": appname} @@ -473,6 +489,11 @@ class PoolOptions: _truncate_metadata(self.__metadata) + @property + def protocol(self): + """A :class:`~pymongo.pool.ConnectionProtocol` value.""" + return self.__protocol + @property def _credentials(self): """A :class:`~pymongo.auth.MongoCredentials` instance or None.""" @@ -614,21 +635,22 @@ class _CancellationContext: return self._cancelled -class SocketInfo: - """Store a socket with some metadata. +class Connection: + """Store a connection with some metadata. :Parameters: - - `sock`: a raw socket object + - `connector`: a raw connector implementation, currently either a TCP Socket or a gRPC Stream - `pool`: a Pool instance - `address`: the server's (host, port) - - `id`: the id of this socket in it's pool + - `id`: the id of this socket in its pool """ - def __init__(self, sock, pool, address, id): + def __init__(self, connector, pool, address, id, protocol): self.pool_ref = weakref.ref(pool) - self.sock = sock + self.connector = connector self.address = address self.id = id + self.protocol = protocol self.authed = set() self.closed = False self.last_checkin_time = time.monotonic() @@ -671,14 +693,18 @@ class SocketInfo: self.pinned_cursor = False self.active = False self.last_timeout = self.opts.socket_timeout + self.current_timeout = self.last_timeout self.connect_rtt = 0.0 + self.response: Optional[Iterator] = None - def set_socket_timeout(self, timeout): - """Cache last timeout to avoid duplicate calls to sock.settimeout.""" + def set_connector_timeout(self, timeout): + """Cache last timeout to avoid duplicate calls to connector timeout implementations.""" if timeout == self.last_timeout: return self.last_timeout = timeout - self.sock.settimeout(timeout) + self.current_timeout = timeout + if self.protocol == ConnectionProtocol.TCP_SOCKET: # TODO: Implement timeouts for gRPC mode + self.connector.settimeout(timeout) def apply_timeout(self, client, cmd): # CSOT: use remaining timeout when set. @@ -686,7 +712,7 @@ class SocketInfo: if timeout is None: # Reset the socket timeout unless we're performing a streaming monitor check. if not self.more_to_come: - self.set_socket_timeout(self.opts.socket_timeout) + self.set_connector_timeout(self.opts.socket_timeout) return None # RTT validation. rtt = _csot.get_rtt() @@ -701,7 +727,7 @@ class SocketInfo: ) if cmd is not None: cmd["maxTimeMS"] = int(max_time_ms * 1000) - self.set_socket_timeout(timeout) + self.set_connector_timeout(timeout) return timeout def pin_txn(self): @@ -748,7 +774,7 @@ class SocketInfo: awaitable = True # If connect_timeout is None there is no timeout. if self.opts.connect_timeout: - self.set_socket_timeout(self.opts.connect_timeout + heartbeat_frequency) + self.set_connector_timeout(self.opts.connect_timeout + heartbeat_frequency) if not performing_handshake and cluster_time is not None: cmd["$clusterTime"] = cluster_time @@ -910,7 +936,7 @@ class SocketInfo: def send_message(self, message, max_doc_size): """Send a raw BSON message or raise ConnectionFailure. - If a network exception is raised, the socket is closed. + If a network exception is raised, the connector is closed. """ if self.max_bson_size is not None and max_doc_size > self.max_bson_size: raise DocumentTooLarge( @@ -919,17 +945,32 @@ class SocketInfo: ) try: - self.sock.sendall(message) + if self.protocol == ConnectionProtocol.GRPC: + self.response = self.connector.__call__( + iter([message]), + metadata=[ + ("security-uuid", "uuid"), + ("username", "user"), + ("servername", "host.local.10gen.cc"), + ("mongodb-wireversion", "18"), + ("x-forwarded-for", "127.0.0.1:9901"), + ], + ) + else: + self.connector.sendall(message) except BaseException as error: self._raise_connection_failure(error) def receive_message(self, request_id): """Receive a raw BSON message or raise ConnectionFailure. - If any exception is raised, the socket is closed. + If any exception is raised, the connector is closed. """ try: - return receive_message(self, request_id, self.max_message_size) + if self.protocol == ConnectionProtocol.GRPC: + return receive_message_grpc(self, request_id, self.max_message_size) + else: + return receive_message_tcp(self, request_id, self.max_message_size) except BaseException as error: self._raise_connection_failure(error) @@ -1017,13 +1058,13 @@ class SocketInfo: # Note: We catch exceptions to avoid spurious errors on interpreter # shutdown. try: - self.sock.close() + self.connector.close() except Exception: pass def socket_closed(self): """Return True if we know socket has been closed, False otherwise.""" - return self.socket_checker.socket_closed(self.sock) + return self.socket_checker.socket_closed(self.connector) def send_cluster_time(self, command, session, client): """Add $clusterTime.""" @@ -1073,17 +1114,17 @@ class SocketInfo: raise def __eq__(self, other): - return self.sock == other.sock + return self.connector == other.connector def __ne__(self, other): return not self == other def __hash__(self): - return hash(self.sock) + return hash(self.connector) def __repr__(self): - return "SocketInfo({}){} at {}".format( - repr(self.sock), + return "Connection({}){} at {}".format( + repr(self.connector), self.closed and " CLOSED" or "", id(self), ) @@ -1256,7 +1297,7 @@ class Pool: :Parameters: - `address`: a (hostname, port) tuple - `options`: a PoolOptions instance - - `handshake`: whether to call hello for each new SocketInfo + - `handshake`: whether to call hello for each new Connection """ if options.pause_enabled: self.state = PoolState.PAUSED @@ -1318,6 +1359,11 @@ class Pool: self.ncursors = 0 self.ntxns = 0 + if self.opts.protocol == ConnectionProtocol.GRPC: + self.channel = self._create_grpc_channel() + else: + self.channel = None + def ready(self): # Take the lock to avoid the race condition described in PYTHON-2699. with self.lock: @@ -1348,11 +1394,11 @@ class Pool: else: discard: collections.deque = collections.deque() keep: collections.deque = collections.deque() - for sock_info in self.sockets: - if sock_info.service_id == service_id: - discard.append(sock_info) + for connection in self.sockets: + if connection.service_id == service_id: + discard.append(connection) else: - keep.append(sock_info) + keep.append(connection) sockets = discard self.sockets = keep @@ -1367,15 +1413,15 @@ class Pool: # PoolClosedEvent but that reset() SHOULD close sockets *after* # publishing the PoolClearedEvent. if close: - for sock_info in sockets: - sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED) + for connection in sockets: + connection.close_socket(ConnectionClosedReason.POOL_CLOSED) if self.enabled_for_cmap: listeners.publish_pool_closed(self.address) else: if old_state != PoolState.PAUSED and self.enabled_for_cmap: listeners.publish_pool_cleared(self.address, service_id=service_id) - for sock_info in sockets: - sock_info.close_socket(ConnectionClosedReason.STALE) + for connection in sockets: + connection.close_socket(ConnectionClosedReason.STALE) def update_is_writable(self, is_writable): """Updates the is_writable attribute on all sockets currently in the @@ -1395,6 +1441,10 @@ class Pool: def close(self): self._reset(close=True) + def _create_grpc_channel(self): + connection_string = self.address[0] + ":" + str(self.address[1]) + return grpc.insecure_channel(connection_string, options=self.grpc_channel_options()) + def stale_generation(self, gen, service_id): return self.gen.stale(gen, service_id) @@ -1415,8 +1465,8 @@ class Pool: self.sockets and self.sockets[-1].idle_time_seconds() > self.opts.max_idle_time_seconds ): - sock_info = self.sockets.pop() - sock_info.close_socket(ConnectionClosedReason.IDLE) + connection = self.sockets.pop() + connection.close_socket(ConnectionClosedReason.IDLE) while True: with self.size_cond: @@ -1435,14 +1485,14 @@ class Pool: return self._pending += 1 incremented = True - sock_info = self.connect() + connection = self.connect() with self.lock: # Close connection and return if the pool was reset during # socket creation or while acquiring the pool lock. if self.gen.get_overall() != reference_generation: - sock_info.close_socket(ConnectionClosedReason.STALE) + connection.close_socket(ConnectionClosedReason.STALE) return - self.sockets.appendleft(sock_info) + self.sockets.appendleft(connection) finally: if incremented: # Notify after adding the socket to the pool. @@ -1455,7 +1505,7 @@ class Pool: self.size_cond.notify() def connect(self, handler=None): - """Connect to Mongo and return a new SocketInfo. + """Connect to Mongo and return a new Connection. Can raise ConnectionFailure. @@ -1471,7 +1521,12 @@ class Pool: listeners.publish_connection_created(self.address, conn_id) try: - sock = _configured_socket(self.address, self.opts) + if self.opts.protocol == ConnectionProtocol.GRPC: + connector = self.channel.stream_stream( + "/mongodb.CommandService/UnauthenticatedCommandStream" + ) + else: + connector = _configured_socket(self.address, self.opts) except BaseException as error: if self.enabled_for_cmap: listeners.publish_connection_closed( @@ -1483,26 +1538,42 @@ class Pool: raise - sock_info = SocketInfo(sock, self, self.address, conn_id) + connection = Connection(connector, self, self.address, conn_id, self.opts.protocol) try: if self.handshake: - sock_info.hello() - self.is_writable = sock_info.is_writable + connection.hello() + self.is_writable = connection.is_writable if handler: - handler.contribute_socket(sock_info, completed_handshake=False) + handler.contribute_socket(connection, completed_handshake=False) - sock_info.authenticate() + connection.authenticate() except BaseException: - sock_info.close_socket(ConnectionClosedReason.ERROR) + connection.close_socket(ConnectionClosedReason.ERROR) raise - return sock_info + return connection + + def grpc_metadata(self): + return [ + ("security-uuid", "client-uuid"), + ("username", "user"), + ("servername", "host.local.10gen.cc"), + ("mongodb-wireversion", 18), + ("x-forwarded-for", "127.0.0.1:9901"), + ] + + def grpc_channel_options(self): + return [ + ("grpc.default_authority", "host.local.10gen.cc"), + ("grpc.max_receive_message_length", 48000000), + ("grpc.max_send_message_length", 48000000), + ] @contextlib.contextmanager def get_socket(self, handler=None): """Get a socket from the pool. Use with a "with" statement. - Returns a :class:`SocketInfo` object wrapping a connected + Returns a :class:`Connection` object wrapping a connected :class:`socket.socket`. This method should always be used in a with-statement:: @@ -1520,36 +1591,36 @@ class Pool: if self.enabled_for_cmap: listeners.publish_connection_check_out_started(self.address) - sock_info = self._get_socket(handler=handler) + connection = self._get_socket(handler=handler) if self.enabled_for_cmap: - listeners.publish_connection_checked_out(self.address, sock_info.id) + listeners.publish_connection_checked_out(self.address, connection.id) try: - yield sock_info + yield connection except BaseException: # Exception in caller. Ensure the connection gets returned. # Note that when pinned is True, the session owns the # connection and it is responsible for checking the connection # back into the pool. - pinned = sock_info.pinned_txn or sock_info.pinned_cursor + pinned = connection.pinned_txn or connection.pinned_cursor if handler: # Perform SDAM error handling rules while the connection is # still checked out. exc_type, exc_val, _ = sys.exc_info() handler.handle(exc_type, exc_val) - if not pinned and sock_info.active: - self.return_socket(sock_info) + if not pinned and connection.active: + self.return_socket(connection) raise - if sock_info.pinned_txn: + if connection.pinned_txn: with self.lock: - self.__pinned_sockets.add(sock_info) + self.__pinned_sockets.add(connection) self.ntxns += 1 - elif sock_info.pinned_cursor: + elif connection.pinned_cursor: with self.lock: - self.__pinned_sockets.add(sock_info) + self.__pinned_sockets.add(connection) self.ncursors += 1 - elif sock_info.active: - self.return_socket(sock_info) + elif connection.active: + self.return_socket(connection) def _raise_if_not_ready(self, emit_event): if self.state != PoolState.READY: @@ -1560,7 +1631,7 @@ class Pool: _raise_connection_failure(self.address, AutoReconnect("connection pool paused")) def _get_socket(self, handler=None): - """Get or create a SocketInfo. Can raise ConnectionFailure.""" + """Get or create a Connection. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. # See test.test_client:TestClient.test_fork for an example of # what could go wrong otherwise @@ -1600,7 +1671,7 @@ class Pool: self.requests += 1 # We've now acquired the semaphore and must release it on error. - sock_info = None + connection = None incremented = False emitted_event = False try: @@ -1608,7 +1679,7 @@ class Pool: self.active_sockets += 1 incremented = True - while sock_info is None: + while connection is None: # CMAP: we MUST wait for either maxConnecting OR for a socket # to be checked back into the pool. with self._max_connecting_cond: @@ -1624,24 +1695,24 @@ class Pool: self._raise_if_not_ready(emit_event=False) try: - sock_info = self.sockets.popleft() + connection = self.sockets.popleft() except IndexError: self._pending += 1 - if sock_info: # We got a socket from the pool - if self._perished(sock_info): - sock_info = None + if connection: # We got a socket from the pool + if self._perished(connection): + connection = None continue else: # We need to create a new connection try: - sock_info = self.connect(handler=handler) + connection = self.connect(handler=handler) finally: with self._max_connecting_cond: self._pending -= 1 self._max_connecting_cond.notify() except BaseException: - if sock_info: + if connection: # We checked out a socket but authentication failed. - sock_info.close_socket(ConnectionClosedReason.ERROR) + connection.close_socket(ConnectionClosedReason.ERROR) with self.size_cond: self.requests -= 1 if incremented: @@ -1654,45 +1725,45 @@ class Pool: ) raise - sock_info.active = True - return sock_info + connection.active = True + return connection - def return_socket(self, sock_info): + def return_socket(self, connection): """Return the socket to the pool, or if it's closed discard it. :Parameters: - - `sock_info`: The socket to check into the pool. + - `connection`: The socket to check into the pool. """ - txn = sock_info.pinned_txn - cursor = sock_info.pinned_cursor - sock_info.active = False - sock_info.pinned_txn = False - sock_info.pinned_cursor = False - self.__pinned_sockets.discard(sock_info) + txn = connection.pinned_txn + cursor = connection.pinned_cursor + connection.active = False + connection.pinned_txn = False + connection.pinned_cursor = False + self.__pinned_sockets.discard(connection) listeners = self.opts._event_listeners if self.enabled_for_cmap: - listeners.publish_connection_checked_in(self.address, sock_info.id) + listeners.publish_connection_checked_in(self.address, connection.id) if self.pid != os.getpid(): self.reset_without_pause() else: if self.closed: - sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED) - elif sock_info.closed: + connection.close_socket(ConnectionClosedReason.POOL_CLOSED) + elif connection.closed: # CMAP requires the closed event be emitted after the check in. if self.enabled_for_cmap: listeners.publish_connection_closed( - self.address, sock_info.id, ConnectionClosedReason.ERROR + self.address, connection.id, ConnectionClosedReason.ERROR ) else: with self.lock: # Hold the lock to ensure this section does not race with # Pool.reset(). - if self.stale_generation(sock_info.generation, sock_info.service_id): - sock_info.close_socket(ConnectionClosedReason.STALE) + if self.stale_generation(connection.generation, connection.service_id): + connection.close_socket(ConnectionClosedReason.STALE) else: - sock_info.update_last_checkin_time() - sock_info.update_is_writable(self.is_writable) - self.sockets.appendleft(sock_info) + connection.update_last_checkin_time() + connection.update_is_writable(self.is_writable) + self.sockets.appendleft(connection) # Notify any threads waiting to create a connection. self._max_connecting_cond.notify() @@ -1706,7 +1777,7 @@ class Pool: self.operation_count -= 1 self.size_cond.notify() - def _perished(self, sock_info): + def _perished(self, connection): """Return True and close the connection if it is "perished". This side-effecty function checks if this socket has been idle for @@ -1720,24 +1791,24 @@ class Pool: pool, to keep performance reasonable - we can't avoid AutoReconnects completely anyway. """ - idle_time_seconds = sock_info.idle_time_seconds() + idle_time_seconds = connection.idle_time_seconds() # If socket is idle, open a new one. if ( self.opts.max_idle_time_seconds is not None and idle_time_seconds > self.opts.max_idle_time_seconds ): - sock_info.close_socket(ConnectionClosedReason.IDLE) + connection.close_socket(ConnectionClosedReason.IDLE) return True if self._check_interval_seconds is not None and ( 0 == self._check_interval_seconds or idle_time_seconds > self._check_interval_seconds ): - if sock_info.socket_closed(): - sock_info.close_socket(ConnectionClosedReason.ERROR) + if connection.socket_closed(): + connection.close_socket(ConnectionClosedReason.ERROR) return True - if self.stale_generation(sock_info.generation, sock_info.service_id): - sock_info.close_socket(ConnectionClosedReason.STALE) + if self.stale_generation(connection.generation, connection.service_id): + connection.close_socket(ConnectionClosedReason.STALE) return True return False @@ -1772,5 +1843,5 @@ class Pool: # Avoid ResourceWarnings in Python 3 # Close all sockets without calling reset() or close() because it is # not safe to acquire a lock in __del__. - for sock_info in self.sockets: - sock_info.close_socket(None) + for connection in self.sockets: + connection.close_socket(None) diff --git a/pymongo/response.py b/pymongo/response.py index bd4795bfb..004e29c8b 100644 --- a/pymongo/response.py +++ b/pymongo/response.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from datetime import timedelta from pymongo.message import _OpMsg, _OpReply - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.typings import _Address @@ -91,7 +91,7 @@ class PinnedResponse(Response): self, data: Union[_OpMsg, _OpReply], address: _Address, - socket_info: SocketInfo, + socket_info: Connection, request_id: int, duration: Optional[timedelta], from_command: bool, @@ -103,7 +103,7 @@ class PinnedResponse(Response): :Parameters: - `data`: A network response message. - `address`: (host, port) of the source server. - - `socket_info`: The SocketInfo used for the initial query. + - `socket_info`: The Connection used for the initial query. - `request_id`: The request id of this operation. - `duration`: The duration of the operation. - `from_command`: If the response is the result of a db command. @@ -116,8 +116,8 @@ class PinnedResponse(Response): self._more_to_come = more_to_come @property - def socket_info(self) -> SocketInfo: - """The SocketInfo used for the initial query. + def socket_info(self) -> Connection: + """The Connection used for the initial query. The server will send batches on this socket, without waiting for getMores from the client, until the result set is exhausted or there diff --git a/pymongo/server.py b/pymongo/server.py index 349af4a41..453d74889 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -42,7 +42,7 @@ if TYPE_CHECKING: from pymongo.mongo_client import _MongoClientErrorHandler from pymongo.monitor import Monitor from pymongo.monitoring import _EventListeners - from pymongo.pool import Pool, SocketInfo + from pymongo.pool import Connection, Pool from pymongo.server_description import ServerDescription _CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} @@ -105,7 +105,7 @@ class Server: @_handle_reauth def run_operation( self, - sock_info: SocketInfo, + connection: Connection, operation: Union[_Query, _GetMore], read_preference: bool, listeners: _EventListeners, @@ -118,7 +118,7 @@ class Server: Can raise ConnectionFailure, OperationFailure, etc. :Parameters: - - `sock_info`: A SocketInfo instance. + - `connection`: A Connection instance. - `operation`: A _Query or _GetMore object. - `read_preference`: The read preference to use. - `listeners`: Instance of _EventListeners or None. @@ -129,27 +129,27 @@ class Server: if publish: start = datetime.now() - use_cmd = operation.use_command(sock_info) + use_cmd = operation.use_command(connection) more_to_come = operation.sock_mgr and operation.sock_mgr.more_to_come if more_to_come: request_id = 0 else: - message = operation.get_message(read_preference, sock_info, use_cmd) + message = operation.get_message(read_preference, connection, use_cmd) request_id, data, max_doc_size = self._split_message(message) if publish: - cmd, dbn = operation.as_command(sock_info) + cmd, dbn = operation.as_command(connection) listeners.publish_command_start( - cmd, dbn, request_id, sock_info.address, service_id=sock_info.service_id + cmd, dbn, request_id, connection.address, service_id=connection.service_id ) start = datetime.now() try: if more_to_come: - reply = sock_info.receive_message(None) + reply = connection.receive_message(None) else: - sock_info.send_message(data, max_doc_size) - reply = sock_info.receive_message(request_id) + connection.send_message(data, max_doc_size) + reply = connection.receive_message(request_id) # Unpack and check for command errors. if use_cmd: @@ -168,7 +168,7 @@ class Server: if use_cmd: first = docs[0] operation.client._process_response(first, operation.session) - _check_command_response(first, sock_info.max_wire_version) + _check_command_response(first, connection.max_wire_version) except Exception as exc: if publish: duration = datetime.now() - start @@ -181,8 +181,8 @@ class Server: failure, operation.name, request_id, - sock_info.address, - service_id=sock_info.service_id, + connection.address, + service_id=connection.service_id, ) raise @@ -205,8 +205,8 @@ class Server: res, operation.name, request_id, - sock_info.address, - service_id=sock_info.service_id, + connection.address, + service_id=connection.service_id, ) # Decrypt response. @@ -219,7 +219,7 @@ class Server: response: Response if client._should_pin_cursor(operation.session) or operation.exhaust: - sock_info.pin_cursor() + connection.pin_cursor() if isinstance(reply, _OpMsg): # In OP_MSG, the server keeps sending only if the # more_to_come flag is set. @@ -232,7 +232,7 @@ class Server: response = PinnedResponse( data=reply, address=self._description.address, - socket_info=sock_info, + socket_info=connection, duration=duration, request_id=request_id, from_command=use_cmd, @@ -253,7 +253,7 @@ class Server: def get_socket( self, handler: Optional[_MongoClientErrorHandler] = None - ) -> ContextManager[SocketInfo]: + ) -> ContextManager[Connection]: return self.pool.get_socket(handler) @property diff --git a/pymongo/topology.py b/pymongo/topology.py index 0a2eaf942..1c3eaa58a 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -778,6 +778,7 @@ class Topology: driver=options.driver, pause_enabled=False, server_api=options.server_api, + protocol=options.protocol, ) return self._settings.pool_class(address, monitor_pool_options, handshake=False) diff --git a/pyproject.toml b/pyproject.toml index 85b322af0..92e23ad82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ classifiers = [ ] dependencies = [ "dnspython>=1.16.0,<3.0.0", + "grpcio>=1.56.0" ] [project.optional-dependencies] diff --git a/test/grpc_example.py b/test/grpc_example.py new file mode 100644 index 000000000..09699ddf5 --- /dev/null +++ b/test/grpc_example.py @@ -0,0 +1,7 @@ +from pymongo import MongoClient + +client = MongoClient("mongodb://host9.local.10gen.cc:9901", grpc=True, loadBalanced=True) + +dbs = client.list_databases() +for db in dbs: + print(db["name"], ": ", client[db["name"]].list_collection_names()) diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 2e7fda21e..4ff477e86 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -48,10 +48,10 @@ class MockPool(Pool): client.mock_standalones + client.mock_members + client.mock_mongoses ), ("bad host: %s" % host_and_port) - with Pool.get_socket(self, handler) as sock_info: - sock_info.mock_host = self.mock_host - sock_info.mock_port = self.mock_port - yield sock_info + with Pool.get_socket(self, handler) as connection: + connection.mock_host = self.mock_host + connection.mock_port = self.mock_port + yield connection class DummyMonitor: diff --git a/test/test_client.py b/test/test_client.py index bba6b3728..178061d28 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -96,7 +96,7 @@ from pymongo.errors import ( ) from pymongo.mongo_client import MongoClient from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent -from pymongo.pool import _METADATA, PoolOptions, SocketInfo +from pymongo.pool import _METADATA, Connection, PoolOptions from pymongo.read_preferences import ReadPreference from pymongo.server_description import ServerDescription from pymongo.server_selectors import readable_server_selector, writable_server_selector @@ -541,10 +541,10 @@ class TestClient(IntegrationTest): # Assert reaper doesn't remove sockets when maxIdleTimeMS not set client = rs_or_single_client() server = client._get_topology().select_server(readable_server_selector) - with server._pool.get_socket() as sock_info: + with server._pool.get_socket() as connection: pass self.assertEqual(1, len(server._pool.sockets)) - self.assertTrue(sock_info in server._pool.sockets) + self.assertTrue(connection in server._pool.sockets) client.close() def test_max_idle_time_reaper_removes_stale_minPoolSize(self): @@ -552,12 +552,12 @@ class TestClient(IntegrationTest): # Assert reaper removes idle socket and replaces it with a new one client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) server = client._get_topology().select_server(readable_server_selector) - with server._pool.get_socket() as sock_info: + with server._pool.get_socket() as connection: pass # When the reaper runs at the same time as the get_socket, two # sockets could be created and checked into the pool. self.assertGreaterEqual(len(server._pool.sockets), 1) - wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket") + wait_until(lambda: connection not in server._pool.sockets, "remove stale socket") wait_until(lambda: 1 <= len(server._pool.sockets), "replace stale socket") client.close() @@ -566,12 +566,12 @@ class TestClient(IntegrationTest): # Assert reaper respects maxPoolSize when adding new sockets. client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) server = client._get_topology().select_server(readable_server_selector) - with server._pool.get_socket() as sock_info: + with server._pool.get_socket() as connection: pass # When the reaper runs at the same time as the get_socket, # maxPoolSize=1 should prevent two sockets from being created. self.assertEqual(1, len(server._pool.sockets)) - wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket") + wait_until(lambda: connection not in server._pool.sockets, "remove stale socket") wait_until(lambda: 1 == len(server._pool.sockets), "replace stale socket") client.close() @@ -605,39 +605,39 @@ class TestClient(IntegrationTest): wait_until(lambda: 10 == len(server._pool.sockets), "pool initialized with 10 sockets") # Assert that if a socket is closed, a new one takes its place - with server._pool.get_socket() as sock_info: - sock_info.close_socket(None) + with server._pool.get_socket() as connection: + connection.close_socket(None) wait_until( lambda: 10 == len(server._pool.sockets), "a closed socket gets replaced from the pool", ) - self.assertFalse(sock_info in server._pool.sockets) + self.assertFalse(connection in server._pool.sockets) def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): client = rs_or_single_client(maxIdleTimeMS=500) server = client._get_topology().select_server(readable_server_selector) - with server._pool.get_socket() as sock_info: + with server._pool.get_socket() as connection: pass self.assertEqual(1, len(server._pool.sockets)) time.sleep(1) # Sleep so that the socket becomes stale. with server._pool.get_socket() as new_sock_info: - self.assertNotEqual(sock_info, new_sock_info) + self.assertNotEqual(connection, new_sock_info) self.assertEqual(1, len(server._pool.sockets)) - self.assertFalse(sock_info in server._pool.sockets) + self.assertFalse(connection in server._pool.sockets) self.assertTrue(new_sock_info in server._pool.sockets) # Test that sockets are reused if maxIdleTimeMS is not set. client = rs_or_single_client() server = client._get_topology().select_server(readable_server_selector) - with server._pool.get_socket() as sock_info: + with server._pool.get_socket() as connection: pass self.assertEqual(1, len(server._pool.sockets)) time.sleep(1) with server._pool.get_socket() as new_sock_info: - self.assertEqual(sock_info, new_sock_info) + self.assertEqual(connection, new_sock_info) self.assertEqual(1, len(server._pool.sockets)) def test_constants(self): @@ -1130,8 +1130,8 @@ class TestClient(IntegrationTest): def test_socketKeepAlive(self): pool = get_pool(self.client) - with pool.get_socket() as sock_info: - keepalive = sock_info.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) + with pool.get_socket() as connection: + keepalive = connection.connector.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) self.assertTrue(keepalive) @no_type_check @@ -1326,13 +1326,13 @@ class TestClient(IntegrationTest): connected(client) # Cause a network error. - sock_info = one(pool.sockets) - sock_info.sock.close() + connection = one(pool.sockets) + connection.connector.close() cursor = collection.find(cursor_type=CursorType.EXHAUST) with self.assertRaises(ConnectionFailure): next(cursor) - self.assertTrue(sock_info.closed) + self.assertTrue(connection.closed) # The semaphore was decremented despite the error. self.assertEqual(0, pool.requests) @@ -1350,7 +1350,7 @@ class TestClient(IntegrationTest): socket_info = one(pool.sockets) socket_info.sock.close() - # SocketInfo.authenticate logs, but gets a socket.error. Should be + # Connection.authenticate logs, but gets a socket.error. Should be # reraised as AutoReconnect. self.assertRaises(AutoReconnect, c.test.collection.find_one) @@ -1847,7 +1847,7 @@ class TestExhaustCursor(IntegrationTest): collection = client.pymongo_test.test pool = get_pool(client) - sock_info = one(pool.sockets) + connection = one(pool.sockets) # This will cause OperationFailure in all mongo versions since # the value for $orderby must be a document. @@ -1856,10 +1856,10 @@ class TestExhaustCursor(IntegrationTest): ) self.assertRaises(OperationFailure, cursor.next) - self.assertFalse(sock_info.closed) + self.assertFalse(connection.closed) # The socket was checked in and the semaphore was decremented. - self.assertIn(sock_info, pool.sockets) + self.assertIn(connection, pool.sockets) self.assertEqual(0, pool.requests) def test_exhaust_getmore_server_error(self): @@ -1874,7 +1874,7 @@ class TestExhaustCursor(IntegrationTest): pool = get_pool(client) pool._check_interval_seconds = None # Never check. - sock_info = one(pool.sockets) + connection = one(pool.sockets) cursor = collection.find(cursor_type=CursorType.EXHAUST) @@ -1884,21 +1884,21 @@ class TestExhaustCursor(IntegrationTest): # Cause a server error on getmore. def receive_message(request_id): # Discard the actual server response. - SocketInfo.receive_message(sock_info, request_id) + Connection.receive_message(connection, request_id) # responseFlags bit 1 is QueryFailure. msg = struct.pack("