PYTHON-3879 Rename SocketInfo to Connection (#1329)

This commit is contained in:
Noah Stapp 2023-07-28 10:04:16 -07:00 committed by GitHub
parent c945ec6302
commit c88ae79e58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 731 additions and 735 deletions

View File

@ -78,7 +78,7 @@ will receive the following error::
File "/Library/Python/2.7/site-packages/pymongo/collection.py", line 1560, in count
return self._count(cmd, collation, session)
File "/Library/Python/2.7/site-packages/pymongo/collection.py", line 1504, in _count
with self._socket_for_reads() as (sock_info, slave_ok):
with self._socket_for_reads() as (connection, slave_ok):
File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/contextlib.py", line 17, in __enter__
return self.gen.next()
File "/Library/Python/2.7/site-packages/pymongo/mongo_client.py", line 982, in _socket_for_reads

View File

@ -29,7 +29,7 @@ if TYPE_CHECKING:
from pymongo.collection import Collection
from pymongo.command_cursor import CommandCursor
from pymongo.database import Database
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
from pymongo.read_preferences import _ServerMode
from pymongo.server import Server
from pymongo.typings import _Pipeline
@ -52,7 +52,7 @@ class _AggregationCommand:
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
user_fields: Optional[MutableMapping[str, Any]] = None,
result_processor: Optional[Callable[[Mapping[str, Any], SocketInfo], None]] = None,
result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None,
comment: Any = None,
) -> None:
if "explain" in options:
@ -134,7 +134,7 @@ class _AggregationCommand:
self,
session: ClientSession,
server: Server,
sock_info: SocketInfo,
conn: Connection,
read_preference: _ServerMode,
) -> CommandCursor:
# Serialize command.
@ -146,7 +146,7 @@ class _AggregationCommand:
# - server version is >= 4.2 or
# - server version is >= 3.2 and pipeline doesn't use $out
if ("readConcern" not in cmd) and (
not self._performs_write or (sock_info.max_wire_version >= 8)
not self._performs_write or (conn.max_wire_version >= 8)
):
read_concern = self._target.read_concern
else:
@ -161,7 +161,7 @@ class _AggregationCommand:
write_concern = None
# Run command.
result = sock_info.command(
result = conn.command(
self._database.name,
cmd,
read_preference,
@ -176,7 +176,7 @@ class _AggregationCommand:
)
if self._result_processor:
self._result_processor(result, sock_info)
self._result_processor(result, conn)
# Extract cursor from result or mock/fake one if necessary.
if "cursor" in result:
@ -193,14 +193,14 @@ class _AggregationCommand:
cmd_cursor = self._cursor_class(
self._cursor_collection(cursor),
cursor,
sock_info.address,
conn.address,
batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms,
session=session,
explicit_session=self._explicit_session,
comment=self._options.get("comment"),
)
cmd_cursor._maybe_pin_connection(sock_info)
cmd_cursor._maybe_pin_connection(conn)
return cmd_cursor

View File

@ -35,7 +35,7 @@ from pymongo.saslprep import saslprep
if TYPE_CHECKING:
from pymongo.hello import Hello
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
HAVE_KERBEROS = True
_USE_PRINCIPAL = False
@ -220,9 +220,7 @@ def _authenticate_scram_start(
return nonce, first_bare, cmd
def _authenticate_scram(
credentials: MongoCredential, sock_info: SocketInfo, mechanism: str
) -> None:
def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanism: str) -> None:
"""Authenticate using SCRAM."""
username = credentials.username
if mechanism == "SCRAM-SHA-256":
@ -239,13 +237,13 @@ def _authenticate_scram(
# Make local
_hmac = hmac.HMAC
ctx = sock_info.auth_ctx
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
nonce, first_bare = ctx.scram_data
res = ctx.speculative_authenticate
else:
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
res = sock_info.command(source, cmd)
res = conn.command(source, cmd)
server_first = res["payload"]
parsed = _parse_scram_response(server_first)
@ -285,7 +283,7 @@ def _authenticate_scram(
("payload", Binary(client_final)),
]
)
res = sock_info.command(source, cmd)
res = conn.command(source, cmd)
parsed = _parse_scram_response(res["payload"])
if not hmac.compare_digest(parsed[b"v"], server_sig):
@ -301,7 +299,7 @@ def _authenticate_scram(
("payload", Binary(b"")),
]
)
res = sock_info.command(source, cmd)
res = conn.command(source, cmd)
if not res["done"]:
raise OperationFailure("SASL conversation failed to complete.")
@ -345,7 +343,7 @@ def _canonicalize_hostname(hostname: str) -> str:
return name[0].lower()
def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> None:
def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using GSSAPI."""
if not HAVE_KERBEROS:
raise ConfigurationError(
@ -358,7 +356,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
props = credentials.mechanism_properties
# Starting here and continuing through the while loop below - establish
# the security context. See RFC 4752, Section 3.1, first paragraph.
host = sock_info.address[0]
host = conn.address[0]
if props.canonicalize_host_name:
host = _canonicalize_hostname(host)
service = props.service_name + "@" + host
@ -413,7 +411,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
("autoAuthorize", 1),
]
)
response = sock_info.command("$external", cmd)
response = conn.command("$external", cmd)
# Limit how many times we loop to catch protocol / library issues
for _ in range(10):
@ -430,7 +428,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
("payload", payload),
]
)
response = sock_info.command("$external", cmd)
response = conn.command("$external", cmd)
if result == kerberos.AUTH_GSS_COMPLETE:
break
@ -453,7 +451,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
("payload", payload),
]
)
sock_info.command("$external", cmd)
conn.command("$external", cmd)
finally:
kerberos.authGSSClientClean(ctx)
@ -462,7 +460,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
raise OperationFailure(str(exc))
def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) -> None:
def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using SASL PLAIN (RFC 4616)"""
source = credentials.source
username = credentials.username
@ -476,52 +474,50 @@ def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) ->
("autoAuthorize", 1),
]
)
sock_info.command(source, cmd)
conn.command(source, cmd)
def _authenticate_x509(credentials: MongoCredential, sock_info: SocketInfo) -> None:
def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using MONGODB-X509."""
ctx = sock_info.auth_ctx
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
# MONGODB-X509 is done after the speculative auth step.
return
cmd = _X509Context(credentials, sock_info.address).speculate_command()
sock_info.command("$external", cmd)
cmd = _X509Context(credentials, conn.address).speculate_command()
conn.command("$external", cmd)
def _authenticate_mongo_cr(credentials: MongoCredential, sock_info: SocketInfo) -> None:
def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using MONGODB-CR."""
source = credentials.source
username = credentials.username
password = credentials.password
# Get a nonce
response = sock_info.command(source, {"getnonce": 1})
response = conn.command(source, {"getnonce": 1})
nonce = response["nonce"]
key = _auth_key(nonce, username, password)
# Actually authenticate
query = SON([("authenticate", 1), ("user", username), ("nonce", nonce), ("key", key)])
sock_info.command(source, query)
conn.command(source, query)
def _authenticate_default(credentials: MongoCredential, sock_info: SocketInfo) -> None:
if sock_info.max_wire_version >= 7:
if sock_info.negotiated_mechs:
mechs = sock_info.negotiated_mechs
def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None:
if conn.max_wire_version >= 7:
if conn.negotiated_mechs:
mechs = conn.negotiated_mechs
else:
source = credentials.source
cmd = sock_info.hello_cmd()
cmd = conn.hello_cmd()
cmd["saslSupportedMechs"] = source + "." + credentials.username
mechs = sock_info.command(source, cmd, publish_events=False).get(
"saslSupportedMechs", []
)
mechs = conn.command(source, cmd, publish_events=False).get("saslSupportedMechs", [])
if "SCRAM-SHA-256" in mechs:
return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-256")
return _authenticate_scram(credentials, conn, "SCRAM-SHA-256")
else:
return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1")
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
else:
return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1")
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
_AUTH_MAP: Mapping[str, Callable] = {
@ -606,12 +602,12 @@ _SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = {
def authenticate(
credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool = False
credentials: MongoCredential, conn: Connection, reauthenticate: bool = False
) -> None:
"""Authenticate sock_info."""
"""Authenticate connection."""
mechanism = credentials.mechanism
auth_func = _AUTH_MAP[mechanism]
if mechanism == "MONGODB-OIDC":
_authenticate_oidc(credentials, sock_info, reauthenticate)
_authenticate_oidc(credentials, conn, reauthenticate)
else:
auth_func(credentials, sock_info)
auth_func(credentials, conn)

View File

@ -49,7 +49,7 @@ from pymongo.errors import ConfigurationError, OperationFailure
if TYPE_CHECKING:
from bson.typings import _ReadableBuffer
from pymongo.auth import MongoCredential
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
class _AwsSaslContext(AwsSaslContext): # type: ignore
@ -67,7 +67,7 @@ class _AwsSaslContext(AwsSaslContext): # type: ignore
return bson.decode(data)
def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> None:
def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using MONGODB-AWS."""
if not _HAVE_MONGODB_AWS:
raise ConfigurationError(
@ -75,7 +75,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No
"install with: python -m pip install 'pymongo[aws]'"
)
if sock_info.max_wire_version < 9:
if conn.max_wire_version < 9:
raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later")
try:
@ -90,7 +90,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No
client_first = SON(
[("saslStart", 1), ("mechanism", "MONGODB-AWS"), ("payload", client_payload)]
)
server_first = sock_info.command("$external", client_first)
server_first = conn.command("$external", client_first)
res = server_first
# Limit how many times we loop to catch protocol / library issues
for _ in range(10):
@ -102,7 +102,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No
("payload", client_payload),
]
)
res = sock_info.command("$external", cmd)
res = conn.command("$external", cmd)
if res["done"]:
# SASL complete.
break

View File

@ -29,7 +29,7 @@ from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE
if TYPE_CHECKING:
from pymongo.auth import MongoCredential
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
@dataclass
@ -242,25 +242,23 @@ class _OIDCAuthenticator:
self.idp_resp = None
self.token_exp_utc = None
def run_command(
self, sock_info: SocketInfo, cmd: Mapping[str, Any]
) -> Optional[Mapping[str, Any]]:
def run_command(self, conn: Connection, cmd: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
try:
return sock_info.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
except OperationFailure as exc:
self.clear()
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
if "jwt" in bson.decode(cmd["payload"]):
if self.idp_info_gen_id > self.reauth_gen_id:
raise
return self.authenticate(sock_info, reauthenticate=True)
return self.authenticate(conn, reauthenticate=True)
raise
def authenticate(
self, sock_info: SocketInfo, reauthenticate: bool = False
self, conn: Connection, reauthenticate: bool = False
) -> Optional[Mapping[str, Any]]:
if reauthenticate:
prev_id = getattr(sock_info, "oidc_token_gen_id", None)
prev_id = getattr(conn, "oidc_token_gen_id", None)
# Check if we've already changed tokens.
if prev_id == self.token_gen_id:
self.reauth_gen_id = self.idp_info_gen_id
@ -268,7 +266,7 @@ class _OIDCAuthenticator:
if not self.properties.refresh_token_callback:
self.clear()
ctx = sock_info.auth_ctx
ctx = conn.auth_ctx
cmd = None
if ctx and ctx.speculate_succeeded():
@ -276,10 +274,10 @@ class _OIDCAuthenticator:
else:
cmd = self.auth_start_cmd()
assert cmd is not None
resp = self.run_command(sock_info, cmd)
resp = self.run_command(conn, cmd)
if resp["done"]:
sock_info.oidc_token_gen_id = self.token_gen_id
conn.oidc_token_gen_id = self.token_gen_id
return None
server_resp: Dict = bson.decode(resp["payload"])
@ -289,7 +287,7 @@ class _OIDCAuthenticator:
conversation_id = resp["conversationId"]
token = self.get_current_token()
sock_info.oidc_token_gen_id = self.token_gen_id
conn.oidc_token_gen_id = self.token_gen_id
bin_payload = Binary(bson.encode({"jwt": token}))
cmd = SON(
[
@ -298,7 +296,7 @@ class _OIDCAuthenticator:
("payload", bin_payload),
]
)
resp = self.run_command(sock_info, cmd)
resp = self.run_command(conn, cmd)
if not resp["done"]:
self.clear()
raise OperationFailure("SASL conversation failed to complete.")
@ -306,8 +304,8 @@ class _OIDCAuthenticator:
def _authenticate_oidc(
credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool
credentials: MongoCredential, conn: Connection, reauthenticate: bool
) -> Optional[Mapping[str, Any]]:
"""Authenticate using MONGODB-OIDC."""
authenticator = _get_authenticator(credentials, sock_info.address)
return authenticator.authenticate(sock_info, reauthenticate=reauthenticate)
authenticator = _get_authenticator(credentials, conn.address)
return authenticator.authenticate(conn, reauthenticate=reauthenticate)

View File

@ -65,7 +65,7 @@ from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from pymongo.collection import Collection
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
_DELETE_ALL: int = 0
@ -311,7 +311,7 @@ class _Bulk:
generator: Iterator[Any],
write_concern: WriteConcern,
session: Optional[ClientSession],
sock_info: SocketInfo,
conn: Connection,
op_id: int,
retryable: bool,
full_result: MutableMapping[str, Any],
@ -326,9 +326,9 @@ class _Bulk:
self.next_run = None
run = self.current_run
# sock_info.command validates the session, but we use
# sock_info.write_command.
sock_info.validate_session(client, session)
# Connection.command validates the session, but we use
# Connection.write_command
conn.validate_session(client, session)
last_run = False
while run:
@ -341,7 +341,7 @@ class _Bulk:
bwc = self.bulk_ctx_class(
db_name,
cmd_name,
sock_info,
conn,
op_id,
listeners,
session,
@ -369,11 +369,11 @@ class _Bulk:
if retryable and not self.started_retryable_write:
session._start_retryable_write()
self.started_retryable_write = True
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, sock_info)
sock_info.send_cluster_time(cmd, session, client)
sock_info.add_server_api(cmd)
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
conn.send_cluster_time(cmd, session, client)
conn.add_server_api(cmd)
# CSOT: apply timeout before encoding the command.
sock_info.apply_timeout(client, cmd)
conn.apply_timeout(client, cmd)
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible in one command.
@ -430,13 +430,13 @@ class _Bulk:
op_id = _randint()
def retryable_bulk(
session: Optional[ClientSession], sock_info: SocketInfo, retryable: bool
session: Optional[ClientSession], conn: Connection, retryable: bool
) -> None:
self._execute_command(
generator,
write_concern,
session,
sock_info,
conn,
op_id,
retryable,
full_result,
@ -450,7 +450,7 @@ class _Bulk:
_raise_bulk_write_error(full_result)
return full_result
def execute_op_msg_no_results(self, sock_info: SocketInfo, generator: Iterator[Any]) -> None:
def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None:
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
db_name = self.collection.database.name
client = self.collection.database.client
@ -466,7 +466,7 @@ class _Bulk:
bwc = self.bulk_ctx_class(
db_name,
cmd_name,
sock_info,
conn,
op_id,
listeners,
None,
@ -482,7 +482,7 @@ class _Bulk:
("writeConcern", {"w": 0}),
]
)
sock_info.add_server_api(cmd)
conn.add_server_api(cmd)
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible.
to_send = bwc.execute_unack(cmd, ops, client)
@ -491,7 +491,7 @@ class _Bulk:
def execute_command_no_results(
self,
sock_info: SocketInfo,
conn: Connection,
generator: Iterator[Any],
write_concern: WriteConcern,
) -> None:
@ -516,7 +516,7 @@ class _Bulk:
generator,
initial_write_concern,
None,
sock_info,
conn,
op_id,
False,
full_result,
@ -527,7 +527,7 @@ class _Bulk:
def execute_no_results(
self,
sock_info: SocketInfo,
conn: Connection,
generator: Iterator[Any],
write_concern: WriteConcern,
) -> None:
@ -538,11 +538,11 @@ class _Bulk:
raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.")
# Guard against unsupported unacknowledged writes.
unack = write_concern and not write_concern.acknowledged
if unack and self.uses_hint_delete and sock_info.max_wire_version < 9:
if unack and self.uses_hint_delete and conn.max_wire_version < 9:
raise ConfigurationError(
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
)
if unack and self.uses_hint_update and sock_info.max_wire_version < 8:
if unack and self.uses_hint_update and conn.max_wire_version < 8:
raise ConfigurationError(
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
)
@ -553,8 +553,8 @@ class _Bulk:
)
if self.ordered:
return self.execute_command_no_results(sock_info, generator, write_concern)
return self.execute_op_msg_no_results(sock_info, generator)
return self.execute_command_no_results(conn, generator, write_concern)
return self.execute_op_msg_no_results(conn, generator)
def execute(self, write_concern: WriteConcern, session: Optional[ClientSession]) -> Any:
"""Execute operations."""
@ -573,8 +573,8 @@ class _Bulk:
client = self.collection.database.client
if not write_concern.acknowledged:
with client._socket_for_writes(session) as sock_info:
self.execute_no_results(sock_info, generator, write_concern)
with client._conn_for_writes(session) as connection:
self.execute_no_results(connection, generator, write_concern)
return None
else:
return self.execute_command(generator, write_concern, session)

View File

@ -78,7 +78,7 @@ if TYPE_CHECKING:
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.mongo_client import MongoClient
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
def _resumable(exc: PyMongoError) -> bool:
@ -213,7 +213,7 @@ class ChangeStream(Generic[_DocumentType]):
full_pipeline.extend(self._pipeline)
return full_pipeline
def _process_result(self, result: Mapping[str, Any], sock_info: SocketInfo) -> None:
def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None:
"""Callback that caches the postBatchResumeToken or
startAtOperationTime from a changeStream aggregate command response
containing an empty batch of change documents.
@ -228,7 +228,7 @@ class ChangeStream(Generic[_DocumentType]):
self._start_at_operation_time is None
and self._uses_resume_after is False
and self._uses_start_after is False
and sock_info.max_wire_version >= 7
and conn.max_wire_version >= 7
):
self._start_at_operation_time = result.get("operationTime")
# PYTHON-2181: informative error on missing operationTime.

View File

@ -160,7 +160,7 @@ from bson.int64 import Int64
from bson.son import SON
from bson.timestamp import Timestamp
from pymongo import _csot
from pymongo.cursor import _SocketManager
from pymongo.cursor import _ConnectionManager
from pymongo.errors import (
ConfigurationError,
ConnectionFailure,
@ -178,7 +178,7 @@ from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from types import TracebackType
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
from pymongo.server import Server
@ -400,7 +400,7 @@ class _Transaction:
self.state = _TxnState.NONE
self.sharded = False
self.pinned_address: Optional[Tuple[str, Optional[int]]] = None
self.sock_mgr: Optional[_SocketManager] = None
self.conn_mgr: Optional[_ConnectionManager] = None
self.recovery_token = None
self.attempt = 0
self.client = client
@ -412,23 +412,23 @@ class _Transaction:
return self.state == _TxnState.STARTING
@property
def pinned_conn(self) -> Optional[SocketInfo]:
if self.active() and self.sock_mgr:
return self.sock_mgr.sock
def pinned_conn(self) -> Optional[Connection]:
if self.active() and self.conn_mgr:
return self.conn_mgr.conn
return None
def pin(self, server: Server, sock_info: SocketInfo) -> None:
def pin(self, server: Server, conn: Connection) -> None:
self.sharded = True
self.pinned_address = server.description.address
if server.description.server_type == SERVER_TYPE.LoadBalancer:
sock_info.pin_txn()
self.sock_mgr = _SocketManager(sock_info, False)
conn.pin_txn()
self.conn_mgr = _ConnectionManager(conn, False)
def unpin(self) -> None:
self.pinned_address = None
if self.sock_mgr:
self.sock_mgr.close()
self.sock_mgr = None
if self.conn_mgr:
self.conn_mgr.close()
self.conn_mgr = None
def reset(self) -> None:
self.unpin()
@ -438,11 +438,11 @@ class _Transaction:
self.attempt = 0
def __del__(self) -> None:
if self.sock_mgr:
if self.conn_mgr:
# Reuse the cursor closing machinery to return the socket to the
# pool soon.
self.client._close_cursor_soon(0, None, self.sock_mgr)
self.sock_mgr = None
self.client._close_cursor_soon(0, None, self.conn_mgr)
self.conn_mgr = None
def _reraise_with_unknown_commit(exc: Any) -> NoReturn:
@ -839,12 +839,12 @@ class ClientSession:
- `command_name`: Either "commitTransaction" or "abortTransaction".
"""
def func(session: ClientSession, sock_info: SocketInfo, retryable: bool) -> Dict[str, Any]:
return self._finish_transaction(sock_info, command_name)
def func(session: ClientSession, conn: Connection, retryable: bool) -> Dict[str, Any]:
return self._finish_transaction(conn, command_name)
return self._client._retry_internal(True, func, self, None)
def _finish_transaction(self, sock_info: SocketInfo, command_name: str) -> Dict[str, Any]:
def _finish_transaction(self, conn: Connection, command_name: str) -> Dict[str, Any]:
self._transaction.attempt += 1
opts = self._transaction.opts
assert opts
@ -868,7 +868,7 @@ class ClientSession:
cmd["recoveryToken"] = self._transaction.recovery_token
return self._client.admin._command(
sock_info, cmd, session=self, write_concern=wc, parse_write_concern_error=True
conn, cmd, session=self, write_concern=wc, parse_write_concern_error=True
)
def _advance_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None:
@ -954,13 +954,13 @@ class ClientSession:
return None
@property
def _pinned_connection(self) -> Optional[SocketInfo]:
def _pinned_connection(self) -> Optional[Connection]:
"""The connection this transaction was started on."""
return self._transaction.pinned_conn
def _pin(self, server: Server, sock_info: SocketInfo) -> None:
def _pin(self, server: Server, conn: Connection) -> None:
"""Pin this session to the given Server or to the given connection."""
self._transaction.pin(server, sock_info)
self._transaction.pin(server, conn)
def _unpin(self) -> None:
"""Unpin this session from any pinned Server."""
@ -985,12 +985,12 @@ class ClientSession:
command: MutableMapping[str, Any],
is_retryable: bool,
read_preference: ReadPreference,
sock_info: SocketInfo,
conn: Connection,
) -> None:
self._check_ended()
self._materialize()
if self.options.snapshot:
self._update_read_concern(command, sock_info)
self._update_read_concern(command, conn)
self._server_session.last_use = time.monotonic()
command["lsid"] = self._server_session.session_id
@ -1016,7 +1016,7 @@ class ClientSession:
rc = self._transaction.opts.read_concern.document
if rc:
command["readConcern"] = rc
self._update_read_concern(command, sock_info)
self._update_read_concern(command, conn)
command["txnNumber"] = self._server_session.transaction_id
command["autocommit"] = False
@ -1025,11 +1025,11 @@ class ClientSession:
self._check_ended()
self._server_session.inc_transaction_id()
def _update_read_concern(self, cmd: MutableMapping[str, Any], sock_info: SocketInfo) -> None:
def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: Connection) -> None:
if self.options.causal_consistency and self.operation_time is not None:
cmd.setdefault("readConcern", {})["afterClusterTime"] = self.operation_time
if self.options.snapshot:
if sock_info.max_wire_version < 13:
if conn.max_wire_version < 13:
raise ConfigurationError("Snapshot reads require MongoDB 5.0 or later")
rc = cmd.setdefault("readConcern", {})
rc["level"] = "snapshot"

View File

@ -126,7 +126,7 @@ if TYPE_CHECKING:
from pymongo.client_session import ClientSession
from pymongo.collation import Collation
from pymongo.database import Database
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
from pymongo.read_concern import ReadConcern
from pymongo.server import Server
@ -262,17 +262,17 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
else:
self.__create(name, kwargs, collation, session)
def _socket_for_reads(
def _conn_for_reads(
self, session: ClientSession
) -> ContextManager[Tuple[SocketInfo, Union[PrimaryPreferred, Primary]]]:
return self.__database.client._socket_for_reads(self._read_preference_for(session), session)
) -> ContextManager[Tuple[Connection, Union[PrimaryPreferred, Primary]]]:
return self.__database.client._conn_for_reads(self._read_preference_for(session), session)
def _socket_for_writes(self, session: Optional[ClientSession]) -> ContextManager[SocketInfo]:
return self.__database.client._socket_for_writes(session)
def _conn_for_writes(self, session: Optional[ClientSession]) -> ContextManager[Connection]:
return self.__database.client._conn_for_writes(session)
def _command(
self,
sock_info: SocketInfo,
conn: Connection,
command: Mapping[str, Any],
read_preference: Optional[_ServerMode] = None,
codec_options: Optional[CodecOptions] = None,
@ -288,7 +288,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
"""Internal command helper.
:Parameters:
- `sock_info` - A SocketInfo instance.
- `conn` - A Connection instance.
- `command` - The command itself, as a :class:`~bson.son.SON` instance.
- `read_preference` (optional) - The read preference to use.
- `codec_options` (optional) - An instance of
@ -313,7 +313,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
The result document.
"""
with self.__database.client._tmp_session(session) as s:
return sock_info.command(
return conn.command(
self.__database.name,
command,
read_preference or self._read_preference_for(session),
@ -348,16 +348,16 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
if "size" in options:
options["size"] = float(options["size"])
cmd.update(options)
with self._socket_for_writes(session) as sock_info:
if qev2_required and sock_info.max_wire_version < 21:
with self._conn_for_writes(session) as conn:
if qev2_required and conn.max_wire_version < 21:
raise ConfigurationError(
"Driver support of Queryable Encryption is incompatible with server. "
"Upgrade server to use Queryable Encryption. "
f"Got maxWireVersion {sock_info.max_wire_version} but need maxWireVersion >= 21 (MongoDB >=7.0)"
f"Got maxWireVersion {conn.max_wire_version} but need maxWireVersion >= 21 (MongoDB >=7.0)"
)
self._command(
sock_info,
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session),
@ -597,12 +597,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
command["comment"] = comment
def _insert_command(
session: ClientSession, sock_info: SocketInfo, retryable_write: bool
session: ClientSession, conn: Connection, retryable_write: bool
) -> None:
if bypass_doc_val:
command["bypassDocumentValidation"] = True
result = sock_info.command(
result = conn.command(
self.__database.name,
command,
write_concern=write_concern,
@ -765,7 +765,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _update(
self,
sock_info: SocketInfo,
conn: Connection,
criteria: Mapping[str, Any],
document: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
@ -801,7 +801,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
else:
update_doc["arrayFilters"] = array_filters
if hint is not None:
if not acknowledged and sock_info.max_wire_version < 8:
if not acknowledged and conn.max_wire_version < 8:
raise ConfigurationError(
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
)
@ -821,7 +821,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
# The command result has to be published for APM unmodified
# so we make a shallow copy here before adding updatedExisting.
result = sock_info.command(
result = conn.command(
self.__database.name,
command,
write_concern=write_concern,
@ -865,10 +865,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
"""Internal update / replace helper."""
def _update(
session: Optional[ClientSession], sock_info: SocketInfo, retryable_write: bool
session: Optional[ClientSession], conn: Connection, retryable_write: bool
) -> Optional[Mapping[str, Any]]:
return self._update(
sock_info,
conn,
criteria,
document,
upsert=upsert,
@ -1255,7 +1255,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _delete(
self,
sock_info: SocketInfo,
conn: Connection,
criteria: Mapping[str, Any],
multi: bool,
write_concern: Optional[WriteConcern] = None,
@ -1280,7 +1280,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
else:
delete_doc["collation"] = collation
if hint is not None:
if not acknowledged and sock_info.max_wire_version < 9:
if not acknowledged and conn.max_wire_version < 9:
raise ConfigurationError(
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
)
@ -1297,7 +1297,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
command["comment"] = comment
# Delete command.
result = sock_info.command(
result = conn.command(
self.__database.name,
command,
write_concern=write_concern,
@ -1325,10 +1325,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
"""Internal delete helper."""
def _delete(
session: Optional[ClientSession], sock_info: SocketInfo, retryable_write: bool
session: Optional[ClientSession], conn: Connection, retryable_write: bool
) -> Mapping[str, Any]:
return self._delete(
sock_info,
conn,
criteria,
multi,
write_concern=write_concern,
@ -1738,7 +1738,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _count_cmd(
self,
session: ClientSession,
sock_info: SocketInfo,
conn: Connection,
read_preference: Optional[_ServerMode],
cmd: Mapping[str, Any],
collation: Optional[Collation],
@ -1747,7 +1747,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
# XXX: "ns missing" checks can be removed when we drop support for
# MongoDB 3.0, see SERVER-17051.
res = self._command(
sock_info,
conn,
cmd,
read_preference=read_preference,
allowable_errors=["ns missing"],
@ -1762,7 +1762,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _aggregate_one_result(
self,
sock_info: SocketInfo,
conn: Connection,
read_preference: Optional[_ServerMode],
cmd: Mapping[str, Any],
collation: Optional[_CollationIn],
@ -1770,7 +1770,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
) -> Optional[Mapping[str, Any]]:
"""Internal helper to run an aggregate that returns a single result."""
result = self._command(
sock_info,
conn,
cmd,
read_preference,
allowable_errors=[26], # Ignore NamespaceNotFound.
@ -1821,12 +1821,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _cmd(
session: ClientSession,
server: Server,
sock_info: SocketInfo,
conn: Connection,
read_preference: Optional[_ServerMode],
) -> int:
cmd: SON[str, Any] = SON([("count", self.__name)])
cmd.update(kwargs)
return self._count_cmd(session, sock_info, read_preference, cmd, collation=None)
return self._count_cmd(session, conn, read_preference, cmd, collation=None)
return self._retryable_non_cursor_read(_cmd, None)
@ -1910,10 +1910,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _cmd(
session: ClientSession,
server: Server,
sock_info: SocketInfo,
conn: Connection,
read_preference: Optional[_ServerMode],
) -> int:
result = self._aggregate_one_result(sock_info, read_preference, cmd, collation, session)
result = self._aggregate_one_result(conn, read_preference, cmd, collation, session)
if not result:
return 0
return result["n"]
@ -1922,7 +1922,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _retryable_non_cursor_read(
self,
func: Callable[[ClientSession, Server, SocketInfo, Optional[_ServerMode]], T],
func: Callable[[ClientSession, Server, Connection, Optional[_ServerMode]], T],
session: Optional[ClientSession],
) -> T:
"""Non-cursor read helper to handle implicit session creation."""
@ -1993,8 +1993,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
command (like maxTimeMS) can be passed as keyword arguments.
"""
names = []
with self._socket_for_writes(session) as sock_info:
supports_quorum = sock_info.max_wire_version >= 9
with self._conn_for_writes(session) as conn:
supports_quorum = conn.max_wire_version >= 9
def gen_indexes() -> Iterator[Mapping[str, Any]]:
for index in indexes:
@ -2015,7 +2015,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
)
self._command(
sock_info,
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
@ -2236,9 +2236,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
with self._socket_for_writes(session) as sock_info:
with self._conn_for_writes(session) as conn:
self._command(
sock_info,
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
@ -2285,7 +2285,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _cmd(
session: ClientSession,
server: Server,
sock_info: SocketInfo,
conn: Connection,
read_preference: _ServerMode,
) -> CommandCursor[_DocumentType]:
cmd = SON([("listIndexes", self.__name), ("cursor", {})])
@ -2293,9 +2293,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd["comment"] = comment
try:
cursor = self._command(
sock_info, cmd, read_preference, codec_options, session=session
)["cursor"]
cursor = self._command(conn, cmd, read_preference, codec_options, session=session)[
"cursor"
]
except OperationFailure as exc:
# Ignore NamespaceNotFound errors to match the behavior
# of reading from *.system.indexes.
@ -2305,12 +2305,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd_cursor = CommandCursor(
coll,
cursor,
sock_info.address,
conn.address,
session=session,
explicit_session=explicit_session,
comment=cmd.get("comment"),
)
cmd_cursor._maybe_pin_connection(sock_info)
cmd_cursor._maybe_pin_connection(conn)
return cmd_cursor
with self.__database.client._tmp_session(session, False) as s:
@ -2479,9 +2479,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd = SON([("createSearchIndexes", self.name), ("indexes", list(gen_indexes()))])
cmd.update(kwargs)
with self._socket_for_writes(session) as sock_info:
with self._conn_for_writes(session) as conn:
resp = self._command(
sock_info,
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
@ -2514,9 +2514,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
with self._socket_for_writes(session) as sock_info:
with self._conn_for_writes(session) as conn:
self._command(
sock_info,
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
@ -2551,9 +2551,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
with self._socket_for_writes(session) as sock_info:
with self._conn_for_writes(session) as conn:
self._command(
sock_info,
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
@ -2980,9 +2980,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd["comment"] = comment
write_concern = self._write_concern_for_cmd(cmd, session)
with self._socket_for_writes(session) as sock_info:
with self._conn_for_writes(session) as conn:
with self.__database.client._tmp_session(session) as s:
return sock_info.command(
return conn.command(
"admin",
cmd,
write_concern=write_concern,
@ -3049,11 +3049,11 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _cmd(
session: ClientSession,
server: Server,
sock_info: SocketInfo,
conn: Connection,
read_preference: Optional[_ServerMode],
) -> List:
return self._command(
sock_info,
conn,
cmd,
read_preference=read_preference,
read_concern=self.read_concern,
@ -3112,7 +3112,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
write_concern = self._write_concern_for_cmd(cmd, session)
def _find_and_modify(
session: ClientSession, sock_info: SocketInfo, retryable_write: bool
session: ClientSession, conn: Connection, retryable_write: bool
) -> Any:
acknowledged = write_concern.acknowledged
if array_filters is not None:
@ -3122,17 +3122,17 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
)
cmd["arrayFilters"] = list(array_filters)
if hint is not None:
if sock_info.max_wire_version < 8:
if conn.max_wire_version < 8:
raise ConfigurationError(
"Must be connected to MongoDB 4.2+ to use hint on find and modify commands."
)
elif not acknowledged and sock_info.max_wire_version < 9:
elif not acknowledged and conn.max_wire_version < 9:
raise ConfigurationError(
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged find and modify commands."
)
cmd["hint"] = hint
out = self._command(
sock_info,
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
write_concern=write_concern,

View File

@ -29,7 +29,7 @@ from typing import (
)
from bson import CodecOptions, _convert_raw_document_lists_to_streams
from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _SocketManager
from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _ConnectionManager
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.message import _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore
from pymongo.response import PinnedResponse
@ -38,7 +38,7 @@ from pymongo.typings import _Address, _DocumentType
if TYPE_CHECKING:
from pymongo.client_session import ClientSession
from pymongo.collection import Collection
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
class CommandCursor(Generic[_DocumentType]):
@ -157,19 +157,19 @@ class CommandCursor(Generic[_DocumentType]):
"""
return self.__postbatchresumetoken
def _maybe_pin_connection(self, sock_info: SocketInfo) -> None:
def _maybe_pin_connection(self, conn: Connection) -> None:
client = self.__collection.database.client
if not client._should_pin_cursor(self.__session):
return
if not self.__sock_mgr:
sock_info.pin_cursor()
sock_mgr = _SocketManager(sock_info, False)
conn.pin_cursor()
conn_mgr = _ConnectionManager(conn, False)
# Ensure the connection gets returned when the entire result is
# returned in the first batch.
if self.__id == 0:
sock_mgr.close()
conn_mgr.close()
else:
self.__sock_mgr = sock_mgr
self.__sock_mgr = conn_mgr
def __send_message(self, operation: _GetMore) -> None:
"""Send a getmore message and handle the response."""
@ -197,7 +197,7 @@ class CommandCursor(Generic[_DocumentType]):
if isinstance(response, PinnedResponse):
if not self.__sock_mgr:
self.__sock_mgr = _SocketManager(response.socket_info, response.more_to_come)
self.__sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
if response.from_command:
cursor = response.docs[0]["cursor"]
documents = cursor["nextBatch"]

View File

@ -64,7 +64,7 @@ if TYPE_CHECKING:
from pymongo.client_session import ClientSession
from pymongo.collection import Collection
from pymongo.message import _OpMsg, _OpReply
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
from pymongo.read_preferences import _ServerMode
@ -139,11 +139,11 @@ class CursorType:
"""
class _SocketManager:
"""Used with exhaust cursors to ensure the socket is returned."""
class _ConnectionManager:
"""Used with exhaust cursors to ensure the connection is returned."""
def __init__(self, sock: SocketInfo, more_to_come: bool):
self.sock: Optional[SocketInfo] = sock
def __init__(self, conn: Connection, more_to_come: bool):
self.conn: Optional[Connection] = conn
self.more_to_come = more_to_come
self.lock = _create_lock()
@ -151,10 +151,10 @@ class _SocketManager:
self.more_to_come = more_to_come
def close(self) -> None:
"""Return this instance's socket to the connection pool."""
if self.sock:
self.sock.unpin()
self.sock = None
"""Return this instance's connection to the connection pool."""
if self.conn:
self.conn.unpin()
self.conn = None
_Sort = Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]]
@ -1085,7 +1085,7 @@ class Cursor(Generic[_DocumentType]):
self.__address = response.address
if isinstance(response, PinnedResponse):
if not self.__sock_mgr:
self.__sock_mgr = _SocketManager(response.socket_info, response.more_to_come)
self.__sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
cmd_name = operation.name
docs = response.docs

View File

@ -48,7 +48,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
if TYPE_CHECKING:
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
from pymongo.server import Server
@ -689,7 +689,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
@overload
def _command(
self,
sock_info: SocketInfo,
conn: Connection,
command: Union[str, MutableMapping[str, Any]],
value: int = 1,
check: bool = True,
@ -706,7 +706,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
@overload
def _command(
self,
sock_info: SocketInfo,
conn: Connection,
command: Union[str, MutableMapping[str, Any]],
value: int = 1,
check: bool = True,
@ -722,7 +722,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def _command(
self,
sock_info: SocketInfo,
conn: Connection,
command: Union[str, MutableMapping[str, Any]],
value: int = 1,
check: bool = True,
@ -742,7 +742,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
command.update(kwargs)
with self.__client._tmp_session(session) as s:
return sock_info.command(
return conn.command(
self.__name,
command,
read_preference,
@ -889,12 +889,12 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if read_preference is None:
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
with self.__client._socket_for_reads(read_preference, session) as (
sock_info,
with self.__client._conn_for_reads(read_preference, session) as (
connection,
read_preference,
):
return self._command(
sock_info,
connection,
command,
value,
check,
@ -973,12 +973,12 @@ class Database(common.BaseObject, Generic[_DocumentType]):
read_preference = (
tmp_session and tmp_session._txn_read_preference()
) or ReadPreference.PRIMARY
with self.__client._socket_for_reads(read_preference, tmp_session) as (
sock_info,
with self.__client._conn_for_reads(read_preference, tmp_session) as (
conn,
read_preference,
):
response = self._command(
sock_info,
conn,
command,
value,
True,
@ -993,13 +993,13 @@ class Database(common.BaseObject, Generic[_DocumentType]):
cmd_cursor = CommandCursor(
coll,
response["cursor"],
sock_info.address,
conn.address,
max_await_time_ms=max_await_time_ms,
session=tmp_session,
explicit_session=session is not None,
comment=comment,
)
cmd_cursor._maybe_pin_connection(sock_info)
cmd_cursor._maybe_pin_connection(conn)
return cmd_cursor
else:
raise InvalidOperation("Command does not return a cursor.")
@ -1015,11 +1015,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def _cmd(
session: Optional[ClientSession],
server: Server,
sock_info: SocketInfo,
conn: Connection,
read_preference: _ServerMode,
) -> Dict[str, Any]:
return self._command(
sock_info,
conn,
command,
read_preference=read_preference,
session=session,
@ -1029,7 +1029,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def _list_collections(
self,
sock_info: SocketInfo,
conn: Connection,
session: Optional[ClientSession],
read_preference: _ServerMode,
**kwargs: Any,
@ -1039,18 +1039,18 @@ class Database(common.BaseObject, Generic[_DocumentType]):
cmd = SON([("listCollections", 1), ("cursor", {})])
cmd.update(kwargs)
with self.__client._tmp_session(session, close=False) as tmp_session:
cursor = self._command(
sock_info, cmd, read_preference=read_preference, session=tmp_session
)["cursor"]
cursor = self._command(conn, cmd, read_preference=read_preference, session=tmp_session)[
"cursor"
]
cmd_cursor = CommandCursor(
coll,
cursor,
sock_info.address,
conn.address,
session=tmp_session,
explicit_session=session is not None,
comment=cmd.get("comment"),
)
cmd_cursor._maybe_pin_connection(sock_info)
cmd_cursor._maybe_pin_connection(conn)
return cmd_cursor
def list_collections(
@ -1090,12 +1090,10 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def _cmd(
session: Optional[ClientSession],
server: Server,
sock_info: SocketInfo,
conn: Connection,
read_preference: _ServerMode,
) -> CommandCursor[_DocumentType]:
return self._list_collections(
sock_info, session, read_preference=read_preference, **kwargs
)
return self._list_collections(conn, session, read_preference=read_preference, **kwargs)
return self.__client._retryable_read(_cmd, read_pref, session)
@ -1154,9 +1152,9 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
command["comment"] = comment
with self.__client._socket_for_writes(session) as sock_info:
with self.__client._conn_for_writes(session) as connection:
return self._command(
sock_info,
connection,
command,
allowable_errors=["ns not found", 26],
write_concern=self._write_concern_for(session),

View File

@ -307,7 +307,7 @@ F = TypeVar("F", bound=Callable[..., Any])
def _handle_reauth(func: F) -> F:
def inner(*args: Any, **kwargs: Any) -> Any:
no_reauth = kwargs.pop("no_reauth", False)
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
try:
return func(*args, **kwargs)
@ -315,19 +315,19 @@ def _handle_reauth(func: F) -> F:
if no_reauth:
raise
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
# Look for an argument that either is a SocketInfo
# or has a socket_info attribute, so we can trigger
# Look for an argument that either is a Connection
# or has a connection attribute, so we can trigger
# a reauth.
sock_info = None
conn = None
for arg in args:
if isinstance(arg, SocketInfo):
sock_info = arg
if isinstance(arg, Connection):
conn = arg
break
if hasattr(arg, "sock_info"):
sock_info = arg.sock_info
if hasattr(arg, "connection"):
conn = arg.conn
break
if sock_info:
sock_info.authenticate(reauthenticate=True)
if conn:
conn.authenticate(reauthenticate=True)
else:
raise
return func(*args, **kwargs)

View File

@ -227,14 +227,14 @@ def _gen_find_command(
return cmd
def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, sock_info):
def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, conn):
"""Generate a getMore command document."""
cmd = SON([("getMore", cursor_id), ("collection", coll)])
if batch_size:
cmd["batchSize"] = batch_size
if max_await_time_ms is not None:
cmd["maxTimeMS"] = max_await_time_ms
if comment is not None and sock_info.max_wire_version >= 9:
if comment is not None and conn.max_wire_version >= 9:
cmd["comment"] = comment
return cmd
@ -264,7 +264,7 @@ class _Query:
)
# For compatibility with the _GetMore class.
sock_mgr = None
conn_mgr = None
cursor_id = None
def __init__(
@ -311,24 +311,23 @@ class _Query:
def namespace(self):
return f"{self.db}.{self.coll}"
def use_command(self, sock_info):
def use_command(self, conn):
use_find_cmd = False
if not self.exhaust:
use_find_cmd = True
elif sock_info.max_wire_version >= 8:
elif conn.max_wire_version >= 8:
# OP_MSG supports exhaust on MongoDB 4.2+
use_find_cmd = True
elif not self.read_concern.ok_for_legacy:
raise ConfigurationError(
"read concern level of %s is not valid "
"with a max wire version of %d."
% (self.read_concern.level, sock_info.max_wire_version)
"with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version)
)
sock_info.validate_session(self.client, self.session)
conn.validate_session(self.client, self.session)
return use_find_cmd
def as_command(self, sock_info, apply_timeout=False):
def as_command(self, conn, apply_timeout=False):
"""Return a find command document for this query."""
# We use the command twice: on the wire and for command monitoring.
# Generate it once, for speed and to avoid repeating side-effects.
@ -353,24 +352,24 @@ class _Query:
self.name = "explain"
cmd = SON([("explain", cmd)])
session = self.session
sock_info.add_server_api(cmd)
conn.add_server_api(cmd)
if session:
session._apply_to(cmd, False, self.read_preference, sock_info)
session._apply_to(cmd, False, self.read_preference, conn)
# Explain does not support readConcern.
if not explain and not session.in_transaction:
session._update_read_concern(cmd, sock_info)
sock_info.send_cluster_time(cmd, session, self.client)
session._update_read_concern(cmd, conn)
conn.send_cluster_time(cmd, session, self.client)
# Support auto encryption
client = self.client
if client._encrypter and not client._encrypter._bypass_auto_encryption:
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
# Support CSOT
if apply_timeout:
sock_info.apply_timeout(client, cmd)
conn.apply_timeout(client, cmd)
self._as_command = cmd, self.db
return self._as_command
def get_message(self, read_preference, sock_info, use_cmd=False):
def get_message(self, read_preference, conn, use_cmd=False):
"""Get a query message, possibly setting the secondaryOk bit."""
# Use the read_preference decided by _socket_from_server.
self.read_preference = read_preference
@ -384,14 +383,14 @@ class _Query:
spec = self.spec
if use_cmd:
spec = self.as_command(sock_info, apply_timeout=True)[0]
spec = self.as_command(conn, apply_timeout=True)[0]
request_id, msg, size, _ = _op_msg(
0,
spec,
self.db,
read_preference,
self.codec_options,
ctx=sock_info.compression_context,
ctx=conn.compression_context,
)
return request_id, msg, size
@ -405,7 +404,7 @@ class _Query:
else:
ntoreturn = self.limit
if sock_info.is_mongos:
if conn.is_mongos:
spec = _maybe_add_read_preference(spec, read_preference)
return _query(
@ -416,7 +415,7 @@ class _Query:
spec,
None if use_cmd else self.fields,
self.codec_options,
ctx=sock_info.compression_context,
ctx=conn.compression_context,
)
@ -433,7 +432,7 @@ class _GetMore:
"read_preference",
"session",
"client",
"sock_mgr",
"conn_mgr",
"_as_command",
"exhaust",
"comment",
@ -452,7 +451,7 @@ class _GetMore:
session,
client,
max_await_time_ms,
sock_mgr,
conn_mgr,
exhaust,
comment,
):
@ -465,7 +464,7 @@ class _GetMore:
self.session = session
self.client = client
self.max_await_time_ms = max_await_time_ms
self.sock_mgr = sock_mgr
self.conn_mgr = conn_mgr
self._as_command = None
self.exhaust = exhaust
self.comment = comment
@ -476,18 +475,18 @@ class _GetMore:
def namespace(self):
return f"{self.db}.{self.coll}"
def use_command(self, sock_info):
def use_command(self, conn):
use_cmd = False
if not self.exhaust:
use_cmd = True
elif sock_info.max_wire_version >= 8:
elif conn.max_wire_version >= 8:
# OP_MSG supports exhaust on MongoDB 4.2+
use_cmd = True
sock_info.validate_session(self.client, self.session)
conn.validate_session(self.client, self.session)
return use_cmd
def as_command(self, sock_info, apply_timeout=False):
def as_command(self, conn, apply_timeout=False):
"""Return a getMore command document for this query."""
# See _Query.as_command for an explanation of this caching.
if self._as_command is not None:
@ -499,35 +498,35 @@ class _GetMore:
self.ntoreturn,
self.max_await_time_ms,
self.comment,
sock_info,
conn,
)
if self.session:
self.session._apply_to(cmd, False, self.read_preference, sock_info)
sock_info.add_server_api(cmd)
sock_info.send_cluster_time(cmd, self.session, self.client)
self.session._apply_to(cmd, False, self.read_preference, conn)
conn.add_server_api(cmd)
conn.send_cluster_time(cmd, self.session, self.client)
# Support auto encryption
client = self.client
if client._encrypter and not client._encrypter._bypass_auto_encryption:
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
# Support CSOT
if apply_timeout:
sock_info.apply_timeout(client, cmd=None)
conn.apply_timeout(client, cmd=None)
self._as_command = cmd, self.db
return self._as_command
def get_message(self, dummy0, sock_info, use_cmd=False):
def get_message(self, dummy0, conn, use_cmd=False):
"""Get a getmore message."""
ns = self.namespace()
ctx = sock_info.compression_context
ctx = conn.compression_context
if use_cmd:
spec = self.as_command(sock_info, apply_timeout=True)[0]
if self.sock_mgr:
spec = self.as_command(conn, apply_timeout=True)[0]
if self.conn_mgr:
flags = _OpMsg.EXHAUST_ALLOWED
else:
flags = 0
request_id, msg, size, _ = _op_msg(
flags, spec, self.db, None, self.codec_options, ctx=sock_info.compression_context
flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context
)
return request_id, msg, size
@ -535,10 +534,10 @@ class _GetMore:
class _RawBatchQuery(_Query):
def use_command(self, sock_info):
def use_command(self, conn):
# Compatibility checks.
super().use_command(sock_info)
if sock_info.max_wire_version >= 8:
super().use_command(conn)
if conn.max_wire_version >= 8:
# MongoDB 4.2+ supports exhaust over OP_MSG
return True
elif not self.exhaust:
@ -547,10 +546,10 @@ class _RawBatchQuery(_Query):
class _RawBatchGetMore(_GetMore):
def use_command(self, sock_info):
def use_command(self, conn):
# Compatibility checks.
super().use_command(sock_info)
if sock_info.max_wire_version >= 8:
super().use_command(conn)
if conn.max_wire_version >= 8:
# MongoDB 4.2+ supports exhaust over OP_MSG
return True
elif not self.exhaust:
@ -794,11 +793,11 @@ def _get_more(collection_name, num_to_return, cursor_id, ctx=None):
class _BulkWriteContext:
"""A wrapper around SocketInfo for use with write splitting functions."""
"""A wrapper around Connection for use with write splitting functions."""
__slots__ = (
"db_name",
"sock_info",
"conn",
"op_id",
"name",
"field",
@ -812,10 +811,10 @@ class _BulkWriteContext:
)
def __init__(
self, database_name, cmd_name, sock_info, operation_id, listeners, session, op_type, codec
self, database_name, cmd_name, conn, operation_id, listeners, session, op_type, codec
):
self.db_name = database_name
self.sock_info = sock_info
self.conn = conn
self.op_id = operation_id
self.listeners = listeners
self.publish = listeners.enabled_for_commands
@ -823,7 +822,7 @@ class _BulkWriteContext:
self.field = _FIELD_MAP[self.name]
self.start_time = datetime.datetime.now() if self.publish else None
self.session = session
self.compress = True if sock_info.compression_context else False
self.compress = True if conn.compression_context else False
self.op_type = op_type
self.codec = codec
@ -855,20 +854,20 @@ class _BulkWriteContext:
@property
def max_bson_size(self):
"""A proxy for SockInfo.max_bson_size."""
return self.sock_info.max_bson_size
return self.conn.max_bson_size
@property
def max_message_size(self):
"""A proxy for SockInfo.max_message_size."""
if self.compress:
# Subtract 16 bytes for the message header.
return self.sock_info.max_message_size - 16
return self.sock_info.max_message_size
return self.conn.max_message_size - 16
return self.conn.max_message_size
@property
def max_write_batch_size(self):
"""A proxy for SockInfo.max_write_batch_size."""
return self.sock_info.max_write_batch_size
return self.conn.max_write_batch_size
@property
def max_split_size(self):
@ -876,14 +875,14 @@ class _BulkWriteContext:
return self.max_bson_size
def unack_write(self, cmd, request_id, msg, max_doc_size, docs):
"""A proxy for SocketInfo.unack_write that handles event publishing."""
"""A proxy for Connection.unack_write that handles event publishing."""
if self.publish:
assert self.start_time is not None
duration = datetime.datetime.now() - self.start_time
cmd = self._start(cmd, request_id, docs)
start = datetime.datetime.now()
try:
result = self.sock_info.unack_write(msg, max_doc_size)
result = self.conn.unack_write(msg, max_doc_size)
if self.publish:
duration = (datetime.datetime.now() - start) + duration
if result is not None:
@ -910,14 +909,14 @@ class _BulkWriteContext:
@_handle_reauth
def write_command(self, cmd, request_id, msg, docs):
"""A proxy for SocketInfo.write_command that handles event publishing."""
"""A proxy for Connection.write_command that handles event publishing."""
if self.publish:
assert self.start_time is not None
duration = datetime.datetime.now() - self.start_time
self._start(cmd, request_id, docs)
start = datetime.datetime.now()
try:
reply = self.sock_info.write_command(request_id, msg, self.codec)
reply = self.conn.write_command(request_id, msg, self.codec)
if self.publish:
duration = (datetime.datetime.now() - start) + duration
self._succeed(request_id, reply, duration)
@ -941,9 +940,9 @@ class _BulkWriteContext:
cmd,
self.db_name,
request_id,
self.sock_info.address,
self.conn.address,
self.op_id,
self.sock_info.service_id,
self.conn.service_id,
)
return cmd
@ -954,9 +953,9 @@ class _BulkWriteContext:
reply,
self.name,
request_id,
self.sock_info.address,
self.conn.address,
self.op_id,
self.sock_info.service_id,
self.conn.service_id,
)
def _fail(self, request_id, failure, duration):
@ -966,9 +965,9 @@ class _BulkWriteContext:
failure,
self.name,
request_id,
self.sock_info.address,
self.conn.address,
self.op_id,
self.sock_info.service_id,
self.conn.service_id,
)
@ -997,14 +996,14 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
def execute(self, cmd, docs, client):
batched_cmd, to_send = self._batch_command(cmd, docs)
result = self.sock_info.command(
result = self.conn.command(
self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client
)
return result, to_send
def execute_unack(self, cmd, docs, client):
batched_cmd, to_send = self._batch_command(cmd, docs)
self.sock_info.command(
self.conn.command(
self.db_name,
batched_cmd,
write_concern=WriteConcern(w=0),
@ -1124,7 +1123,7 @@ def _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx):
"""
data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx)
request_id, msg = _compress(2013, data, ctx.sock_info.compression_context)
request_id, msg = _compress(2013, data, ctx.conn.compression_context)
return request_id, msg, to_send
@ -1162,7 +1161,7 @@ def _do_batched_op_msg(namespace, operation, command, docs, opts, ctx):
ack = bool(command["writeConcern"].get("w", 1))
else:
ack = True
if ctx.sock_info.compression_context:
if ctx.conn.compression_context:
return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx)
return _batched_op_msg(operation, command, docs, ack, opts, ctx)

View File

@ -1160,18 +1160,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
def _end_sessions(self, session_ids):
"""Send endSessions command(s) with the given session ids."""
try:
# Use SocketInfo.command directly to avoid implicitly creating
# Use Connection.command directly to avoid implicitly creating
# another session.
with self._socket_for_reads(ReadPreference.PRIMARY_PREFERRED, None) as (
sock_info,
with self._conn_for_reads(ReadPreference.PRIMARY_PREFERRED, None) as (
conn,
read_pref,
):
if not sock_info.supports_sessions:
if not conn.supports_sessions:
return
for i in range(0, len(session_ids), common._MAX_END_SESSIONS):
spec = SON([("endSessions", session_ids[i : i + common._MAX_END_SESSIONS])])
sock_info.command("admin", spec, read_preference=read_pref, client=self)
conn.command("admin", spec, read_preference=read_pref, client=self)
except PyMongoError:
# Drivers MUST ignore any errors returned by the endSessions
# command.
@ -1216,7 +1216,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
return self._topology
@contextlib.contextmanager
def _get_socket(self, server, session):
def _checkout(self, server, session):
in_txn = session and session.in_transaction
with _MongoClientErrorHandler(self, server, session) as err_handler:
# Reuse the pinned connection, if it exists.
@ -1224,23 +1224,23 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
err_handler.contribute_socket(session._pinned_connection)
yield session._pinned_connection
return
with server.get_socket(handler=err_handler) as sock_info:
with server.checkout(handler=err_handler) as conn:
# Pin this session to the selected server or connection.
if in_txn and server.description.server_type in (
SERVER_TYPE.Mongos,
SERVER_TYPE.LoadBalancer,
):
session._pin(server, sock_info)
err_handler.contribute_socket(sock_info)
session._pin(server, conn)
err_handler.contribute_socket(conn)
if (
self._encrypter
and not self._encrypter._bypass_auto_encryption
and sock_info.max_wire_version < 8
and conn.max_wire_version < 8
):
raise ConfigurationError(
"Auto-encryption requires a minimum MongoDB version of 4.2"
)
yield sock_info
yield conn
def _select_server(self, server_selector, session, address=None):
"""Select a server to run an operation on this client.
@ -1273,15 +1273,15 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
session._unpin()
raise
def _socket_for_writes(self, session):
def _conn_for_writes(self, session):
server = self._select_server(writable_server_selector, session)
return self._get_socket(server, session)
return self._checkout(server, session)
@contextlib.contextmanager
def _socket_from_server(self, read_preference, server, session):
def _conn_from_server(self, read_preference, server, session):
assert read_preference is not None, "read_preference must not be None"
# Get a socket for a server matching the read preference, and yield
# sock_info with the effective read preference. The Server Selection
# Get a connection for a server matching the read preference, and yield
# conn with the effective read preference. The Server Selection
# Spec says not to send any $readPreference to standalones and to
# always send primaryPreferred when directly connected to a repl set
# member.
@ -1289,22 +1289,22 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
topology = self._get_topology()
single = topology.description.topology_type == TOPOLOGY_TYPE.Single
with self._get_socket(server, session) as sock_info:
with self._checkout(server, session) as conn:
if single:
if sock_info.is_repl and not (session and session.in_transaction):
if conn.is_repl and not (session and session.in_transaction):
# Use primary preferred to ensure any repl set member
# can handle the request.
read_preference = ReadPreference.PRIMARY_PREFERRED
elif sock_info.is_standalone:
elif conn.is_standalone:
# Don't send read preference to standalones.
read_preference = ReadPreference.PRIMARY
yield sock_info, read_preference
yield conn, read_preference
def _socket_for_reads(self, read_preference, session):
def _conn_for_reads(self, read_preference, session):
assert read_preference is not None, "read_preference must not be None"
_ = self._get_topology()
server = self._select_server(read_preference, session)
return self._socket_from_server(read_preference, server, session)
return self._conn_from_server(read_preference, server, session)
def _should_pin_cursor(self, session):
return self.__options.load_balanced and not (session and session.in_transaction)
@ -1319,22 +1319,26 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
- `address` (optional): Optional address when sending a message
to a specific server, used for getMore.
"""
if operation.sock_mgr:
if operation.conn_mgr:
server = self._select_server(
operation.read_preference, operation.session, address=address
)
with operation.sock_mgr.lock:
with operation.conn_mgr.lock:
with _MongoClientErrorHandler(self, server, operation.session) as err_handler:
err_handler.contribute_socket(operation.sock_mgr.sock)
err_handler.contribute_socket(operation.conn_mgr.conn)
return server.run_operation(
operation.sock_mgr.sock, operation, True, self._event_listeners, unpack_res
operation.conn_mgr.conn,
operation,
True,
self._event_listeners,
unpack_res,
)
def _cmd(session, server, sock_info, read_preference):
def _cmd(session, server, conn, read_preference):
operation.reset() # Reset op in case of retry.
return server.run_operation(
sock_info, operation, read_preference, self._event_listeners, unpack_res
conn, operation, read_preference, self._event_listeners, unpack_res
)
return self._retryable_read(
@ -1388,8 +1392,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
supports_session = (
session is not None and server.description.retryable_writes_supported
)
with self._get_socket(server, session) as sock_info:
max_wire_version = sock_info.max_wire_version
with self._checkout(server, session) as conn:
max_wire_version = conn.max_wire_version
if retryable and not supports_session:
if is_retrying():
# A retry is not possible because this server does
@ -1397,7 +1401,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
assert last_error is not None
raise last_error
retryable = False
return func(session, sock_info, retryable)
return func(session, conn, retryable)
except ServerSelectionTimeoutError:
if is_retrying():
# The application may think the write was never attempted
@ -1455,13 +1459,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
raise last_error
try:
server = self._select_server(read_pref, session, address=address)
with self._socket_from_server(read_pref, server, session) as (sock_info, read_pref):
with self._conn_from_server(read_pref, server, session) as (
conn,
read_pref,
):
if retrying and not retryable:
# A retry is not possible because this server does
# not support retryable reads, raise the last error.
assert last_error is not None
raise last_error
return func(session, server, sock_info, read_pref)
return func(session, server, conn, read_pref)
except ServerSelectionTimeoutError:
if retrying:
# The application may think the write was never attempted
@ -1566,7 +1573,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
return database.Database(self, name)
def _cleanup_cursor(
self, locks_allowed, cursor_id, address, sock_mgr, session, explicit_session
self, locks_allowed, cursor_id, address, conn_mgr, session, explicit_session
):
"""Cleanup a cursor from cursor.close() or __del__.
@ -1578,33 +1585,33 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
- `locks_allowed`: True if we are allowed to acquire locks.
- `cursor_id`: The cursor id which may be 0.
- `address`: The _CursorAddress.
- `sock_mgr`: The _SocketManager for the pinned connection or None.
- `conn_mgr`: The _ConnectionManager for the pinned connection or None.
- `session`: The cursor's session.
- `explicit_session`: True if the session was passed explicitly.
"""
if locks_allowed:
if cursor_id:
if sock_mgr and sock_mgr.more_to_come:
if conn_mgr and conn_mgr.more_to_come:
# If this is an exhaust cursor and we haven't completely
# exhausted the result set we *must* close the socket
# to stop the server from sending more data.
sock_mgr.sock.close_socket(ConnectionClosedReason.ERROR)
conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR)
else:
self._close_cursor_now(cursor_id, address, session=session, sock_mgr=sock_mgr)
if sock_mgr:
sock_mgr.close()
self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr)
if conn_mgr:
conn_mgr.close()
else:
# The cursor will be closed later in a different session.
if cursor_id or sock_mgr:
self._close_cursor_soon(cursor_id, address, sock_mgr)
if cursor_id or conn_mgr:
self._close_cursor_soon(cursor_id, address, conn_mgr)
if session and not explicit_session:
session._end_session(lock=locks_allowed)
def _close_cursor_soon(self, cursor_id, address, sock_mgr=None):
def _close_cursor_soon(self, cursor_id, address, conn_mgr=None):
"""Request that a cursor and/or connection be cleaned up soon."""
self.__kill_cursors_queue.append((address, cursor_id, sock_mgr))
self.__kill_cursors_queue.append((address, cursor_id, conn_mgr))
def _close_cursor_now(self, cursor_id, address=None, session=None, sock_mgr=None):
def _close_cursor_now(self, cursor_id, address=None, session=None, conn_mgr=None):
"""Send a kill cursors message with the given id.
The cursor is closed synchronously on the current thread.
@ -1613,10 +1620,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
raise TypeError("cursor_id must be an instance of int")
try:
if sock_mgr:
with sock_mgr.lock:
if conn_mgr:
with conn_mgr.lock:
# Cursor is pinned to LB outside of a transaction.
self._kill_cursor_impl([cursor_id], address, session, sock_mgr.sock)
self._kill_cursor_impl([cursor_id], address, session, conn_mgr.conn)
else:
self._kill_cursors([cursor_id], address, self._get_topology(), session)
except PyMongoError:
@ -1633,14 +1640,14 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
# Application called close_cursor() with no address.
server = topology.select_server(writable_server_selector)
with self._get_socket(server, session) as sock_info:
self._kill_cursor_impl(cursor_ids, address, session, sock_info)
with self._checkout(server, session) as conn:
self._kill_cursor_impl(cursor_ids, address, session, conn)
def _kill_cursor_impl(self, cursor_ids, address, session, sock_info):
def _kill_cursor_impl(self, cursor_ids, address, session, conn):
namespace = address.namespace
db, coll = namespace.split(".", 1)
spec = SON([("killCursors", coll), ("cursors", cursor_ids)])
sock_info.command(db, spec, session=session, client=self)
conn.command(db, spec, session=session, client=self)
def _process_kill_cursors(self):
"""Process any pending kill cursors requests."""
@ -1650,18 +1657,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
# Other threads or the GC may append to the queue concurrently.
while True:
try:
address, cursor_id, sock_mgr = self.__kill_cursors_queue.pop()
address, cursor_id, conn_mgr = self.__kill_cursors_queue.pop()
except IndexError:
break
if sock_mgr:
pinned_cursors.append((address, cursor_id, sock_mgr))
if conn_mgr:
pinned_cursors.append((address, cursor_id, conn_mgr))
else:
address_to_cursor_ids[address].append(cursor_id)
for address, cursor_id, sock_mgr in pinned_cursors:
for address, cursor_id, conn_mgr in pinned_cursors:
try:
self._cleanup_cursor(True, cursor_id, address, sock_mgr, None, False)
self._cleanup_cursor(True, cursor_id, address, conn_mgr, None, False)
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it
@ -1925,9 +1932,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
if not isinstance(name, str):
raise TypeError("name_or_database must be an instance of str or a Database")
with self._socket_for_writes(session) as sock_info:
with self._conn_for_writes(session) as conn:
self[name]._command(
sock_info,
conn,
{"dropDatabase": 1, "comment": comment},
read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session),
@ -2149,11 +2156,11 @@ class _MongoClientErrorHandler:
self.service_id = None
self.handled = False
def contribute_socket(self, sock_info, completed_handshake=True):
def contribute_socket(self, conn, completed_handshake=True):
"""Provide socket information to the error handler."""
self.max_wire_version = sock_info.max_wire_version
self.sock_generation = sock_info.generation
self.service_id = sock_info.service_id
self.max_wire_version = conn.max_wire_version
self.sock_generation = conn.generation
self.service_id = conn.service_id
self.completed_handshake = completed_handshake
def handle(self, exc_type, exc_val):

View File

@ -245,9 +245,9 @@ class Monitor(MonitorBase):
if self._cancel_context and self._cancel_context.cancelled:
self._reset_connection()
with self._pool.get_socket() as sock_info:
self._cancel_context = sock_info.cancel_context
response, round_trip_time = self._check_with_socket(sock_info)
with self._pool.checkout() as conn:
self._cancel_context = conn.cancel_context
response, round_trip_time = self._check_with_socket(conn)
if not response.awaitable:
self._rtt_monitor.add_sample(round_trip_time)
@ -393,11 +393,11 @@ class _RttMonitor(MonitorBase):
def _ping(self):
"""Run a "hello" command and return the RTT."""
with self._pool.get_socket() as sock_info:
with self._pool.checkout() as conn:
if self._executor._stopped:
raise Exception("_RttMonitor closed")
start = time.monotonic()
sock_info.hello()
conn.hello()
return time.monotonic() - start

View File

@ -135,15 +135,15 @@ Connection monitoring and pooling events are also available. For example::
logging.info("[pool {0.address}] pool closed".format(event))
def connection_created(self, event):
logging.info("[pool {0.address}][conn #{0.connection_id}] "
logging.info("[pool {0.address}][connection #{0.connection_id}] "
"connection created".format(event))
def connection_ready(self, event):
logging.info("[pool {0.address}][conn #{0.connection_id}] "
logging.info("[pool {0.address}][connection #{0.connection_id}] "
"connection setup succeeded".format(event))
def connection_closed(self, event):
logging.info("[pool {0.address}][conn #{0.connection_id}] "
logging.info("[pool {0.address}][connection #{0.connection_id}] "
"connection closed, reason: "
"{0.reason}".format(event))
@ -156,11 +156,11 @@ Connection monitoring and pooling events are also available. For example::
"failed, reason: {0.reason}".format(event))
def connection_checked_out(self, event):
logging.info("[pool {0.address}][conn #{0.connection_id}] "
logging.info("[pool {0.address}][connection #{0.connection_id}] "
"connection checked out of pool".format(event))
def connection_checked_in(self, event):
logging.info("[pool {0.address}][conn #{0.connection_id}] "
logging.info("[pool {0.address}][connection #{0.connection_id}] "
"connection checked into pool".format(event))
@ -268,7 +268,7 @@ class ConnectionPoolListener(_EventListener):
def pool_created(self, event: "PoolCreatedEvent") -> None:
"""Abstract method to handle a :class:`PoolCreatedEvent`.
Emitted when a Connection Pool is created.
Emitted when a connection Pool is created.
:Parameters:
- `event`: An instance of :class:`PoolCreatedEvent`.
@ -278,7 +278,7 @@ class ConnectionPoolListener(_EventListener):
def pool_ready(self, event: "PoolReadyEvent") -> None:
"""Abstract method to handle a :class:`PoolReadyEvent`.
Emitted when a Connection Pool is marked ready.
Emitted when a connection Pool is marked ready.
:Parameters:
- `event`: An instance of :class:`PoolReadyEvent`.
@ -290,7 +290,7 @@ class ConnectionPoolListener(_EventListener):
def pool_cleared(self, event: "PoolClearedEvent") -> None:
"""Abstract method to handle a `PoolClearedEvent`.
Emitted when a Connection Pool is cleared.
Emitted when a connection Pool is cleared.
:Parameters:
- `event`: An instance of :class:`PoolClearedEvent`.
@ -300,7 +300,7 @@ class ConnectionPoolListener(_EventListener):
def pool_closed(self, event: "PoolClosedEvent") -> None:
"""Abstract method to handle a `PoolClosedEvent`.
Emitted when a Connection Pool is closed.
Emitted when a connection Pool is closed.
:Parameters:
- `event`: An instance of :class:`PoolClosedEvent`.
@ -310,7 +310,7 @@ class ConnectionPoolListener(_EventListener):
def connection_created(self, event: "ConnectionCreatedEvent") -> None:
"""Abstract method to handle a :class:`ConnectionCreatedEvent`.
Emitted when a Connection Pool creates a Connection object.
Emitted when a connection Pool creates a Connection object.
:Parameters:
- `event`: An instance of :class:`ConnectionCreatedEvent`.
@ -320,7 +320,7 @@ class ConnectionPoolListener(_EventListener):
def connection_ready(self, event: "ConnectionReadyEvent") -> None:
"""Abstract method to handle a :class:`ConnectionReadyEvent`.
Emitted when a Connection has finished its setup, and is now ready to
Emitted when a connection has finished its setup, and is now ready to
use.
:Parameters:
@ -331,7 +331,7 @@ class ConnectionPoolListener(_EventListener):
def connection_closed(self, event: "ConnectionClosedEvent") -> None:
"""Abstract method to handle a :class:`ConnectionClosedEvent`.
Emitted when a Connection Pool closes a Connection.
Emitted when a connection Pool closes a connection.
:Parameters:
- `event`: An instance of :class:`ConnectionClosedEvent`.
@ -361,7 +361,7 @@ class ConnectionPoolListener(_EventListener):
def connection_checked_out(self, event: "ConnectionCheckedOutEvent") -> None:
"""Abstract method to handle a :class:`ConnectionCheckedOutEvent`.
Emitted when the driver successfully checks out a Connection.
Emitted when the driver successfully checks out a connection.
:Parameters:
- `event`: An instance of :class:`ConnectionCheckedOutEvent`.
@ -371,7 +371,7 @@ class ConnectionPoolListener(_EventListener):
def connection_checked_in(self, event: "ConnectionCheckedInEvent") -> None:
"""Abstract method to handle a :class:`ConnectionCheckedInEvent`.
Emitted when the driver checks in a Connection back to the Connection
Emitted when the driver checks in a connection back to the connection
Pool.
:Parameters:
@ -948,7 +948,7 @@ class _ConnectionIdEvent(_ConnectionEvent):
@property
def connection_id(self) -> int:
"""The ID of the Connection."""
"""The ID of the connection."""
return self.__connection_id
def __repr__(self):
@ -1066,7 +1066,7 @@ class ConnectionCheckOutFailedEvent(_ConnectionEvent):
class ConnectionCheckedOutEvent(_ConnectionIdEvent):
"""Published when the driver successfully checks out a Connection.
"""Published when the driver successfully checks out a connection.
:Parameters:
- `address`: The address (host, port) pair of the server this

View File

@ -52,7 +52,7 @@ if TYPE_CHECKING:
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
from pymongo.mongo_client import MongoClient
from pymongo.monitoring import _EventListeners
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.typings import _Address
@ -62,7 +62,7 @@ _UNPACK_HEADER = struct.Struct("<iiii").unpack
def command(
sock_info: SocketInfo,
conn: Connection,
dbname: str,
spec: MutableMapping[str, Any],
is_mongos: bool,
@ -88,7 +88,7 @@ def command(
"""Execute a command over the socket, or raise socket.error.
:Parameters:
- `sock`: a raw socket instance
- `conn`: a Connection instance
- `dbname`: name of the database on which to run the command
- `spec`: a command document as an ordered dict type, eg SON.
- `is_mongos`: are we connected to a mongos?
@ -98,7 +98,7 @@ def command(
- `client`: optional MongoClient instance for updating $clusterTime.
- `check`: raise OperationFailure if there are errors
- `allowable_errors`: errors to ignore if `check` is True
- `address`: the (host, port) of `sock`
- `address`: the (host, port) of `conn`
- `listeners`: An instance of :class:`~pymongo.monitoring.EventListeners`
- `max_bson_size`: The maximum encoded bson size for this server
- `read_concern`: The read concern for this command.
@ -125,7 +125,7 @@ def command(
if read_concern.level:
spec["readConcern"] = read_concern.document
if session:
session._update_read_concern(spec, sock_info)
session._update_read_concern(spec, conn)
if collation is not None:
spec["collation"] = collation
@ -142,7 +142,7 @@ def command(
# Support CSOT
if client:
sock_info.apply_timeout(client, spec)
conn.apply_timeout(client, spec)
_csot.apply_write_concern(spec, write_concern)
if use_op_msg:
@ -167,19 +167,19 @@ def command(
encoding_duration = datetime.datetime.now() - start
assert listeners is not None
listeners.publish_command_start(
orig, dbname, request_id, address, service_id=sock_info.service_id
orig, dbname, request_id, address, service_id=conn.service_id
)
start = datetime.datetime.now()
try:
sock_info.sock.sendall(msg)
conn.conn.sendall(msg)
if use_op_msg and unacknowledged:
# Unacknowledged, fake a successful command response.
reply = None
response_doc = {"ok": 1}
else:
reply = receive_message(sock_info, request_id)
sock_info.more_to_come = reply.more_to_come
reply = receive_message(conn, request_id)
conn.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields
)
@ -190,7 +190,7 @@ def command(
if check:
helpers._check_command_response(
response_doc,
sock_info.max_wire_version,
conn.max_wire_version,
allowable_errors,
parse_write_concern_error=parse_write_concern_error,
)
@ -203,7 +203,7 @@ def command(
failure = message._convert_exception(exc)
assert listeners is not None
listeners.publish_command_failure(
duration, failure, name, request_id, address, service_id=sock_info.service_id
duration, failure, name, request_id, address, service_id=conn.service_id
)
raise
if publish:
@ -215,7 +215,7 @@ def command(
name,
request_id,
address,
service_id=sock_info.service_id,
service_id=conn.service_id,
speculative_hello=speculative_hello,
)
@ -230,21 +230,19 @@ _UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
def receive_message(
sock_info: SocketInfo, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE
conn: Connection, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout():
deadline = _csot.get_deadline()
else:
timeout = sock_info.sock.gettimeout()
timeout = conn.conn.gettimeout()
if timeout:
deadline = time.monotonic() + timeout
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(
_receive_data_on_socket(sock_info, 16, deadline)
)
length, _, response_to, op_code = _UNPACK_HEADER(_receive_data_on_socket(conn, 16, deadline))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
@ -260,11 +258,11 @@ def receive_message(
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
_receive_data_on_socket(sock_info, 9, deadline)
_receive_data_on_socket(conn, 9, deadline)
)
data = decompress(_receive_data_on_socket(sock_info, length - 25, deadline), compressor_id)
data = decompress(_receive_data_on_socket(conn, length - 25, deadline), compressor_id)
else:
data = _receive_data_on_socket(sock_info, length - 16, deadline)
data = _receive_data_on_socket(conn, length - 16, deadline)
try:
unpack_reply = _UNPACK_REPLY[op_code]
@ -276,12 +274,12 @@ def receive_message(
_POLL_TIMEOUT = 0.5
def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None:
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
context = sock_info.cancel_context
context = conn.cancel_context
# Only Monitor connections can be cancelled.
if context:
sock = sock_info.sock
sock = conn.conn
timed_out = False
while True:
# SSLSocket can have buffered data which won't be caught by select.
@ -300,7 +298,7 @@ def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None:
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = sock_info.socket_checker.select(sock, read=True, timeout=timeout)
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
if context.cancelled:
raise _OperationCancelled("hello cancelled")
if readable:
@ -313,21 +311,19 @@ def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None:
BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS)
def _receive_data_on_socket(
sock_info: SocketInfo, length: int, deadline: Optional[float]
) -> memoryview:
def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
wait_for_read(sock_info, deadline)
wait_for_read(conn, deadline)
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None:
sock_info.set_socket_timeout(max(deadline - time.monotonic(), 0))
chunk_length = sock_info.sock.recv_into(mv[bytes_read:])
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out")
except OSError as exc: # noqa: B014

View File

@ -614,19 +614,19 @@ class _CancellationContext:
return self._cancelled
class SocketInfo:
"""Store a socket with some metadata.
class Connection:
"""Store a connection with some metadata.
:Parameters:
- `sock`: a raw socket object
- `conn`: a raw connection object
- `pool`: a Pool instance
- `address`: the server's (host, port)
- `id`: the id of this socket in it's pool
"""
def __init__(self, sock, pool, address, id):
def __init__(self, conn, pool, address, id):
self.pool_ref = weakref.ref(pool)
self.sock = sock
self.conn = conn
self.address = address
self.id = id
self.authed = set()
@ -673,12 +673,12 @@ class SocketInfo:
self.last_timeout = self.opts.socket_timeout
self.connect_rtt = 0.0
def set_socket_timeout(self, timeout):
"""Cache last timeout to avoid duplicate calls to sock.settimeout."""
def set_conn_timeout(self, timeout):
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
if timeout == self.last_timeout:
return
self.last_timeout = timeout
self.sock.settimeout(timeout)
self.conn.settimeout(timeout)
def apply_timeout(self, client, cmd):
# CSOT: use remaining timeout when set.
@ -686,7 +686,7 @@ class SocketInfo:
if timeout is None:
# Reset the socket timeout unless we're performing a streaming monitor check.
if not self.more_to_come:
self.set_socket_timeout(self.opts.socket_timeout)
self.set_conn_timeout(self.opts.socket_timeout)
return None
# RTT validation.
rtt = _csot.get_rtt()
@ -701,7 +701,7 @@ class SocketInfo:
)
if cmd is not None:
cmd["maxTimeMS"] = int(max_time_ms * 1000)
self.set_socket_timeout(timeout)
self.set_conn_timeout(timeout)
return timeout
def pin_txn(self):
@ -715,9 +715,9 @@ class SocketInfo:
def unpin(self):
pool = self.pool_ref()
if pool:
pool.return_socket(self)
pool.checkin(self)
else:
self.close_socket(ConnectionClosedReason.STALE)
self.close_conn(ConnectionClosedReason.STALE)
def hello_cmd(self):
# Handshake spec requires us to use OP_MSG+hello command for the
@ -748,7 +748,7 @@ class SocketInfo:
awaitable = True
# If connect_timeout is None there is no timeout.
if self.opts.connect_timeout:
self.set_socket_timeout(self.opts.connect_timeout + heartbeat_frequency)
self.set_conn_timeout(self.opts.connect_timeout + heartbeat_frequency)
if not performing_handshake and cluster_time is not None:
cmd["$clusterTime"] = cluster_time
@ -919,7 +919,7 @@ class SocketInfo:
)
try:
self.sock.sendall(message)
self.conn.sendall(message)
except BaseException as error:
self._raise_connection_failure(error)
@ -999,15 +999,15 @@ class SocketInfo:
if session._client is not client:
raise InvalidOperation("Can only use session with the MongoClient that started it")
def close_socket(self, reason):
def close_conn(self, reason):
"""Close this connection with a reason."""
if self.closed:
return
self._close_socket()
self._close_conn()
if reason and self.enabled_for_cmap:
self.listeners.publish_connection_closed(self.address, self.id, reason)
def _close_socket(self):
def _close_conn(self):
"""Close this connection."""
if self.closed:
return
@ -1017,13 +1017,13 @@ class SocketInfo:
# Note: We catch exceptions to avoid spurious errors on interpreter
# shutdown.
try:
self.sock.close()
self.conn.close()
except Exception:
pass
def socket_closed(self):
def conn_closed(self):
"""Return True if we know socket has been closed, False otherwise."""
return self.socket_checker.socket_closed(self.sock)
return self.socket_checker.socket_closed(self.conn)
def send_cluster_time(self, command, session, client):
"""Add $clusterTime."""
@ -1060,12 +1060,12 @@ class SocketInfo:
# KeyboardInterrupt from the start, rather than as an initial
# socket.error, so we catch that, close the socket, and reraise it.
#
# The connection closed event will be emitted later in return_socket.
# The connection closed event will be emitted later in checkin.
if self.ready:
reason = None
else:
reason = ConnectionClosedReason.ERROR
self.close_socket(reason)
self.close_conn(reason)
# SSLError from PyOpenSSL inherits directly from Exception.
if isinstance(error, (IOError, OSError, SSLError)):
_raise_connection_failure(self.address, error)
@ -1073,17 +1073,17 @@ class SocketInfo:
raise
def __eq__(self, other):
return self.sock == other.sock
return self.conn == other.conn
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash(self.sock)
return hash(self.conn)
def __repr__(self):
return "SocketInfo({}){} at {}".format(
repr(self.sock),
return "Connection({}){} at {}".format(
repr(self.conn),
self.closed and " CLOSED" or "",
id(self),
)
@ -1256,7 +1256,7 @@ class Pool:
:Parameters:
- `address`: a (hostname, port) tuple
- `options`: a PoolOptions instance
- `handshake`: whether to call hello for each new SocketInfo
- `handshake`: whether to call hello for each new Connection
"""
if options.pause_enabled:
self.state = PoolState.PAUSED
@ -1268,7 +1268,7 @@ class Pool:
# LIFO pool. Sockets are ordered on idle time. Sockets claimed
# and returned to pool from the left side. Stale sockets removed
# from the right side.
self.sockets: collections.deque = collections.deque()
self.conns: collections.deque = collections.deque()
self.lock = _create_lock()
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
@ -1344,17 +1344,17 @@ class Pool:
self.active_sockets = 0
self.operation_count = 0
if service_id is None:
sockets, self.sockets = self.sockets, collections.deque()
sockets, self.conns = self.conns, collections.deque()
else:
discard: collections.deque = collections.deque()
keep: collections.deque = collections.deque()
for sock_info in self.sockets:
if sock_info.service_id == service_id:
discard.append(sock_info)
for conn in self.conns:
if conn.service_id == service_id:
discard.append(conn)
else:
keep.append(sock_info)
keep.append(conn)
sockets = discard
self.sockets = keep
self.conns = keep
if close:
self.state = PoolState.CLOSED
@ -1367,15 +1367,15 @@ class Pool:
# PoolClosedEvent but that reset() SHOULD close sockets *after*
# publishing the PoolClearedEvent.
if close:
for sock_info in sockets:
sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED)
for conn in sockets:
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if self.enabled_for_cmap:
listeners.publish_pool_closed(self.address)
else:
if old_state != PoolState.PAUSED and self.enabled_for_cmap:
listeners.publish_pool_cleared(self.address, service_id=service_id)
for sock_info in sockets:
sock_info.close_socket(ConnectionClosedReason.STALE)
for conn in sockets:
conn.close_conn(ConnectionClosedReason.STALE)
def update_is_writable(self, is_writable):
"""Updates the is_writable attribute on all sockets currently in the
@ -1383,7 +1383,7 @@ class Pool:
"""
self.is_writable = is_writable
with self.lock:
for _socket in self.sockets:
for _socket in self.conns:
_socket.update_is_writable(self.is_writable)
def reset(self, service_id=None):
@ -1412,16 +1412,16 @@ class Pool:
if self.opts.max_idle_time_seconds is not None:
with self.lock:
while (
self.sockets
and self.sockets[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
self.conns
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
):
sock_info = self.sockets.pop()
sock_info.close_socket(ConnectionClosedReason.IDLE)
conn = self.conns.pop()
conn.close_conn(ConnectionClosedReason.IDLE)
while True:
with self.size_cond:
# There are enough sockets in the pool.
if len(self.sockets) + self.active_sockets >= self.opts.min_pool_size:
if len(self.conns) + self.active_sockets >= self.opts.min_pool_size:
return
if self.requests >= self.opts.min_pool_size:
return
@ -1435,14 +1435,14 @@ class Pool:
return
self._pending += 1
incremented = True
sock_info = self.connect()
conn = self.connect()
with self.lock:
# Close connection and return if the pool was reset during
# socket creation or while acquiring the pool lock.
if self.gen.get_overall() != reference_generation:
sock_info.close_socket(ConnectionClosedReason.STALE)
conn.close_conn(ConnectionClosedReason.STALE)
return
self.sockets.appendleft(sock_info)
self.conns.appendleft(conn)
finally:
if incremented:
# Notify after adding the socket to the pool.
@ -1455,12 +1455,12 @@ class Pool:
self.size_cond.notify()
def connect(self, handler=None):
"""Connect to Mongo and return a new SocketInfo.
"""Connect to Mongo and return a new Connection.
Can raise ConnectionFailure.
Note that the pool does not keep a reference to the socket -- you
must call return_socket() when you're done with it.
must call checkin() when you're done with it.
"""
with self.lock:
conn_id = self.next_connection_id
@ -1483,33 +1483,33 @@ class Pool:
raise
sock_info = SocketInfo(sock, self, self.address, conn_id)
conn = Connection(sock, self, self.address, conn_id)
try:
if self.handshake:
sock_info.hello()
self.is_writable = sock_info.is_writable
conn.hello()
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(sock_info, completed_handshake=False)
handler.contribute_socket(conn, completed_handshake=False)
sock_info.authenticate()
conn.authenticate()
except BaseException:
sock_info.close_socket(ConnectionClosedReason.ERROR)
conn.close_conn(ConnectionClosedReason.ERROR)
raise
return sock_info
return conn
@contextlib.contextmanager
def get_socket(self, handler=None):
"""Get a socket from the pool. Use with a "with" statement.
def checkout(self, handler=None):
"""Get a connection from the pool. Use with a "with" statement.
Returns a :class:`SocketInfo` object wrapping a connected
Returns a :class:`Connection` object wrapping a connected
:class:`socket.socket`.
This method should always be used in a with-statement::
with pool.get_socket() as socket_info:
socket_info.send_message(msg)
data = socket_info.receive_message(op_code, request_id)
with pool.get_conn() as connection:
connection.send_message(msg)
data = connection.receive_message(op_code, request_id)
Can raise ConnectionFailure or OperationFailure.
@ -1520,36 +1520,36 @@ class Pool:
if self.enabled_for_cmap:
listeners.publish_connection_check_out_started(self.address)
sock_info = self._get_socket(handler=handler)
conn = self._get_conn(handler=handler)
if self.enabled_for_cmap:
listeners.publish_connection_checked_out(self.address, sock_info.id)
listeners.publish_connection_checked_out(self.address, conn.id)
try:
yield sock_info
yield conn
except BaseException:
# Exception in caller. Ensure the connection gets returned.
# Note that when pinned is True, the session owns the
# connection and it is responsible for checking the connection
# back into the pool.
pinned = sock_info.pinned_txn or sock_info.pinned_cursor
pinned = conn.pinned_txn or conn.pinned_cursor
if handler:
# Perform SDAM error handling rules while the connection is
# still checked out.
exc_type, exc_val, _ = sys.exc_info()
handler.handle(exc_type, exc_val)
if not pinned and sock_info.active:
self.return_socket(sock_info)
if not pinned and conn.active:
self.checkin(conn)
raise
if sock_info.pinned_txn:
if conn.pinned_txn:
with self.lock:
self.__pinned_sockets.add(sock_info)
self.__pinned_sockets.add(conn)
self.ntxns += 1
elif sock_info.pinned_cursor:
elif conn.pinned_cursor:
with self.lock:
self.__pinned_sockets.add(sock_info)
self.__pinned_sockets.add(conn)
self.ncursors += 1
elif sock_info.active:
self.return_socket(sock_info)
elif conn.active:
self.checkin(conn)
def _raise_if_not_ready(self, emit_event):
if self.state != PoolState.READY:
@ -1559,8 +1559,8 @@ class Pool:
)
_raise_connection_failure(self.address, AutoReconnect("connection pool paused"))
def _get_socket(self, handler=None):
"""Get or create a SocketInfo. Can raise ConnectionFailure."""
def _get_conn(self, handler=None):
"""Get or create a Connection. Can raise ConnectionFailure."""
# We use the pid here to avoid issues with fork / multiprocessing.
# See test.test_client:TestClient.test_fork for an example of
# what could go wrong otherwise
@ -1600,7 +1600,7 @@ class Pool:
self.requests += 1
# We've now acquired the semaphore and must release it on error.
sock_info = None
conn = None
incremented = False
emitted_event = False
try:
@ -1608,40 +1608,40 @@ class Pool:
self.active_sockets += 1
incremented = True
while sock_info is None:
while conn is None:
# CMAP: we MUST wait for either maxConnecting OR for a socket
# to be checked back into the pool.
with self._max_connecting_cond:
self._raise_if_not_ready(emit_event=False)
while not (self.sockets or self._pending < self._max_connecting):
while not (self.conns or self._pending < self._max_connecting):
if not _cond_wait(self._max_connecting_cond, deadline):
# Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition.
if self.sockets or self._pending < self._max_connecting:
if self.conns or self._pending < self._max_connecting:
self._max_connecting_cond.notify()
emitted_event = True
self._raise_wait_queue_timeout()
self._raise_if_not_ready(emit_event=False)
try:
sock_info = self.sockets.popleft()
conn = self.conns.popleft()
except IndexError:
self._pending += 1
if sock_info: # We got a socket from the pool
if self._perished(sock_info):
sock_info = None
if conn: # We got a socket from the pool
if self._perished(conn):
conn = None
continue
else: # We need to create a new connection
try:
sock_info = self.connect(handler=handler)
conn = self.connect(handler=handler)
finally:
with self._max_connecting_cond:
self._pending -= 1
self._max_connecting_cond.notify()
except BaseException:
if sock_info:
if conn:
# We checked out a socket but authentication failed.
sock_info.close_socket(ConnectionClosedReason.ERROR)
conn.close_conn(ConnectionClosedReason.ERROR)
with self.size_cond:
self.requests -= 1
if incremented:
@ -1654,45 +1654,45 @@ class Pool:
)
raise
sock_info.active = True
return sock_info
conn.active = True
return conn
def return_socket(self, sock_info):
"""Return the socket to the pool, or if it's closed discard it.
def checkin(self, conn):
"""Return the connection to the pool, or if it's closed discard it.
:Parameters:
- `sock_info`: The socket to check into the pool.
- `conn`: The connection to check into the pool.
"""
txn = sock_info.pinned_txn
cursor = sock_info.pinned_cursor
sock_info.active = False
sock_info.pinned_txn = False
sock_info.pinned_cursor = False
self.__pinned_sockets.discard(sock_info)
txn = conn.pinned_txn
cursor = conn.pinned_cursor
conn.active = False
conn.pinned_txn = False
conn.pinned_cursor = False
self.__pinned_sockets.discard(conn)
listeners = self.opts._event_listeners
if self.enabled_for_cmap:
listeners.publish_connection_checked_in(self.address, sock_info.id)
listeners.publish_connection_checked_in(self.address, conn.id)
if self.pid != os.getpid():
self.reset_without_pause()
else:
if self.closed:
sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED)
elif sock_info.closed:
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
elif conn.closed:
# CMAP requires the closed event be emitted after the check in.
if self.enabled_for_cmap:
listeners.publish_connection_closed(
self.address, sock_info.id, ConnectionClosedReason.ERROR
self.address, conn.id, ConnectionClosedReason.ERROR
)
else:
with self.lock:
# Hold the lock to ensure this section does not race with
# Pool.reset().
if self.stale_generation(sock_info.generation, sock_info.service_id):
sock_info.close_socket(ConnectionClosedReason.STALE)
if self.stale_generation(conn.generation, conn.service_id):
conn.close_conn(ConnectionClosedReason.STALE)
else:
sock_info.update_last_checkin_time()
sock_info.update_is_writable(self.is_writable)
self.sockets.appendleft(sock_info)
conn.update_last_checkin_time()
conn.update_is_writable(self.is_writable)
self.conns.appendleft(conn)
# Notify any threads waiting to create a connection.
self._max_connecting_cond.notify()
@ -1706,7 +1706,7 @@ class Pool:
self.operation_count -= 1
self.size_cond.notify()
def _perished(self, sock_info):
def _perished(self, conn):
"""Return True and close the connection if it is "perished".
This side-effecty function checks if this socket has been idle for
@ -1720,24 +1720,24 @@ class Pool:
pool, to keep performance reasonable - we can't avoid AutoReconnects
completely anyway.
"""
idle_time_seconds = sock_info.idle_time_seconds()
idle_time_seconds = conn.idle_time_seconds()
# If socket is idle, open a new one.
if (
self.opts.max_idle_time_seconds is not None
and idle_time_seconds > self.opts.max_idle_time_seconds
):
sock_info.close_socket(ConnectionClosedReason.IDLE)
conn.close_conn(ConnectionClosedReason.IDLE)
return True
if self._check_interval_seconds is not None and (
0 == self._check_interval_seconds or idle_time_seconds > self._check_interval_seconds
):
if sock_info.socket_closed():
sock_info.close_socket(ConnectionClosedReason.ERROR)
if conn.conn_closed():
conn.close_conn(ConnectionClosedReason.ERROR)
return True
if self.stale_generation(sock_info.generation, sock_info.service_id):
sock_info.close_socket(ConnectionClosedReason.STALE)
if self.stale_generation(conn.generation, conn.service_id):
conn.close_conn(ConnectionClosedReason.STALE)
return True
return False
@ -1772,5 +1772,5 @@ class Pool:
# Avoid ResourceWarnings in Python 3
# Close all sockets without calling reset() or close() because it is
# not safe to acquire a lock in __del__.
for sock_info in self.sockets:
sock_info.close_socket(None)
for conn in self.conns:
conn.close_conn(None)

View File

@ -81,7 +81,7 @@ def _is_ip_address(address):
return False
# According to the docs for Connection.send it can raise
# According to the docs for socket.send it can raise
# WantX509LookupError and should be retried.
BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError)
@ -347,7 +347,7 @@ class SSLContext:
server_hostname=None,
session=None,
):
"""Wrap an existing Python socket sock and return a TLS socket
"""Wrap an existing Python socket connection and return a TLS socket
object.
"""
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)

View File

@ -21,7 +21,7 @@ if TYPE_CHECKING:
from datetime import timedelta
from pymongo.message import _OpMsg, _OpReply
from pymongo.pool import SocketInfo
from pymongo.pool import Connection
from pymongo.typings import _Address
@ -85,13 +85,13 @@ class Response:
class PinnedResponse(Response):
__slots__ = ("_socket_info", "_more_to_come")
__slots__ = ("_conn", "_more_to_come")
def __init__(
self,
data: Union[_OpMsg, _OpReply],
address: _Address,
socket_info: SocketInfo,
conn: Connection,
request_id: int,
duration: Optional[timedelta],
from_command: bool,
@ -103,7 +103,7 @@ class PinnedResponse(Response):
:Parameters:
- `data`: A network response message.
- `address`: (host, port) of the source server.
- `socket_info`: The SocketInfo used for the initial query.
- `conn`: The Connection used for the initial query.
- `request_id`: The request id of this operation.
- `duration`: The duration of the operation.
- `from_command`: If the response is the result of a db command.
@ -112,18 +112,18 @@ class PinnedResponse(Response):
exhausted.
"""
super().__init__(data, address, request_id, duration, from_command, docs)
self._socket_info = socket_info
self._conn = conn
self._more_to_come = more_to_come
@property
def socket_info(self) -> SocketInfo:
"""The SocketInfo used for the initial query.
def conn(self) -> Connection:
"""The Connection used for the initial query.
The server will send batches on this socket, without waiting for
getMores from the client, until the result set is exhausted or there
is an error.
"""
return self._socket_info
return self._conn
@property
def more_to_come(self) -> bool:

View File

@ -42,7 +42,7 @@ if TYPE_CHECKING:
from pymongo.mongo_client import _MongoClientErrorHandler
from pymongo.monitor import Monitor
from pymongo.monitoring import _EventListeners
from pymongo.pool import Pool, SocketInfo
from pymongo.pool import Connection, Pool
from pymongo.server_description import ServerDescription
_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}
@ -105,7 +105,7 @@ class Server:
@_handle_reauth
def run_operation(
self,
sock_info: SocketInfo,
conn: Connection,
operation: Union[_Query, _GetMore],
read_preference: bool,
listeners: _EventListeners,
@ -118,7 +118,7 @@ class Server:
Can raise ConnectionFailure, OperationFailure, etc.
:Parameters:
- `sock_info`: A SocketInfo instance.
- `conn`: A Connection instance.
- `operation`: A _Query or _GetMore object.
- `read_preference`: The read preference to use.
- `listeners`: Instance of _EventListeners or None.
@ -129,27 +129,27 @@ class Server:
if publish:
start = datetime.now()
use_cmd = operation.use_command(sock_info)
more_to_come = operation.sock_mgr and operation.sock_mgr.more_to_come
use_cmd = operation.use_command(conn)
more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come
if more_to_come:
request_id = 0
else:
message = operation.get_message(read_preference, sock_info, use_cmd)
message = operation.get_message(read_preference, conn, use_cmd)
request_id, data, max_doc_size = self._split_message(message)
if publish:
cmd, dbn = operation.as_command(sock_info)
cmd, dbn = operation.as_command(conn)
listeners.publish_command_start(
cmd, dbn, request_id, sock_info.address, service_id=sock_info.service_id
cmd, dbn, request_id, conn.address, service_id=conn.service_id
)
start = datetime.now()
try:
if more_to_come:
reply = sock_info.receive_message(None)
reply = conn.receive_message(None)
else:
sock_info.send_message(data, max_doc_size)
reply = sock_info.receive_message(request_id)
conn.send_message(data, max_doc_size)
reply = conn.receive_message(request_id)
# Unpack and check for command errors.
if use_cmd:
@ -168,7 +168,7 @@ class Server:
if use_cmd:
first = docs[0]
operation.client._process_response(first, operation.session)
_check_command_response(first, sock_info.max_wire_version)
_check_command_response(first, conn.max_wire_version)
except Exception as exc:
if publish:
duration = datetime.now() - start
@ -181,8 +181,8 @@ class Server:
failure,
operation.name,
request_id,
sock_info.address,
service_id=sock_info.service_id,
conn.address,
service_id=conn.service_id,
)
raise
@ -205,8 +205,8 @@ class Server:
res,
operation.name,
request_id,
sock_info.address,
service_id=sock_info.service_id,
conn.address,
service_id=conn.service_id,
)
# Decrypt response.
@ -219,7 +219,7 @@ class Server:
response: Response
if client._should_pin_cursor(operation.session) or operation.exhaust:
sock_info.pin_cursor()
conn.pin_cursor()
if isinstance(reply, _OpMsg):
# In OP_MSG, the server keeps sending only if the
# more_to_come flag is set.
@ -227,12 +227,12 @@ class Server:
else:
# In OP_REPLY, the server keeps sending until cursor_id is 0.
more_to_come = bool(operation.exhaust and reply.cursor_id)
if operation.sock_mgr:
operation.sock_mgr.update_exhaust(more_to_come)
if operation.conn_mgr:
operation.conn_mgr.update_exhaust(more_to_come)
response = PinnedResponse(
data=reply,
address=self._description.address,
socket_info=sock_info,
conn=conn,
duration=duration,
request_id=request_id,
from_command=use_cmd,
@ -251,10 +251,10 @@ class Server:
return response
def get_socket(
def checkout(
self, handler: Optional[_MongoClientErrorHandler] = None
) -> ContextManager[SocketInfo]:
return self.pool.get_socket(handler)
) -> ContextManager[Connection]:
return self.pool.checkout(handler)
@property
def description(self) -> ServerDescription:

View File

@ -106,7 +106,7 @@ class TestHandshake(unittest.TestCase):
self.addCleanup(client.close)
# New monitoring sockets send data during handshake.
# New monitoring connections send data during handshake.
heartbeat = primary.receives("ismaster")
_check_handshake_data(heartbeat)
heartbeat.ok(primary_response)
@ -169,7 +169,7 @@ class TestHandshake(unittest.TestCase):
self.addCleanup(client.close)
# New monitoring sockets send data during handshake.
# New monitoring connections send data during handshake.
heartbeat = server.receives("ismaster")
heartbeat.ok(primary_response)

View File

@ -38,7 +38,7 @@ class MockPool(Pool):
Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs)
@contextlib.contextmanager
def get_socket(self, handler=None):
def checkout(self, handler=None):
client = self.client
host_and_port = f"{self.mock_host}:{self.mock_port}"
if host_and_port in client.mock_down_hosts:
@ -48,10 +48,10 @@ class MockPool(Pool):
client.mock_standalones + client.mock_members + client.mock_mongoses
), ("bad host: %s" % host_and_port)
with Pool.get_socket(self, handler) as sock_info:
sock_info.mock_host = self.mock_host
sock_info.mock_port = self.mock_port
yield sock_info
with Pool.checkout(self, handler) as conn:
conn.mock_host = self.mock_host
conn.mock_port = self.mock_port
yield conn
class DummyMonitor:

View File

@ -60,7 +60,7 @@ class AutoAuthenticateThread(threading.Thread):
"""Used in testing threaded authentication.
This does collection.find_one() with a 1-second delay to ensure it must
check out and authenticate multiple sockets from the pool concurrently.
check out and authenticate multiple connections from the pool concurrently.
:Parameters:
`collection`: An auth-protected collection containing one document.
@ -217,7 +217,7 @@ class TestGSSAPI(unittest.TestCase):
# Need one document in the collection. AutoAuthenticateThread does
# collection.find_one with a 1-second delay, forcing it to check out
# multiple sockets from the pool concurrently, proving that
# multiple connections from the pool concurrently, proving that
# auto-authentication works with GSSAPI.
collection = db.test
if not collection.count_documents({}):

View File

@ -96,7 +96,7 @@ from pymongo.errors import (
)
from pymongo.mongo_client import MongoClient
from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent
from pymongo.pool import _METADATA, PoolOptions, SocketInfo
from pymongo.pool import _METADATA, Connection, PoolOptions
from pymongo.read_preferences import ReadPreference
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import readable_server_selector, writable_server_selector
@ -538,13 +538,13 @@ class TestClient(IntegrationTest):
def test_max_idle_time_reaper_default(self):
with client_knobs(kill_cursor_frequency=0.1):
# Assert reaper doesn't remove sockets when maxIdleTimeMS not set
# Assert reaper doesn't remove connections when maxIdleTimeMS not set
client = rs_or_single_client()
server = client._get_topology().select_server(readable_server_selector)
with server._pool.get_socket() as sock_info:
with server._pool.checkout() as conn:
pass
self.assertEqual(1, len(server._pool.sockets))
self.assertTrue(sock_info in server._pool.sockets)
self.assertEqual(1, len(server._pool.conns))
self.assertTrue(conn in server._pool.conns)
client.close()
def test_max_idle_time_reaper_removes_stale_minPoolSize(self):
@ -552,27 +552,27 @@ class TestClient(IntegrationTest):
# Assert reaper removes idle socket and replaces it with a new one
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1)
server = client._get_topology().select_server(readable_server_selector)
with server._pool.get_socket() as sock_info:
with server._pool.checkout() as conn:
pass
# When the reaper runs at the same time as the get_socket, two
# sockets could be created and checked into the pool.
self.assertGreaterEqual(len(server._pool.sockets), 1)
wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket")
wait_until(lambda: 1 <= len(server._pool.sockets), "replace stale socket")
# connections could be created and checked into the pool.
self.assertGreaterEqual(len(server._pool.conns), 1)
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
wait_until(lambda: 1 <= len(server._pool.conns), "replace stale socket")
client.close()
def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self):
with client_knobs(kill_cursor_frequency=0.1):
# Assert reaper respects maxPoolSize when adding new sockets.
# Assert reaper respects maxPoolSize when adding new connections.
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1)
server = client._get_topology().select_server(readable_server_selector)
with server._pool.get_socket() as sock_info:
with server._pool.checkout() as conn:
pass
# When the reaper runs at the same time as the get_socket,
# maxPoolSize=1 should prevent two sockets from being created.
self.assertEqual(1, len(server._pool.sockets))
wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket")
wait_until(lambda: 1 == len(server._pool.sockets), "replace stale socket")
# maxPoolSize=1 should prevent two connections from being created.
self.assertEqual(1, len(server._pool.conns))
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
wait_until(lambda: 1 == len(server._pool.conns), "replace stale socket")
client.close()
def test_max_idle_time_reaper_removes_stale(self):
@ -580,15 +580,15 @@ class TestClient(IntegrationTest):
# Assert reaper has removed idle socket and NOT replaced it
client = rs_or_single_client(maxIdleTimeMS=500)
server = client._get_topology().select_server(readable_server_selector)
with server._pool.get_socket() as sock_info_one:
with server._pool.checkout() as conn_one:
pass
# Assert that the pool does not close sockets prematurely.
# Assert that the pool does not close connections prematurely.
time.sleep(0.300)
with server._pool.get_socket() as sock_info_two:
with server._pool.checkout() as conn_two:
pass
self.assertIs(sock_info_one, sock_info_two)
self.assertIs(conn_one, conn_two)
wait_until(
lambda: 0 == len(server._pool.sockets),
lambda: 0 == len(server._pool.conns),
"stale socket reaped and new one NOT added to the pool",
)
client.close()
@ -597,48 +597,50 @@ class TestClient(IntegrationTest):
with client_knobs(kill_cursor_frequency=0.1):
client = rs_or_single_client()
server = client._get_topology().select_server(readable_server_selector)
self.assertEqual(0, len(server._pool.sockets))
self.assertEqual(0, len(server._pool.conns))
# Assert that pool started up at minPoolSize
client = rs_or_single_client(minPoolSize=10)
server = client._get_topology().select_server(readable_server_selector)
wait_until(lambda: 10 == len(server._pool.sockets), "pool initialized with 10 sockets")
wait_until(
lambda: 10 == len(server._pool.conns), "pool initialized with 10 connections"
)
# Assert that if a socket is closed, a new one takes its place
with server._pool.get_socket() as sock_info:
sock_info.close_socket(None)
with server._pool.checkout() as conn:
conn.close_conn(None)
wait_until(
lambda: 10 == len(server._pool.sockets),
lambda: 10 == len(server._pool.conns),
"a closed socket gets replaced from the pool",
)
self.assertFalse(sock_info in server._pool.sockets)
self.assertFalse(conn in server._pool.conns)
def test_max_idle_time_checkout(self):
# Use high frequency to test _get_socket_no_auth.
with client_knobs(kill_cursor_frequency=99999999):
client = rs_or_single_client(maxIdleTimeMS=500)
server = client._get_topology().select_server(readable_server_selector)
with server._pool.get_socket() as sock_info:
with server._pool.checkout() as conn:
pass
self.assertEqual(1, len(server._pool.sockets))
self.assertEqual(1, len(server._pool.conns))
time.sleep(1) # Sleep so that the socket becomes stale.
with server._pool.get_socket() as new_sock_info:
self.assertNotEqual(sock_info, new_sock_info)
self.assertEqual(1, len(server._pool.sockets))
self.assertFalse(sock_info in server._pool.sockets)
self.assertTrue(new_sock_info in server._pool.sockets)
with server._pool.checkout() as new_con:
self.assertNotEqual(conn, new_con)
self.assertEqual(1, len(server._pool.conns))
self.assertFalse(conn in server._pool.conns)
self.assertTrue(new_con in server._pool.conns)
# Test that sockets are reused if maxIdleTimeMS is not set.
# Test that connections are reused if maxIdleTimeMS is not set.
client = rs_or_single_client()
server = client._get_topology().select_server(readable_server_selector)
with server._pool.get_socket() as sock_info:
with server._pool.checkout() as conn:
pass
self.assertEqual(1, len(server._pool.sockets))
self.assertEqual(1, len(server._pool.conns))
time.sleep(1)
with server._pool.get_socket() as new_sock_info:
self.assertEqual(sock_info, new_sock_info)
self.assertEqual(1, len(server._pool.sockets))
with server._pool.checkout() as new_con:
self.assertEqual(conn, new_con)
self.assertEqual(1, len(server._pool.conns))
def test_constants(self):
"""This test uses MongoClient explicitly to make sure that host and
@ -933,11 +935,11 @@ class TestClient(IntegrationTest):
topology = client._topology
client.close()
for server in topology._servers.values():
self.assertFalse(server._pool.sockets)
self.assertFalse(server._pool.conns)
self.assertTrue(server._monitor._executor._stopped)
self.assertTrue(server._monitor._rtt_monitor._executor._stopped)
self.assertFalse(server._monitor._pool.sockets)
self.assertFalse(server._monitor._rtt_monitor._pool.sockets)
self.assertFalse(server._monitor._pool.conns)
self.assertFalse(server._monitor._rtt_monitor._pool.conns)
def test_bad_uri(self):
with self.assertRaises(InvalidURI):
@ -1130,8 +1132,8 @@ class TestClient(IntegrationTest):
def test_socketKeepAlive(self):
pool = get_pool(self.client)
with pool.get_socket() as sock_info:
keepalive = sock_info.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
with pool.checkout() as conn:
keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
self.assertTrue(keepalive)
@no_type_check
@ -1184,7 +1186,7 @@ class TestClient(IntegrationTest):
# The socket used for the previous commands has been returned to the
# pool
self.assertEqual(1, len(get_pool(client).sockets))
self.assertEqual(1, len(get_pool(client).conns))
with contextlib.closing(client):
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
@ -1223,7 +1225,7 @@ class TestClient(IntegrationTest):
# main thread while find() is in-progress: On Windows, SIGALRM is
# unavailable so we use a second thread. In our Evergreen setup on
# Linux, the thread technique causes an error in the test at
# sock.recv(): TypeError: 'int' object is not callable
# conn.recv(): TypeError: 'int' object is not callable
# We don't know what causes this, so we hack around it.
if sys.platform == "win32":
@ -1271,16 +1273,16 @@ class TestClient(IntegrationTest):
self.addCleanup(client.close)
client.pymongo_test.test.find_one()
pool = get_pool(client)
socket_count = len(pool.sockets)
socket_count = len(pool.conns)
self.assertGreaterEqual(socket_count, 1)
old_sock_info = next(iter(pool.sockets))
old_conn = next(iter(pool.conns))
client.pymongo_test.test.drop()
client.pymongo_test.test.insert_one({"_id": "foo"})
self.assertRaises(OperationFailure, client.pymongo_test.test.insert_one, {"_id": "foo"})
self.assertEqual(socket_count, len(pool.sockets))
new_sock_info = next(iter(pool.sockets))
self.assertEqual(old_sock_info, new_sock_info)
self.assertEqual(socket_count, len(pool.conns))
new_con = next(iter(pool.conns))
self.assertEqual(old_conn, new_con)
def test_lazy_connect_w0(self):
# Ensure that connect-on-demand works when the first operation is
@ -1326,13 +1328,13 @@ class TestClient(IntegrationTest):
connected(client)
# Cause a network error.
sock_info = one(pool.sockets)
sock_info.sock.close()
conn = one(pool.conns)
conn.conn.close()
cursor = collection.find(cursor_type=CursorType.EXHAUST)
with self.assertRaises(ConnectionFailure):
next(cursor)
self.assertTrue(sock_info.closed)
self.assertTrue(conn.closed)
# The semaphore was decremented despite the error.
self.assertEqual(0, pool.requests)
@ -1347,10 +1349,10 @@ class TestClient(IntegrationTest):
# Cause a network error on the actual socket.
pool = get_pool(c)
socket_info = one(pool.sockets)
socket_info.sock.close()
socket_info = one(pool.conns)
socket_info.conn.close()
# SocketInfo.authenticate logs, but gets a socket.error. Should be
# Connection.authenticate logs, but gets a socket.error. Should be
# reraised as AutoReconnect.
self.assertRaises(AutoReconnect, c.test.collection.find_one)
@ -1586,7 +1588,7 @@ class TestClient(IntegrationTest):
self.addCleanup(delattr, pool, "connect")
# Wait for the background thread to start creating connections
wait_until(lambda: len(pool.sockets) > 1, "start creating connections")
wait_until(lambda: len(pool.conns) > 1, "start creating connections")
# Assert that application operations do not block.
for _ in range(10):
@ -1847,7 +1849,7 @@ class TestExhaustCursor(IntegrationTest):
collection = client.pymongo_test.test
pool = get_pool(client)
sock_info = one(pool.sockets)
conn = one(pool.conns)
# This will cause OperationFailure in all mongo versions since
# the value for $orderby must be a document.
@ -1856,10 +1858,10 @@ class TestExhaustCursor(IntegrationTest):
)
self.assertRaises(OperationFailure, cursor.next)
self.assertFalse(sock_info.closed)
self.assertFalse(conn.closed)
# The socket was checked in and the semaphore was decremented.
self.assertIn(sock_info, pool.sockets)
self.assertIn(conn, pool.conns)
self.assertEqual(0, pool.requests)
def test_exhaust_getmore_server_error(self):
@ -1874,7 +1876,7 @@ class TestExhaustCursor(IntegrationTest):
pool = get_pool(client)
pool._check_interval_seconds = None # Never check.
sock_info = one(pool.sockets)
conn = one(pool.conns)
cursor = collection.find(cursor_type=CursorType.EXHAUST)
@ -1884,21 +1886,21 @@ class TestExhaustCursor(IntegrationTest):
# Cause a server error on getmore.
def receive_message(request_id):
# Discard the actual server response.
SocketInfo.receive_message(sock_info, request_id)
Connection.receive_message(conn, request_id)
# responseFlags bit 1 is QueryFailure.
msg = struct.pack("<iiiii", 1 << 1, 0, 0, 0, 0)
msg += encode({"$err": "mock err", "code": 0})
return message._OpReply.unpack(msg)
sock_info.receive_message = receive_message
conn.receive_message = receive_message
self.assertRaises(OperationFailure, list, cursor)
# Unpatch the instance.
del sock_info.receive_message
del conn.receive_message
# The socket is returned to the pool and it still works.
self.assertEqual(200, collection.count_documents({}))
self.assertIn(sock_info, pool.sockets)
self.assertIn(conn, pool.conns)
def test_exhaust_query_network_error(self):
# When doing an exhaust query, the socket stays checked out on success
@ -1909,15 +1911,15 @@ class TestExhaustCursor(IntegrationTest):
pool._check_interval_seconds = None # Never check.
# Cause a network error.
sock_info = one(pool.sockets)
sock_info.sock.close()
conn = one(pool.conns)
conn.conn.close()
cursor = collection.find(cursor_type=CursorType.EXHAUST)
self.assertRaises(ConnectionFailure, cursor.next)
self.assertTrue(sock_info.closed)
self.assertTrue(conn.closed)
# The socket was closed and the semaphore was decremented.
self.assertNotIn(sock_info, pool.sockets)
self.assertNotIn(conn, pool.conns)
self.assertEqual(0, pool.requests)
def test_exhaust_getmore_network_error(self):
@ -1936,15 +1938,15 @@ class TestExhaustCursor(IntegrationTest):
cursor.next()
# Cause a network error.
sock_info = cursor._Cursor__sock_mgr.sock
sock_info.sock.close()
conn = cursor._Cursor__sock_mgr.conn
conn.conn.close()
# A getmore fails.
self.assertRaises(ConnectionFailure, list, cursor)
self.assertTrue(sock_info.closed)
self.assertTrue(conn.closed)
# The socket was closed and the semaphore was decremented.
self.assertNotIn(sock_info, pool.sockets)
self.assertNotIn(conn, pool.conns)
self.assertEqual(0, pool.requests)
def test_gevent_task(self):
@ -2243,10 +2245,10 @@ class TestClientPool(MockClientTest):
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2)
# Assert that we do not create connections to arbiters.
arbiter = c._topology.get_server_by_address(("c", 3))
self.assertFalse(arbiter.pool.sockets)
self.assertFalse(arbiter.pool.conns)
# Assert that we do not create connections to unknown servers.
arbiter = c._topology.get_server_by_address(("d", 4))
self.assertFalse(arbiter.pool.sockets)
self.assertFalse(arbiter.pool.conns)
# Arbiter pool is not marked ready.
self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 2)
@ -2271,7 +2273,7 @@ class TestClientPool(MockClientTest):
listener.wait_for_event(monitoring.ConnectionReadyEvent, 1)
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
arbiter = c._topology.get_server_by_address(("c", 3))
self.assertEqual(len(arbiter.pool.sockets), 1)
self.assertEqual(len(arbiter.pool.conns), 1)
# Arbiter pool is marked ready.
self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 1)

View File

@ -123,19 +123,19 @@ class TestCMAP(IntegrationTest):
def check_out(self, op):
"""Run the 'checkOut' operation."""
label = op["label"]
with self.pool.get_socket() as sock_info:
with self.pool.checkout() as conn:
# Call 'pin_cursor' so we can hold the socket.
sock_info.pin_cursor()
conn.pin_cursor()
if label:
self.labels[label] = sock_info
self.labels[label] = conn
else:
self.addCleanup(sock_info.close_socket, None)
self.addCleanup(conn.close_conn, None)
def check_in(self, op):
"""Run the 'checkIn' operation."""
label = op["connection"]
sock_info = self.labels[label]
self.pool.return_socket(sock_info)
conn = self.labels[label]
self.pool.checkin(conn)
def ready(self, op):
"""Run the 'ready' operation."""
@ -270,7 +270,7 @@ class TestCMAP(IntegrationTest):
for t in self.targets.values():
t.join(5)
for conn in self.labels.values():
conn.close_socket(None)
conn.close_conn(None)
self.addCleanup(cleanup)
@ -444,7 +444,7 @@ class TestCMAP(IntegrationTest):
self.assertEqual(1, listener.event_count(PoolClearedEvent))
self.assertEqual(PoolState.READY, pool.state)
# Checking out a connection should succeed
with pool.get_socket():
with pool.checkout():
pass

View File

@ -1721,15 +1721,15 @@ class TestCollection(IntegrationTest):
# Make sure the socket is returned after exhaustion.
cur = client[self.db.name].test.find(cursor_type=CursorType.EXHAUST)
next(cur)
self.assertEqual(0, len(pool.sockets))
self.assertEqual(0, len(pool.conns))
for _ in cur:
pass
self.assertEqual(1, len(pool.sockets))
self.assertEqual(1, len(pool.conns))
# Same as previous but don't call next()
for _ in client[self.db.name].test.find(cursor_type=CursorType.EXHAUST):
pass
self.assertEqual(1, len(pool.sockets))
self.assertEqual(1, len(pool.conns))
# If the Cursor instance is discarded before being completely iterated
# and the socket has pending data (more_to_come=True) we have to close
@ -1742,7 +1742,7 @@ class TestCollection(IntegrationTest):
next(cur)
else:
next(cur)
self.assertEqual(0, len(pool.sockets))
self.assertEqual(0, len(pool.conns))
if sys.platform.startswith("java") or "PyPy" in sys.version:
# Don't wait for GC or use gc.collect(), it's unreliable.
cur.close()
@ -1750,7 +1750,7 @@ class TestCollection(IntegrationTest):
# Wait until the background thread returns the socket.
wait_until(lambda: pool.active_sockets == 0, "return socket")
# The socket should be discarded.
self.assertEqual(0, len(pool.sockets))
self.assertEqual(0, len(pool.conns))
def test_distinct(self):
self.db.drop_collection("test")

View File

@ -47,7 +47,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
)
# Ensure connections to all servers in replica set. This is to test
# that the is_writable flag is properly updated for sockets that
# that the is_writable flag is properly updated for connections that
# survive a replica set election.
ensure_all_connected(cls.client)
cls.listener.reset()

View File

@ -94,7 +94,7 @@ def got_app_error(topology, app_error):
when = app_error["when"]
max_wire_version = app_error["maxWireVersion"]
# XXX: We could get better test coverage by mocking the errors on the
# Pool/SocketInfo.
# Pool/Connection.
try:
if error_type == "command":
_check_command_response(app_error["response"], max_wire_version)
@ -274,14 +274,14 @@ class TestIgnoreStaleErrors(IntegrationTest):
client.admin.command("ping")
pool = get_pool(client)
starting_generation = pool.gen.get_overall()
wait_until(lambda: len(pool.sockets) == N_THREADS, "created sockets")
wait_until(lambda: len(pool.conns) == N_THREADS, "created conns")
def mock_command(*args, **kwargs):
# Synchronize all threads to ensure they use the same generation.
barrier.wait()
raise AutoReconnect("mock SocketInfo.command error")
raise AutoReconnect("mock Connection.command error")
for sock in pool.sockets:
for sock in pool.conns:
sock.command = mock_command
def insert_command(i):

View File

@ -41,11 +41,11 @@ class TestLB(IntegrationTest):
# Tracked in PYTHON-3011
self.skipTest("Test is flaky on PyPy")
pool = get_pool(self.client)
nconns = len(pool.sockets)
n_conns = len(pool.conns)
self.db.test.find_one({})
self.assertEqual(len(pool.sockets), nconns)
self.assertEqual(len(pool.conns), n_conns)
list(self.db.test.aggregate([{"$limit": 1}]))
self.assertEqual(len(pool.sockets), nconns)
self.assertEqual(len(pool.conns), n_conns)
@client_context.require_load_balancer
def test_unpin_committed_transaction(self):

View File

@ -116,15 +116,15 @@ class SocketGetter(MongoThread):
self.state = "get_socket"
# Call 'pin_cursor' so we can hold the socket.
with self.pool.get_socket() as sock:
with self.pool.checkout() as sock:
sock.pin_cursor()
self.sock = sock
self.state = "sock"
self.state = "connection"
def __del__(self):
if self.sock:
self.sock.close_socket(None)
self.sock.close_conn(None)
def run_cases(client, cases):
@ -188,36 +188,36 @@ class TestPooling(_TestPoolingBase):
# Test Pool's _check_closed() method doesn't close a healthy socket.
cx_pool = self.create_pool(max_pool_size=10)
cx_pool._check_interval_seconds = 0 # Always check.
with cx_pool.get_socket() as sock_info:
with cx_pool.checkout() as conn:
pass
with cx_pool.get_socket() as new_sock_info:
self.assertEqual(sock_info, new_sock_info)
with cx_pool.checkout() as new_connection:
self.assertEqual(conn, new_connection)
self.assertEqual(1, len(cx_pool.sockets))
self.assertEqual(1, len(cx_pool.conns))
def test_get_socket_and_exception(self):
# get_socket() returns socket after a non-network error.
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
with self.assertRaises(ZeroDivisionError):
with cx_pool.get_socket() as sock_info:
with cx_pool.checkout() as conn:
1 / 0
# Socket was returned, not closed.
with cx_pool.get_socket() as new_sock_info:
self.assertEqual(sock_info, new_sock_info)
with cx_pool.checkout() as new_connection:
self.assertEqual(conn, new_connection)
self.assertEqual(1, len(cx_pool.sockets))
self.assertEqual(1, len(cx_pool.conns))
def test_pool_removes_closed_socket(self):
# Test that Pool removes explicitly closed socket.
cx_pool = self.create_pool()
with cx_pool.get_socket() as sock_info:
# Use SocketInfo's API to close the socket.
sock_info.close_socket(None)
with cx_pool.checkout() as conn:
# Use Connection's API to close the socket.
conn.close_conn(None)
self.assertEqual(0, len(cx_pool.sockets))
self.assertEqual(0, len(cx_pool.conns))
def test_pool_removes_dead_socket(self):
# Test that Pool removes dead socket and the socket doesn't return
@ -225,20 +225,20 @@ class TestPooling(_TestPoolingBase):
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
cx_pool._check_interval_seconds = 0 # Always check.
with cx_pool.get_socket() as sock_info:
# Simulate a closed socket without telling the SocketInfo it's
with cx_pool.checkout() as conn:
# Simulate a closed socket without telling the Connection it's
# closed.
sock_info.sock.close()
self.assertTrue(sock_info.socket_closed())
conn.conn.close()
self.assertTrue(conn.conn_closed())
with cx_pool.get_socket() as new_sock_info:
self.assertEqual(0, len(cx_pool.sockets))
self.assertNotEqual(sock_info, new_sock_info)
with cx_pool.checkout() as new_connection:
self.assertEqual(0, len(cx_pool.conns))
self.assertNotEqual(conn, new_connection)
self.assertEqual(1, len(cx_pool.sockets))
self.assertEqual(1, len(cx_pool.conns))
# Semaphore was released.
with cx_pool.get_socket():
with cx_pool.checkout():
pass
def test_socket_closed(self):
@ -282,13 +282,13 @@ class TestPooling(_TestPoolingBase):
def test_return_socket_after_reset(self):
pool = self.create_pool()
with pool.get_socket() as sock:
with pool.checkout() as sock:
self.assertEqual(pool.active_sockets, 1)
self.assertEqual(pool.operation_count, 1)
pool.reset()
self.assertTrue(sock.closed)
self.assertEqual(0, len(pool.sockets))
self.assertEqual(0, len(pool.conns))
self.assertEqual(pool.active_sockets, 0)
self.assertEqual(pool.operation_count, 0)
@ -299,20 +299,20 @@ class TestPooling(_TestPoolingBase):
cx_pool._check_interval_seconds = 0 # Always check.
self.addCleanup(cx_pool.close)
with cx_pool.get_socket() as sock_info:
# Simulate a closed socket without telling the SocketInfo it's
with cx_pool.checkout() as conn:
# Simulate a closed socket without telling the Connection it's
# closed.
sock_info.sock.close()
conn.conn.close()
# Swap pool's address with a bad one.
address, cx_pool.address = cx_pool.address, ("foo.com", 1234)
with self.assertRaises(AutoReconnect):
with cx_pool.get_socket():
with cx_pool.checkout():
pass
# Back to normal, semaphore was correctly released.
cx_pool.address = address
with cx_pool.get_socket():
with cx_pool.checkout():
pass
def test_wait_queue_timeout(self):
@ -320,10 +320,10 @@ class TestPooling(_TestPoolingBase):
pool = self.create_pool(max_pool_size=1, wait_queue_timeout=wait_queue_timeout)
self.addCleanup(pool.close)
with pool.get_socket():
with pool.checkout():
start = time.time()
with self.assertRaises(ConnectionFailure):
with pool.get_socket():
with pool.checkout():
pass
duration = time.time() - start
@ -338,7 +338,7 @@ class TestPooling(_TestPoolingBase):
self.addCleanup(pool.close)
# Reach max_size.
with pool.get_socket() as s1:
with pool.checkout() as s1:
t = SocketGetter(self.c, pool)
t.start()
while t.state != "get_socket":
@ -347,10 +347,10 @@ class TestPooling(_TestPoolingBase):
time.sleep(1)
self.assertEqual(t.state, "get_socket")
while t.state != "sock":
while t.state != "connection":
time.sleep(0.1)
self.assertEqual(t.state, "sock")
self.assertEqual(t.state, "connection")
self.assertEqual(t.sock, s1)
def test_checkout_more_than_max_pool_size(self):
@ -359,7 +359,7 @@ class TestPooling(_TestPoolingBase):
socks = []
for _ in range(2):
# Call 'pin_cursor' so we can hold the socket.
with pool.get_socket() as sock:
with pool.checkout() as sock:
sock.pin_cursor()
socks.append(sock)
@ -373,7 +373,7 @@ class TestPooling(_TestPoolingBase):
self.assertEqual(t.state, "get_socket")
for socket_info in socks:
socket_info.close_socket(None)
socket_info.close_conn(None)
def test_maxConnecting(self):
client = rs_or_single_client()
@ -394,14 +394,14 @@ class TestPooling(_TestPoolingBase):
thread.join(10)
self.assertEqual(len(docs), 50)
self.assertLessEqual(len(pool.sockets), 50)
self.assertLessEqual(len(pool.conns), 50)
# TLS and auth make connection establishment more expensive than
# the query which leads to more threads hitting maxConnecting.
# The end result is fewer total connections and better latency.
if client_context.tls and client_context.auth_enabled:
self.assertLessEqual(len(pool.sockets), 30)
self.assertLessEqual(len(pool.conns), 30)
else:
self.assertLessEqual(len(pool.sockets), 50)
self.assertLessEqual(len(pool.conns), 50)
# MongoDB 4.4.1 with auth + ssl:
# maxConnecting = 2: 6 connections in ~0.231+ seconds
# maxConnecting = unbounded: 50 connections in ~0.642+ seconds
@ -409,7 +409,7 @@ class TestPooling(_TestPoolingBase):
# MongoDB 4.4.1 with no-auth no-ssl Python 3.8:
# maxConnecting = 2: 15-22 connections in ~0.108+ seconds
# maxConnecting = unbounded: 30+ connections in ~0.140+ seconds
print(len(pool.sockets))
print(len(pool.conns))
class TestPoolMaxSize(_TestPoolingBase):
@ -424,7 +424,7 @@ class TestPoolMaxSize(_TestPoolingBase):
collection.insert_one({})
# nthreads had better be much larger than max_pool_size to ensure that
# max_pool_size sockets are actually required at some point in this
# max_pool_size connections are actually required at some point in this
# test's execution.
cx_pool = get_pool(c)
nthreads = 10
@ -435,7 +435,7 @@ class TestPoolMaxSize(_TestPoolingBase):
def f():
for _ in range(5):
collection.find_one({"$where": delay(0.1)})
assert len(cx_pool.sockets) <= max_pool_size
assert len(cx_pool.conns) <= max_pool_size
with lock:
self.n_passed += 1
@ -447,7 +447,7 @@ class TestPoolMaxSize(_TestPoolingBase):
joinall(threads)
self.assertEqual(nthreads, self.n_passed)
self.assertTrue(len(cx_pool.sockets) > 1)
self.assertTrue(len(cx_pool.conns) > 1)
self.assertEqual(0, cx_pool.requests)
def test_max_pool_size_none(self):
@ -479,7 +479,7 @@ class TestPoolMaxSize(_TestPoolingBase):
joinall(threads)
self.assertEqual(nthreads, self.n_passed)
self.assertTrue(len(cx_pool.sockets) > 1)
self.assertTrue(len(cx_pool.conns) > 1)
self.assertEqual(cx_pool.max_pool_size, float("inf"))
def test_max_pool_size_zero(self):
@ -502,7 +502,7 @@ class TestPoolMaxSize(_TestPoolingBase):
# socket from pool" instead of AutoReconnect.
for _i in range(2):
with self.assertRaises(AutoReconnect) as context:
with test_pool.get_socket():
with test_pool.checkout():
pass
# Testing for AutoReconnect instead of ConnectionFailure, above,

View File

@ -284,18 +284,18 @@ class ReadPrefTester(MongoClient):
super().__init__(*args, **client_options)
@contextlib.contextmanager
def _socket_for_reads(self, read_preference, session):
context = super()._socket_for_reads(read_preference, session)
with context as (sock_info, read_preference):
self.record_a_read(sock_info.address)
yield sock_info, read_preference
def _conn_for_reads(self, read_preference, session):
context = super()._conn_for_reads(read_preference, session)
with context as (conn, read_preference):
self.record_a_read(conn.address)
yield conn, read_preference
@contextlib.contextmanager
def _socket_from_server(self, read_preference, server, session):
context = super()._socket_from_server(read_preference, server, session)
with context as (sock_info, read_preference):
self.record_a_read(sock_info.address)
yield sock_info, read_preference
def _conn_from_server(self, read_preference, server, session):
context = super()._conn_from_server(read_preference, server, session)
with context as (conn, read_preference):
self.record_a_read(conn.address)
yield conn, read_preference
def record_a_read(self, address):
server = self._get_topology().select_server_by_address(address, 0)

View File

@ -141,7 +141,7 @@ class TestProse(IntegrationTest):
)
self.addCleanup(client.close)
wait_until(lambda: len(client.nodes) == 2, "discover both nodes")
wait_until(lambda: len(get_pool(client).sockets) >= 10, "create 10 connections")
wait_until(lambda: len(get_pool(client).conns) >= 10, "create 10 connections")
# Delay find commands on
delay_finds = {
"configureFailPoint": "failCommand",

View File

@ -279,12 +279,12 @@ class HeartbeatEventListener(BaseListener, monitoring.ServerHeartbeatListener):
self.add_event(event)
class MockSocketInfo:
class MockConnection:
def __init__(self):
self.cancel_context = _CancellationContext()
self.more_to_come = False
def close_socket(self, reason):
def close_conn(self, reason):
pass
def __enter__(self):
@ -304,10 +304,10 @@ class MockPool:
def stale_generation(self, gen, service_id):
return self.gen.stale(gen, service_id)
def get_socket(self, handler=None):
return MockSocketInfo()
def checkout(self, handler=None):
return MockConnection()
def return_socket(self, *args, **kwargs):
def checkin(self, *args, **kwargs):
pass
def _reset(self, service_id=None):