PYTHON-3879 Rename SocketInfo to Connection (#1329)
This commit is contained in:
parent
c945ec6302
commit
c88ae79e58
@ -78,7 +78,7 @@ will receive the following error::
|
||||
File "/Library/Python/2.7/site-packages/pymongo/collection.py", line 1560, in count
|
||||
return self._count(cmd, collation, session)
|
||||
File "/Library/Python/2.7/site-packages/pymongo/collection.py", line 1504, in _count
|
||||
with self._socket_for_reads() as (sock_info, slave_ok):
|
||||
with self._socket_for_reads() as (connection, slave_ok):
|
||||
File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/contextlib.py", line 17, in __enter__
|
||||
return self.gen.next()
|
||||
File "/Library/Python/2.7/site-packages/pymongo/mongo_client.py", line 982, in _socket_for_reads
|
||||
|
||||
@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
from pymongo.database import Database
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.server import Server
|
||||
from pymongo.typings import _Pipeline
|
||||
@ -52,7 +52,7 @@ class _AggregationCommand:
|
||||
explicit_session: bool,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
user_fields: Optional[MutableMapping[str, Any]] = None,
|
||||
result_processor: Optional[Callable[[Mapping[str, Any], SocketInfo], None]] = None,
|
||||
result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
if "explain" in options:
|
||||
@ -134,7 +134,7 @@ class _AggregationCommand:
|
||||
self,
|
||||
session: ClientSession,
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
read_preference: _ServerMode,
|
||||
) -> CommandCursor:
|
||||
# Serialize command.
|
||||
@ -146,7 +146,7 @@ class _AggregationCommand:
|
||||
# - server version is >= 4.2 or
|
||||
# - server version is >= 3.2 and pipeline doesn't use $out
|
||||
if ("readConcern" not in cmd) and (
|
||||
not self._performs_write or (sock_info.max_wire_version >= 8)
|
||||
not self._performs_write or (conn.max_wire_version >= 8)
|
||||
):
|
||||
read_concern = self._target.read_concern
|
||||
else:
|
||||
@ -161,7 +161,7 @@ class _AggregationCommand:
|
||||
write_concern = None
|
||||
|
||||
# Run command.
|
||||
result = sock_info.command(
|
||||
result = conn.command(
|
||||
self._database.name,
|
||||
cmd,
|
||||
read_preference,
|
||||
@ -176,7 +176,7 @@ class _AggregationCommand:
|
||||
)
|
||||
|
||||
if self._result_processor:
|
||||
self._result_processor(result, sock_info)
|
||||
self._result_processor(result, conn)
|
||||
|
||||
# Extract cursor from result or mock/fake one if necessary.
|
||||
if "cursor" in result:
|
||||
@ -193,14 +193,14 @@ class _AggregationCommand:
|
||||
cmd_cursor = self._cursor_class(
|
||||
self._cursor_collection(cursor),
|
||||
cursor,
|
||||
sock_info.address,
|
||||
conn.address,
|
||||
batch_size=self._batch_size or 0,
|
||||
max_await_time_ms=self._max_await_time_ms,
|
||||
session=session,
|
||||
explicit_session=self._explicit_session,
|
||||
comment=self._options.get("comment"),
|
||||
)
|
||||
cmd_cursor._maybe_pin_connection(sock_info)
|
||||
cmd_cursor._maybe_pin_connection(conn)
|
||||
return cmd_cursor
|
||||
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ from pymongo.saslprep import saslprep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
|
||||
HAVE_KERBEROS = True
|
||||
_USE_PRINCIPAL = False
|
||||
@ -220,9 +220,7 @@ def _authenticate_scram_start(
|
||||
return nonce, first_bare, cmd
|
||||
|
||||
|
||||
def _authenticate_scram(
|
||||
credentials: MongoCredential, sock_info: SocketInfo, mechanism: str
|
||||
) -> None:
|
||||
def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanism: str) -> None:
|
||||
"""Authenticate using SCRAM."""
|
||||
username = credentials.username
|
||||
if mechanism == "SCRAM-SHA-256":
|
||||
@ -239,13 +237,13 @@ def _authenticate_scram(
|
||||
# Make local
|
||||
_hmac = hmac.HMAC
|
||||
|
||||
ctx = sock_info.auth_ctx
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
nonce, first_bare = ctx.scram_data
|
||||
res = ctx.speculative_authenticate
|
||||
else:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
|
||||
res = sock_info.command(source, cmd)
|
||||
res = conn.command(source, cmd)
|
||||
|
||||
server_first = res["payload"]
|
||||
parsed = _parse_scram_response(server_first)
|
||||
@ -285,7 +283,7 @@ def _authenticate_scram(
|
||||
("payload", Binary(client_final)),
|
||||
]
|
||||
)
|
||||
res = sock_info.command(source, cmd)
|
||||
res = conn.command(source, cmd)
|
||||
|
||||
parsed = _parse_scram_response(res["payload"])
|
||||
if not hmac.compare_digest(parsed[b"v"], server_sig):
|
||||
@ -301,7 +299,7 @@ def _authenticate_scram(
|
||||
("payload", Binary(b"")),
|
||||
]
|
||||
)
|
||||
res = sock_info.command(source, cmd)
|
||||
res = conn.command(source, cmd)
|
||||
if not res["done"]:
|
||||
raise OperationFailure("SASL conversation failed to complete.")
|
||||
|
||||
@ -345,7 +343,7 @@ def _canonicalize_hostname(hostname: str) -> str:
|
||||
return name[0].lower()
|
||||
|
||||
|
||||
def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using GSSAPI."""
|
||||
if not HAVE_KERBEROS:
|
||||
raise ConfigurationError(
|
||||
@ -358,7 +356,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
|
||||
props = credentials.mechanism_properties
|
||||
# Starting here and continuing through the while loop below - establish
|
||||
# the security context. See RFC 4752, Section 3.1, first paragraph.
|
||||
host = sock_info.address[0]
|
||||
host = conn.address[0]
|
||||
if props.canonicalize_host_name:
|
||||
host = _canonicalize_hostname(host)
|
||||
service = props.service_name + "@" + host
|
||||
@ -413,7 +411,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
|
||||
("autoAuthorize", 1),
|
||||
]
|
||||
)
|
||||
response = sock_info.command("$external", cmd)
|
||||
response = conn.command("$external", cmd)
|
||||
|
||||
# Limit how many times we loop to catch protocol / library issues
|
||||
for _ in range(10):
|
||||
@ -430,7 +428,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
|
||||
("payload", payload),
|
||||
]
|
||||
)
|
||||
response = sock_info.command("$external", cmd)
|
||||
response = conn.command("$external", cmd)
|
||||
|
||||
if result == kerberos.AUTH_GSS_COMPLETE:
|
||||
break
|
||||
@ -453,7 +451,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
|
||||
("payload", payload),
|
||||
]
|
||||
)
|
||||
sock_info.command("$external", cmd)
|
||||
conn.command("$external", cmd)
|
||||
|
||||
finally:
|
||||
kerberos.authGSSClientClean(ctx)
|
||||
@ -462,7 +460,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) ->
|
||||
raise OperationFailure(str(exc))
|
||||
|
||||
|
||||
def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using SASL PLAIN (RFC 4616)"""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
@ -476,52 +474,50 @@ def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) ->
|
||||
("autoAuthorize", 1),
|
||||
]
|
||||
)
|
||||
sock_info.command(source, cmd)
|
||||
conn.command(source, cmd)
|
||||
|
||||
|
||||
def _authenticate_x509(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using MONGODB-X509."""
|
||||
ctx = sock_info.auth_ctx
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
# MONGODB-X509 is done after the speculative auth step.
|
||||
return
|
||||
|
||||
cmd = _X509Context(credentials, sock_info.address).speculate_command()
|
||||
sock_info.command("$external", cmd)
|
||||
cmd = _X509Context(credentials, conn.address).speculate_command()
|
||||
conn.command("$external", cmd)
|
||||
|
||||
|
||||
def _authenticate_mongo_cr(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using MONGODB-CR."""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
# Get a nonce
|
||||
response = sock_info.command(source, {"getnonce": 1})
|
||||
response = conn.command(source, {"getnonce": 1})
|
||||
nonce = response["nonce"]
|
||||
key = _auth_key(nonce, username, password)
|
||||
|
||||
# Actually authenticate
|
||||
query = SON([("authenticate", 1), ("user", username), ("nonce", nonce), ("key", key)])
|
||||
sock_info.command(source, query)
|
||||
conn.command(source, query)
|
||||
|
||||
|
||||
def _authenticate_default(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
if sock_info.max_wire_version >= 7:
|
||||
if sock_info.negotiated_mechs:
|
||||
mechs = sock_info.negotiated_mechs
|
||||
def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None:
|
||||
if conn.max_wire_version >= 7:
|
||||
if conn.negotiated_mechs:
|
||||
mechs = conn.negotiated_mechs
|
||||
else:
|
||||
source = credentials.source
|
||||
cmd = sock_info.hello_cmd()
|
||||
cmd = conn.hello_cmd()
|
||||
cmd["saslSupportedMechs"] = source + "." + credentials.username
|
||||
mechs = sock_info.command(source, cmd, publish_events=False).get(
|
||||
"saslSupportedMechs", []
|
||||
)
|
||||
mechs = conn.command(source, cmd, publish_events=False).get("saslSupportedMechs", [])
|
||||
if "SCRAM-SHA-256" in mechs:
|
||||
return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-256")
|
||||
return _authenticate_scram(credentials, conn, "SCRAM-SHA-256")
|
||||
else:
|
||||
return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1")
|
||||
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
|
||||
else:
|
||||
return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1")
|
||||
return _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
|
||||
|
||||
|
||||
_AUTH_MAP: Mapping[str, Callable] = {
|
||||
@ -606,12 +602,12 @@ _SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = {
|
||||
|
||||
|
||||
def authenticate(
|
||||
credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool = False
|
||||
credentials: MongoCredential, conn: Connection, reauthenticate: bool = False
|
||||
) -> None:
|
||||
"""Authenticate sock_info."""
|
||||
"""Authenticate connection."""
|
||||
mechanism = credentials.mechanism
|
||||
auth_func = _AUTH_MAP[mechanism]
|
||||
if mechanism == "MONGODB-OIDC":
|
||||
_authenticate_oidc(credentials, sock_info, reauthenticate)
|
||||
_authenticate_oidc(credentials, conn, reauthenticate)
|
||||
else:
|
||||
auth_func(credentials, sock_info)
|
||||
auth_func(credentials, conn)
|
||||
|
||||
@ -49,7 +49,7 @@ from pymongo.errors import ConfigurationError, OperationFailure
|
||||
if TYPE_CHECKING:
|
||||
from bson.typings import _ReadableBuffer
|
||||
from pymongo.auth import MongoCredential
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
|
||||
|
||||
class _AwsSaslContext(AwsSaslContext): # type: ignore
|
||||
@ -67,7 +67,7 @@ class _AwsSaslContext(AwsSaslContext): # type: ignore
|
||||
return bson.decode(data)
|
||||
|
||||
|
||||
def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> None:
|
||||
def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using MONGODB-AWS."""
|
||||
if not _HAVE_MONGODB_AWS:
|
||||
raise ConfigurationError(
|
||||
@ -75,7 +75,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No
|
||||
"install with: python -m pip install 'pymongo[aws]'"
|
||||
)
|
||||
|
||||
if sock_info.max_wire_version < 9:
|
||||
if conn.max_wire_version < 9:
|
||||
raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later")
|
||||
|
||||
try:
|
||||
@ -90,7 +90,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No
|
||||
client_first = SON(
|
||||
[("saslStart", 1), ("mechanism", "MONGODB-AWS"), ("payload", client_payload)]
|
||||
)
|
||||
server_first = sock_info.command("$external", client_first)
|
||||
server_first = conn.command("$external", client_first)
|
||||
res = server_first
|
||||
# Limit how many times we loop to catch protocol / library issues
|
||||
for _ in range(10):
|
||||
@ -102,7 +102,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No
|
||||
("payload", client_payload),
|
||||
]
|
||||
)
|
||||
res = sock_info.command("$external", cmd)
|
||||
res = conn.command("$external", cmd)
|
||||
if res["done"]:
|
||||
# SASL complete.
|
||||
break
|
||||
|
||||
@ -29,7 +29,7 @@ from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.auth import MongoCredential
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -242,25 +242,23 @@ class _OIDCAuthenticator:
|
||||
self.idp_resp = None
|
||||
self.token_exp_utc = None
|
||||
|
||||
def run_command(
|
||||
self, sock_info: SocketInfo, cmd: Mapping[str, Any]
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
def run_command(self, conn: Connection, cmd: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
|
||||
try:
|
||||
return sock_info.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
|
||||
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
|
||||
except OperationFailure as exc:
|
||||
self.clear()
|
||||
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
|
||||
if "jwt" in bson.decode(cmd["payload"]):
|
||||
if self.idp_info_gen_id > self.reauth_gen_id:
|
||||
raise
|
||||
return self.authenticate(sock_info, reauthenticate=True)
|
||||
return self.authenticate(conn, reauthenticate=True)
|
||||
raise
|
||||
|
||||
def authenticate(
|
||||
self, sock_info: SocketInfo, reauthenticate: bool = False
|
||||
self, conn: Connection, reauthenticate: bool = False
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
if reauthenticate:
|
||||
prev_id = getattr(sock_info, "oidc_token_gen_id", None)
|
||||
prev_id = getattr(conn, "oidc_token_gen_id", None)
|
||||
# Check if we've already changed tokens.
|
||||
if prev_id == self.token_gen_id:
|
||||
self.reauth_gen_id = self.idp_info_gen_id
|
||||
@ -268,7 +266,7 @@ class _OIDCAuthenticator:
|
||||
if not self.properties.refresh_token_callback:
|
||||
self.clear()
|
||||
|
||||
ctx = sock_info.auth_ctx
|
||||
ctx = conn.auth_ctx
|
||||
cmd = None
|
||||
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
@ -276,10 +274,10 @@ class _OIDCAuthenticator:
|
||||
else:
|
||||
cmd = self.auth_start_cmd()
|
||||
assert cmd is not None
|
||||
resp = self.run_command(sock_info, cmd)
|
||||
resp = self.run_command(conn, cmd)
|
||||
|
||||
if resp["done"]:
|
||||
sock_info.oidc_token_gen_id = self.token_gen_id
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
return None
|
||||
|
||||
server_resp: Dict = bson.decode(resp["payload"])
|
||||
@ -289,7 +287,7 @@ class _OIDCAuthenticator:
|
||||
|
||||
conversation_id = resp["conversationId"]
|
||||
token = self.get_current_token()
|
||||
sock_info.oidc_token_gen_id = self.token_gen_id
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
bin_payload = Binary(bson.encode({"jwt": token}))
|
||||
cmd = SON(
|
||||
[
|
||||
@ -298,7 +296,7 @@ class _OIDCAuthenticator:
|
||||
("payload", bin_payload),
|
||||
]
|
||||
)
|
||||
resp = self.run_command(sock_info, cmd)
|
||||
resp = self.run_command(conn, cmd)
|
||||
if not resp["done"]:
|
||||
self.clear()
|
||||
raise OperationFailure("SASL conversation failed to complete.")
|
||||
@ -306,8 +304,8 @@ class _OIDCAuthenticator:
|
||||
|
||||
|
||||
def _authenticate_oidc(
|
||||
credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool
|
||||
credentials: MongoCredential, conn: Connection, reauthenticate: bool
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""Authenticate using MONGODB-OIDC."""
|
||||
authenticator = _get_authenticator(credentials, sock_info.address)
|
||||
return authenticator.authenticate(sock_info, reauthenticate=reauthenticate)
|
||||
authenticator = _get_authenticator(credentials, conn.address)
|
||||
return authenticator.authenticate(conn, reauthenticate=reauthenticate)
|
||||
|
||||
@ -65,7 +65,7 @@ from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
|
||||
|
||||
_DELETE_ALL: int = 0
|
||||
@ -311,7 +311,7 @@ class _Bulk:
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[ClientSession],
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
op_id: int,
|
||||
retryable: bool,
|
||||
full_result: MutableMapping[str, Any],
|
||||
@ -326,9 +326,9 @@ class _Bulk:
|
||||
self.next_run = None
|
||||
run = self.current_run
|
||||
|
||||
# sock_info.command validates the session, but we use
|
||||
# sock_info.write_command.
|
||||
sock_info.validate_session(client, session)
|
||||
# Connection.command validates the session, but we use
|
||||
# Connection.write_command
|
||||
conn.validate_session(client, session)
|
||||
last_run = False
|
||||
|
||||
while run:
|
||||
@ -341,7 +341,7 @@ class _Bulk:
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
sock_info,
|
||||
conn,
|
||||
op_id,
|
||||
listeners,
|
||||
session,
|
||||
@ -369,11 +369,11 @@ class _Bulk:
|
||||
if retryable and not self.started_retryable_write:
|
||||
session._start_retryable_write()
|
||||
self.started_retryable_write = True
|
||||
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, sock_info)
|
||||
sock_info.send_cluster_time(cmd, session, client)
|
||||
sock_info.add_server_api(cmd)
|
||||
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
|
||||
conn.send_cluster_time(cmd, session, client)
|
||||
conn.add_server_api(cmd)
|
||||
# CSOT: apply timeout before encoding the command.
|
||||
sock_info.apply_timeout(client, cmd)
|
||||
conn.apply_timeout(client, cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
|
||||
# Run as many ops as possible in one command.
|
||||
@ -430,13 +430,13 @@ class _Bulk:
|
||||
op_id = _randint()
|
||||
|
||||
def retryable_bulk(
|
||||
session: Optional[ClientSession], sock_info: SocketInfo, retryable: bool
|
||||
session: Optional[ClientSession], conn: Connection, retryable: bool
|
||||
) -> None:
|
||||
self._execute_command(
|
||||
generator,
|
||||
write_concern,
|
||||
session,
|
||||
sock_info,
|
||||
conn,
|
||||
op_id,
|
||||
retryable,
|
||||
full_result,
|
||||
@ -450,7 +450,7 @@ class _Bulk:
|
||||
_raise_bulk_write_error(full_result)
|
||||
return full_result
|
||||
|
||||
def execute_op_msg_no_results(self, sock_info: SocketInfo, generator: Iterator[Any]) -> None:
|
||||
def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None:
|
||||
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
|
||||
db_name = self.collection.database.name
|
||||
client = self.collection.database.client
|
||||
@ -466,7 +466,7 @@ class _Bulk:
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
sock_info,
|
||||
conn,
|
||||
op_id,
|
||||
listeners,
|
||||
None,
|
||||
@ -482,7 +482,7 @@ class _Bulk:
|
||||
("writeConcern", {"w": 0}),
|
||||
]
|
||||
)
|
||||
sock_info.add_server_api(cmd)
|
||||
conn.add_server_api(cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
# Run as many ops as possible.
|
||||
to_send = bwc.execute_unack(cmd, ops, client)
|
||||
@ -491,7 +491,7 @@ class _Bulk:
|
||||
|
||||
def execute_command_no_results(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
) -> None:
|
||||
@ -516,7 +516,7 @@ class _Bulk:
|
||||
generator,
|
||||
initial_write_concern,
|
||||
None,
|
||||
sock_info,
|
||||
conn,
|
||||
op_id,
|
||||
False,
|
||||
full_result,
|
||||
@ -527,7 +527,7 @@ class _Bulk:
|
||||
|
||||
def execute_no_results(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
) -> None:
|
||||
@ -538,11 +538,11 @@ class _Bulk:
|
||||
raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.")
|
||||
# Guard against unsupported unacknowledged writes.
|
||||
unack = write_concern and not write_concern.acknowledged
|
||||
if unack and self.uses_hint_delete and sock_info.max_wire_version < 9:
|
||||
if unack and self.uses_hint_delete and conn.max_wire_version < 9:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
|
||||
)
|
||||
if unack and self.uses_hint_update and sock_info.max_wire_version < 8:
|
||||
if unack and self.uses_hint_update and conn.max_wire_version < 8:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
|
||||
)
|
||||
@ -553,8 +553,8 @@ class _Bulk:
|
||||
)
|
||||
|
||||
if self.ordered:
|
||||
return self.execute_command_no_results(sock_info, generator, write_concern)
|
||||
return self.execute_op_msg_no_results(sock_info, generator)
|
||||
return self.execute_command_no_results(conn, generator, write_concern)
|
||||
return self.execute_op_msg_no_results(conn, generator)
|
||||
|
||||
def execute(self, write_concern: WriteConcern, session: Optional[ClientSession]) -> Any:
|
||||
"""Execute operations."""
|
||||
@ -573,8 +573,8 @@ class _Bulk:
|
||||
|
||||
client = self.collection.database.client
|
||||
if not write_concern.acknowledged:
|
||||
with client._socket_for_writes(session) as sock_info:
|
||||
self.execute_no_results(sock_info, generator, write_concern)
|
||||
with client._conn_for_writes(session) as connection:
|
||||
self.execute_no_results(connection, generator, write_concern)
|
||||
return None
|
||||
else:
|
||||
return self.execute_command(generator, write_concern, session)
|
||||
|
||||
@ -78,7 +78,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
|
||||
|
||||
def _resumable(exc: PyMongoError) -> bool:
|
||||
@ -213,7 +213,7 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
full_pipeline.extend(self._pipeline)
|
||||
return full_pipeline
|
||||
|
||||
def _process_result(self, result: Mapping[str, Any], sock_info: SocketInfo) -> None:
|
||||
def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None:
|
||||
"""Callback that caches the postBatchResumeToken or
|
||||
startAtOperationTime from a changeStream aggregate command response
|
||||
containing an empty batch of change documents.
|
||||
@ -228,7 +228,7 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
self._start_at_operation_time is None
|
||||
and self._uses_resume_after is False
|
||||
and self._uses_start_after is False
|
||||
and sock_info.max_wire_version >= 7
|
||||
and conn.max_wire_version >= 7
|
||||
):
|
||||
self._start_at_operation_time = result.get("operationTime")
|
||||
# PYTHON-2181: informative error on missing operationTime.
|
||||
|
||||
@ -160,7 +160,7 @@ from bson.int64 import Int64
|
||||
from bson.son import SON
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot
|
||||
from pymongo.cursor import _SocketManager
|
||||
from pymongo.cursor import _ConnectionManager
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
ConnectionFailure,
|
||||
@ -178,7 +178,7 @@ from pymongo.write_concern import WriteConcern
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.server import Server
|
||||
|
||||
|
||||
@ -400,7 +400,7 @@ class _Transaction:
|
||||
self.state = _TxnState.NONE
|
||||
self.sharded = False
|
||||
self.pinned_address: Optional[Tuple[str, Optional[int]]] = None
|
||||
self.sock_mgr: Optional[_SocketManager] = None
|
||||
self.conn_mgr: Optional[_ConnectionManager] = None
|
||||
self.recovery_token = None
|
||||
self.attempt = 0
|
||||
self.client = client
|
||||
@ -412,23 +412,23 @@ class _Transaction:
|
||||
return self.state == _TxnState.STARTING
|
||||
|
||||
@property
|
||||
def pinned_conn(self) -> Optional[SocketInfo]:
|
||||
if self.active() and self.sock_mgr:
|
||||
return self.sock_mgr.sock
|
||||
def pinned_conn(self) -> Optional[Connection]:
|
||||
if self.active() and self.conn_mgr:
|
||||
return self.conn_mgr.conn
|
||||
return None
|
||||
|
||||
def pin(self, server: Server, sock_info: SocketInfo) -> None:
|
||||
def pin(self, server: Server, conn: Connection) -> None:
|
||||
self.sharded = True
|
||||
self.pinned_address = server.description.address
|
||||
if server.description.server_type == SERVER_TYPE.LoadBalancer:
|
||||
sock_info.pin_txn()
|
||||
self.sock_mgr = _SocketManager(sock_info, False)
|
||||
conn.pin_txn()
|
||||
self.conn_mgr = _ConnectionManager(conn, False)
|
||||
|
||||
def unpin(self) -> None:
|
||||
self.pinned_address = None
|
||||
if self.sock_mgr:
|
||||
self.sock_mgr.close()
|
||||
self.sock_mgr = None
|
||||
if self.conn_mgr:
|
||||
self.conn_mgr.close()
|
||||
self.conn_mgr = None
|
||||
|
||||
def reset(self) -> None:
|
||||
self.unpin()
|
||||
@ -438,11 +438,11 @@ class _Transaction:
|
||||
self.attempt = 0
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.sock_mgr:
|
||||
if self.conn_mgr:
|
||||
# Reuse the cursor closing machinery to return the socket to the
|
||||
# pool soon.
|
||||
self.client._close_cursor_soon(0, None, self.sock_mgr)
|
||||
self.sock_mgr = None
|
||||
self.client._close_cursor_soon(0, None, self.conn_mgr)
|
||||
self.conn_mgr = None
|
||||
|
||||
|
||||
def _reraise_with_unknown_commit(exc: Any) -> NoReturn:
|
||||
@ -839,12 +839,12 @@ class ClientSession:
|
||||
- `command_name`: Either "commitTransaction" or "abortTransaction".
|
||||
"""
|
||||
|
||||
def func(session: ClientSession, sock_info: SocketInfo, retryable: bool) -> Dict[str, Any]:
|
||||
return self._finish_transaction(sock_info, command_name)
|
||||
def func(session: ClientSession, conn: Connection, retryable: bool) -> Dict[str, Any]:
|
||||
return self._finish_transaction(conn, command_name)
|
||||
|
||||
return self._client._retry_internal(True, func, self, None)
|
||||
|
||||
def _finish_transaction(self, sock_info: SocketInfo, command_name: str) -> Dict[str, Any]:
|
||||
def _finish_transaction(self, conn: Connection, command_name: str) -> Dict[str, Any]:
|
||||
self._transaction.attempt += 1
|
||||
opts = self._transaction.opts
|
||||
assert opts
|
||||
@ -868,7 +868,7 @@ class ClientSession:
|
||||
cmd["recoveryToken"] = self._transaction.recovery_token
|
||||
|
||||
return self._client.admin._command(
|
||||
sock_info, cmd, session=self, write_concern=wc, parse_write_concern_error=True
|
||||
conn, cmd, session=self, write_concern=wc, parse_write_concern_error=True
|
||||
)
|
||||
|
||||
def _advance_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None:
|
||||
@ -954,13 +954,13 @@ class ClientSession:
|
||||
return None
|
||||
|
||||
@property
|
||||
def _pinned_connection(self) -> Optional[SocketInfo]:
|
||||
def _pinned_connection(self) -> Optional[Connection]:
|
||||
"""The connection this transaction was started on."""
|
||||
return self._transaction.pinned_conn
|
||||
|
||||
def _pin(self, server: Server, sock_info: SocketInfo) -> None:
|
||||
def _pin(self, server: Server, conn: Connection) -> None:
|
||||
"""Pin this session to the given Server or to the given connection."""
|
||||
self._transaction.pin(server, sock_info)
|
||||
self._transaction.pin(server, conn)
|
||||
|
||||
def _unpin(self) -> None:
|
||||
"""Unpin this session from any pinned Server."""
|
||||
@ -985,12 +985,12 @@ class ClientSession:
|
||||
command: MutableMapping[str, Any],
|
||||
is_retryable: bool,
|
||||
read_preference: ReadPreference,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
) -> None:
|
||||
self._check_ended()
|
||||
self._materialize()
|
||||
if self.options.snapshot:
|
||||
self._update_read_concern(command, sock_info)
|
||||
self._update_read_concern(command, conn)
|
||||
|
||||
self._server_session.last_use = time.monotonic()
|
||||
command["lsid"] = self._server_session.session_id
|
||||
@ -1016,7 +1016,7 @@ class ClientSession:
|
||||
rc = self._transaction.opts.read_concern.document
|
||||
if rc:
|
||||
command["readConcern"] = rc
|
||||
self._update_read_concern(command, sock_info)
|
||||
self._update_read_concern(command, conn)
|
||||
|
||||
command["txnNumber"] = self._server_session.transaction_id
|
||||
command["autocommit"] = False
|
||||
@ -1025,11 +1025,11 @@ class ClientSession:
|
||||
self._check_ended()
|
||||
self._server_session.inc_transaction_id()
|
||||
|
||||
def _update_read_concern(self, cmd: MutableMapping[str, Any], sock_info: SocketInfo) -> None:
|
||||
def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: Connection) -> None:
|
||||
if self.options.causal_consistency and self.operation_time is not None:
|
||||
cmd.setdefault("readConcern", {})["afterClusterTime"] = self.operation_time
|
||||
if self.options.snapshot:
|
||||
if sock_info.max_wire_version < 13:
|
||||
if conn.max_wire_version < 13:
|
||||
raise ConfigurationError("Snapshot reads require MongoDB 5.0 or later")
|
||||
rc = cmd.setdefault("readConcern", {})
|
||||
rc["level"] = "snapshot"
|
||||
|
||||
@ -126,7 +126,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collation import Collation
|
||||
from pymongo.database import Database
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.server import Server
|
||||
|
||||
@ -262,17 +262,17 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
else:
|
||||
self.__create(name, kwargs, collation, session)
|
||||
|
||||
def _socket_for_reads(
|
||||
def _conn_for_reads(
|
||||
self, session: ClientSession
|
||||
) -> ContextManager[Tuple[SocketInfo, Union[PrimaryPreferred, Primary]]]:
|
||||
return self.__database.client._socket_for_reads(self._read_preference_for(session), session)
|
||||
) -> ContextManager[Tuple[Connection, Union[PrimaryPreferred, Primary]]]:
|
||||
return self.__database.client._conn_for_reads(self._read_preference_for(session), session)
|
||||
|
||||
def _socket_for_writes(self, session: Optional[ClientSession]) -> ContextManager[SocketInfo]:
|
||||
return self.__database.client._socket_for_writes(session)
|
||||
def _conn_for_writes(self, session: Optional[ClientSession]) -> ContextManager[Connection]:
|
||||
return self.__database.client._conn_for_writes(session)
|
||||
|
||||
def _command(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
command: Mapping[str, Any],
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
@ -288,7 +288,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
"""Internal command helper.
|
||||
|
||||
:Parameters:
|
||||
- `sock_info` - A SocketInfo instance.
|
||||
- `conn` - A Connection instance.
|
||||
- `command` - The command itself, as a :class:`~bson.son.SON` instance.
|
||||
- `read_preference` (optional) - The read preference to use.
|
||||
- `codec_options` (optional) - An instance of
|
||||
@ -313,7 +313,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
The result document.
|
||||
"""
|
||||
with self.__database.client._tmp_session(session) as s:
|
||||
return sock_info.command(
|
||||
return conn.command(
|
||||
self.__database.name,
|
||||
command,
|
||||
read_preference or self._read_preference_for(session),
|
||||
@ -348,16 +348,16 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
if "size" in options:
|
||||
options["size"] = float(options["size"])
|
||||
cmd.update(options)
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
if qev2_required and sock_info.max_wire_version < 21:
|
||||
with self._conn_for_writes(session) as conn:
|
||||
if qev2_required and conn.max_wire_version < 21:
|
||||
raise ConfigurationError(
|
||||
"Driver support of Queryable Encryption is incompatible with server. "
|
||||
"Upgrade server to use Queryable Encryption. "
|
||||
f"Got maxWireVersion {sock_info.max_wire_version} but need maxWireVersion >= 21 (MongoDB >=7.0)"
|
||||
f"Got maxWireVersion {conn.max_wire_version} but need maxWireVersion >= 21 (MongoDB >=7.0)"
|
||||
)
|
||||
|
||||
self._command(
|
||||
sock_info,
|
||||
conn,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
write_concern=self._write_concern_for(session),
|
||||
@ -597,12 +597,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
command["comment"] = comment
|
||||
|
||||
def _insert_command(
|
||||
session: ClientSession, sock_info: SocketInfo, retryable_write: bool
|
||||
session: ClientSession, conn: Connection, retryable_write: bool
|
||||
) -> None:
|
||||
if bypass_doc_val:
|
||||
command["bypassDocumentValidation"] = True
|
||||
|
||||
result = sock_info.command(
|
||||
result = conn.command(
|
||||
self.__database.name,
|
||||
command,
|
||||
write_concern=write_concern,
|
||||
@ -765,7 +765,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _update(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
criteria: Mapping[str, Any],
|
||||
document: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
@ -801,7 +801,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
else:
|
||||
update_doc["arrayFilters"] = array_filters
|
||||
if hint is not None:
|
||||
if not acknowledged and sock_info.max_wire_version < 8:
|
||||
if not acknowledged and conn.max_wire_version < 8:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
|
||||
)
|
||||
@ -821,7 +821,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
# The command result has to be published for APM unmodified
|
||||
# so we make a shallow copy here before adding updatedExisting.
|
||||
result = sock_info.command(
|
||||
result = conn.command(
|
||||
self.__database.name,
|
||||
command,
|
||||
write_concern=write_concern,
|
||||
@ -865,10 +865,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
"""Internal update / replace helper."""
|
||||
|
||||
def _update(
|
||||
session: Optional[ClientSession], sock_info: SocketInfo, retryable_write: bool
|
||||
session: Optional[ClientSession], conn: Connection, retryable_write: bool
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
return self._update(
|
||||
sock_info,
|
||||
conn,
|
||||
criteria,
|
||||
document,
|
||||
upsert=upsert,
|
||||
@ -1255,7 +1255,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _delete(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
criteria: Mapping[str, Any],
|
||||
multi: bool,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
@ -1280,7 +1280,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
else:
|
||||
delete_doc["collation"] = collation
|
||||
if hint is not None:
|
||||
if not acknowledged and sock_info.max_wire_version < 9:
|
||||
if not acknowledged and conn.max_wire_version < 9:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
|
||||
)
|
||||
@ -1297,7 +1297,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
command["comment"] = comment
|
||||
|
||||
# Delete command.
|
||||
result = sock_info.command(
|
||||
result = conn.command(
|
||||
self.__database.name,
|
||||
command,
|
||||
write_concern=write_concern,
|
||||
@ -1325,10 +1325,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
"""Internal delete helper."""
|
||||
|
||||
def _delete(
|
||||
session: Optional[ClientSession], sock_info: SocketInfo, retryable_write: bool
|
||||
session: Optional[ClientSession], conn: Connection, retryable_write: bool
|
||||
) -> Mapping[str, Any]:
|
||||
return self._delete(
|
||||
sock_info,
|
||||
conn,
|
||||
criteria,
|
||||
multi,
|
||||
write_concern=write_concern,
|
||||
@ -1738,7 +1738,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
def _count_cmd(
|
||||
self,
|
||||
session: ClientSession,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
cmd: Mapping[str, Any],
|
||||
collation: Optional[Collation],
|
||||
@ -1747,7 +1747,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
# XXX: "ns missing" checks can be removed when we drop support for
|
||||
# MongoDB 3.0, see SERVER-17051.
|
||||
res = self._command(
|
||||
sock_info,
|
||||
conn,
|
||||
cmd,
|
||||
read_preference=read_preference,
|
||||
allowable_errors=["ns missing"],
|
||||
@ -1762,7 +1762,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _aggregate_one_result(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
cmd: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn],
|
||||
@ -1770,7 +1770,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""Internal helper to run an aggregate that returns a single result."""
|
||||
result = self._command(
|
||||
sock_info,
|
||||
conn,
|
||||
cmd,
|
||||
read_preference,
|
||||
allowable_errors=[26], # Ignore NamespaceNotFound.
|
||||
@ -1821,12 +1821,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
def _cmd(
|
||||
session: ClientSession,
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
) -> int:
|
||||
cmd: SON[str, Any] = SON([("count", self.__name)])
|
||||
cmd.update(kwargs)
|
||||
return self._count_cmd(session, sock_info, read_preference, cmd, collation=None)
|
||||
return self._count_cmd(session, conn, read_preference, cmd, collation=None)
|
||||
|
||||
return self._retryable_non_cursor_read(_cmd, None)
|
||||
|
||||
@ -1910,10 +1910,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
def _cmd(
|
||||
session: ClientSession,
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
) -> int:
|
||||
result = self._aggregate_one_result(sock_info, read_preference, cmd, collation, session)
|
||||
result = self._aggregate_one_result(conn, read_preference, cmd, collation, session)
|
||||
if not result:
|
||||
return 0
|
||||
return result["n"]
|
||||
@ -1922,7 +1922,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _retryable_non_cursor_read(
|
||||
self,
|
||||
func: Callable[[ClientSession, Server, SocketInfo, Optional[_ServerMode]], T],
|
||||
func: Callable[[ClientSession, Server, Connection, Optional[_ServerMode]], T],
|
||||
session: Optional[ClientSession],
|
||||
) -> T:
|
||||
"""Non-cursor read helper to handle implicit session creation."""
|
||||
@ -1993,8 +1993,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
command (like maxTimeMS) can be passed as keyword arguments.
|
||||
"""
|
||||
names = []
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
supports_quorum = sock_info.max_wire_version >= 9
|
||||
with self._conn_for_writes(session) as conn:
|
||||
supports_quorum = conn.max_wire_version >= 9
|
||||
|
||||
def gen_indexes() -> Iterator[Mapping[str, Any]]:
|
||||
for index in indexes:
|
||||
@ -2015,7 +2015,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
self._command(
|
||||
sock_info,
|
||||
conn,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
@ -2236,9 +2236,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd.update(kwargs)
|
||||
if comment is not None:
|
||||
cmd["comment"] = comment
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
with self._conn_for_writes(session) as conn:
|
||||
self._command(
|
||||
sock_info,
|
||||
conn,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
allowable_errors=["ns not found", 26],
|
||||
@ -2285,7 +2285,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
def _cmd(
|
||||
session: ClientSession,
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
read_preference: _ServerMode,
|
||||
) -> CommandCursor[_DocumentType]:
|
||||
cmd = SON([("listIndexes", self.__name), ("cursor", {})])
|
||||
@ -2293,9 +2293,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd["comment"] = comment
|
||||
|
||||
try:
|
||||
cursor = self._command(
|
||||
sock_info, cmd, read_preference, codec_options, session=session
|
||||
)["cursor"]
|
||||
cursor = self._command(conn, cmd, read_preference, codec_options, session=session)[
|
||||
"cursor"
|
||||
]
|
||||
except OperationFailure as exc:
|
||||
# Ignore NamespaceNotFound errors to match the behavior
|
||||
# of reading from *.system.indexes.
|
||||
@ -2305,12 +2305,12 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd_cursor = CommandCursor(
|
||||
coll,
|
||||
cursor,
|
||||
sock_info.address,
|
||||
conn.address,
|
||||
session=session,
|
||||
explicit_session=explicit_session,
|
||||
comment=cmd.get("comment"),
|
||||
)
|
||||
cmd_cursor._maybe_pin_connection(sock_info)
|
||||
cmd_cursor._maybe_pin_connection(conn)
|
||||
return cmd_cursor
|
||||
|
||||
with self.__database.client._tmp_session(session, False) as s:
|
||||
@ -2479,9 +2479,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd = SON([("createSearchIndexes", self.name), ("indexes", list(gen_indexes()))])
|
||||
cmd.update(kwargs)
|
||||
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
with self._conn_for_writes(session) as conn:
|
||||
resp = self._command(
|
||||
sock_info,
|
||||
conn,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
@ -2514,9 +2514,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd.update(kwargs)
|
||||
if comment is not None:
|
||||
cmd["comment"] = comment
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
with self._conn_for_writes(session) as conn:
|
||||
self._command(
|
||||
sock_info,
|
||||
conn,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
allowable_errors=["ns not found", 26],
|
||||
@ -2551,9 +2551,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd.update(kwargs)
|
||||
if comment is not None:
|
||||
cmd["comment"] = comment
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
with self._conn_for_writes(session) as conn:
|
||||
self._command(
|
||||
sock_info,
|
||||
conn,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
allowable_errors=["ns not found", 26],
|
||||
@ -2980,9 +2980,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd["comment"] = comment
|
||||
write_concern = self._write_concern_for_cmd(cmd, session)
|
||||
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
with self._conn_for_writes(session) as conn:
|
||||
with self.__database.client._tmp_session(session) as s:
|
||||
return sock_info.command(
|
||||
return conn.command(
|
||||
"admin",
|
||||
cmd,
|
||||
write_concern=write_concern,
|
||||
@ -3049,11 +3049,11 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
def _cmd(
|
||||
session: ClientSession,
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
) -> List:
|
||||
return self._command(
|
||||
sock_info,
|
||||
conn,
|
||||
cmd,
|
||||
read_preference=read_preference,
|
||||
read_concern=self.read_concern,
|
||||
@ -3112,7 +3112,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
write_concern = self._write_concern_for_cmd(cmd, session)
|
||||
|
||||
def _find_and_modify(
|
||||
session: ClientSession, sock_info: SocketInfo, retryable_write: bool
|
||||
session: ClientSession, conn: Connection, retryable_write: bool
|
||||
) -> Any:
|
||||
acknowledged = write_concern.acknowledged
|
||||
if array_filters is not None:
|
||||
@ -3122,17 +3122,17 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
cmd["arrayFilters"] = list(array_filters)
|
||||
if hint is not None:
|
||||
if sock_info.max_wire_version < 8:
|
||||
if conn.max_wire_version < 8:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.2+ to use hint on find and modify commands."
|
||||
)
|
||||
elif not acknowledged and sock_info.max_wire_version < 9:
|
||||
elif not acknowledged and conn.max_wire_version < 9:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged find and modify commands."
|
||||
)
|
||||
cmd["hint"] = hint
|
||||
out = self._command(
|
||||
sock_info,
|
||||
conn,
|
||||
cmd,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
write_concern=write_concern,
|
||||
|
||||
@ -29,7 +29,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from bson import CodecOptions, _convert_raw_document_lists_to_streams
|
||||
from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _SocketManager
|
||||
from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _ConnectionManager
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.message import _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore
|
||||
from pymongo.response import PinnedResponse
|
||||
@ -38,7 +38,7 @@ from pymongo.typings import _Address, _DocumentType
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
|
||||
|
||||
class CommandCursor(Generic[_DocumentType]):
|
||||
@ -157,19 +157,19 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
"""
|
||||
return self.__postbatchresumetoken
|
||||
|
||||
def _maybe_pin_connection(self, sock_info: SocketInfo) -> None:
|
||||
def _maybe_pin_connection(self, conn: Connection) -> None:
|
||||
client = self.__collection.database.client
|
||||
if not client._should_pin_cursor(self.__session):
|
||||
return
|
||||
if not self.__sock_mgr:
|
||||
sock_info.pin_cursor()
|
||||
sock_mgr = _SocketManager(sock_info, False)
|
||||
conn.pin_cursor()
|
||||
conn_mgr = _ConnectionManager(conn, False)
|
||||
# Ensure the connection gets returned when the entire result is
|
||||
# returned in the first batch.
|
||||
if self.__id == 0:
|
||||
sock_mgr.close()
|
||||
conn_mgr.close()
|
||||
else:
|
||||
self.__sock_mgr = sock_mgr
|
||||
self.__sock_mgr = conn_mgr
|
||||
|
||||
def __send_message(self, operation: _GetMore) -> None:
|
||||
"""Send a getmore message and handle the response."""
|
||||
@ -197,7 +197,7 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self.__sock_mgr:
|
||||
self.__sock_mgr = _SocketManager(response.socket_info, response.more_to_come)
|
||||
self.__sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
|
||||
if response.from_command:
|
||||
cursor = response.docs[0]["cursor"]
|
||||
documents = cursor["nextBatch"]
|
||||
|
||||
@ -64,7 +64,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.message import _OpMsg, _OpReply
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
|
||||
|
||||
@ -139,11 +139,11 @@ class CursorType:
|
||||
"""
|
||||
|
||||
|
||||
class _SocketManager:
|
||||
"""Used with exhaust cursors to ensure the socket is returned."""
|
||||
class _ConnectionManager:
|
||||
"""Used with exhaust cursors to ensure the connection is returned."""
|
||||
|
||||
def __init__(self, sock: SocketInfo, more_to_come: bool):
|
||||
self.sock: Optional[SocketInfo] = sock
|
||||
def __init__(self, conn: Connection, more_to_come: bool):
|
||||
self.conn: Optional[Connection] = conn
|
||||
self.more_to_come = more_to_come
|
||||
self.lock = _create_lock()
|
||||
|
||||
@ -151,10 +151,10 @@ class _SocketManager:
|
||||
self.more_to_come = more_to_come
|
||||
|
||||
def close(self) -> None:
|
||||
"""Return this instance's socket to the connection pool."""
|
||||
if self.sock:
|
||||
self.sock.unpin()
|
||||
self.sock = None
|
||||
"""Return this instance's connection to the connection pool."""
|
||||
if self.conn:
|
||||
self.conn.unpin()
|
||||
self.conn = None
|
||||
|
||||
|
||||
_Sort = Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]]
|
||||
@ -1085,7 +1085,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
self.__address = response.address
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self.__sock_mgr:
|
||||
self.__sock_mgr = _SocketManager(response.socket_info, response.more_to_come)
|
||||
self.__sock_mgr = _ConnectionManager(response.conn, response.more_to_come)
|
||||
|
||||
cmd_name = operation.name
|
||||
docs = response.docs
|
||||
|
||||
@ -48,7 +48,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.server import Server
|
||||
|
||||
|
||||
@ -689,7 +689,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
@overload
|
||||
def _command(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
value: int = 1,
|
||||
check: bool = True,
|
||||
@ -706,7 +706,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
@overload
|
||||
def _command(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
value: int = 1,
|
||||
check: bool = True,
|
||||
@ -722,7 +722,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _command(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
value: int = 1,
|
||||
check: bool = True,
|
||||
@ -742,7 +742,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
command.update(kwargs)
|
||||
with self.__client._tmp_session(session) as s:
|
||||
return sock_info.command(
|
||||
return conn.command(
|
||||
self.__name,
|
||||
command,
|
||||
read_preference,
|
||||
@ -889,12 +889,12 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
if read_preference is None:
|
||||
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
|
||||
with self.__client._socket_for_reads(read_preference, session) as (
|
||||
sock_info,
|
||||
with self.__client._conn_for_reads(read_preference, session) as (
|
||||
connection,
|
||||
read_preference,
|
||||
):
|
||||
return self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
command,
|
||||
value,
|
||||
check,
|
||||
@ -973,12 +973,12 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
read_preference = (
|
||||
tmp_session and tmp_session._txn_read_preference()
|
||||
) or ReadPreference.PRIMARY
|
||||
with self.__client._socket_for_reads(read_preference, tmp_session) as (
|
||||
sock_info,
|
||||
with self.__client._conn_for_reads(read_preference, tmp_session) as (
|
||||
conn,
|
||||
read_preference,
|
||||
):
|
||||
response = self._command(
|
||||
sock_info,
|
||||
conn,
|
||||
command,
|
||||
value,
|
||||
True,
|
||||
@ -993,13 +993,13 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd_cursor = CommandCursor(
|
||||
coll,
|
||||
response["cursor"],
|
||||
sock_info.address,
|
||||
conn.address,
|
||||
max_await_time_ms=max_await_time_ms,
|
||||
session=tmp_session,
|
||||
explicit_session=session is not None,
|
||||
comment=comment,
|
||||
)
|
||||
cmd_cursor._maybe_pin_connection(sock_info)
|
||||
cmd_cursor._maybe_pin_connection(conn)
|
||||
return cmd_cursor
|
||||
else:
|
||||
raise InvalidOperation("Command does not return a cursor.")
|
||||
@ -1015,11 +1015,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
def _cmd(
|
||||
session: Optional[ClientSession],
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
read_preference: _ServerMode,
|
||||
) -> Dict[str, Any]:
|
||||
return self._command(
|
||||
sock_info,
|
||||
conn,
|
||||
command,
|
||||
read_preference=read_preference,
|
||||
session=session,
|
||||
@ -1029,7 +1029,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _list_collections(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
session: Optional[ClientSession],
|
||||
read_preference: _ServerMode,
|
||||
**kwargs: Any,
|
||||
@ -1039,18 +1039,18 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
cmd = SON([("listCollections", 1), ("cursor", {})])
|
||||
cmd.update(kwargs)
|
||||
with self.__client._tmp_session(session, close=False) as tmp_session:
|
||||
cursor = self._command(
|
||||
sock_info, cmd, read_preference=read_preference, session=tmp_session
|
||||
)["cursor"]
|
||||
cursor = self._command(conn, cmd, read_preference=read_preference, session=tmp_session)[
|
||||
"cursor"
|
||||
]
|
||||
cmd_cursor = CommandCursor(
|
||||
coll,
|
||||
cursor,
|
||||
sock_info.address,
|
||||
conn.address,
|
||||
session=tmp_session,
|
||||
explicit_session=session is not None,
|
||||
comment=cmd.get("comment"),
|
||||
)
|
||||
cmd_cursor._maybe_pin_connection(sock_info)
|
||||
cmd_cursor._maybe_pin_connection(conn)
|
||||
return cmd_cursor
|
||||
|
||||
def list_collections(
|
||||
@ -1090,12 +1090,10 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
def _cmd(
|
||||
session: Optional[ClientSession],
|
||||
server: Server,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
read_preference: _ServerMode,
|
||||
) -> CommandCursor[_DocumentType]:
|
||||
return self._list_collections(
|
||||
sock_info, session, read_preference=read_preference, **kwargs
|
||||
)
|
||||
return self._list_collections(conn, session, read_preference=read_preference, **kwargs)
|
||||
|
||||
return self.__client._retryable_read(_cmd, read_pref, session)
|
||||
|
||||
@ -1154,9 +1152,9 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
if comment is not None:
|
||||
command["comment"] = comment
|
||||
|
||||
with self.__client._socket_for_writes(session) as sock_info:
|
||||
with self.__client._conn_for_writes(session) as connection:
|
||||
return self._command(
|
||||
sock_info,
|
||||
connection,
|
||||
command,
|
||||
allowable_errors=["ns not found", 26],
|
||||
write_concern=self._write_concern_for(session),
|
||||
|
||||
@ -307,7 +307,7 @@ F = TypeVar("F", bound=Callable[..., Any])
|
||||
def _handle_reauth(func: F) -> F:
|
||||
def inner(*args: Any, **kwargs: Any) -> Any:
|
||||
no_reauth = kwargs.pop("no_reauth", False)
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@ -315,19 +315,19 @@ def _handle_reauth(func: F) -> F:
|
||||
if no_reauth:
|
||||
raise
|
||||
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
|
||||
# Look for an argument that either is a SocketInfo
|
||||
# or has a socket_info attribute, so we can trigger
|
||||
# Look for an argument that either is a Connection
|
||||
# or has a connection attribute, so we can trigger
|
||||
# a reauth.
|
||||
sock_info = None
|
||||
conn = None
|
||||
for arg in args:
|
||||
if isinstance(arg, SocketInfo):
|
||||
sock_info = arg
|
||||
if isinstance(arg, Connection):
|
||||
conn = arg
|
||||
break
|
||||
if hasattr(arg, "sock_info"):
|
||||
sock_info = arg.sock_info
|
||||
if hasattr(arg, "connection"):
|
||||
conn = arg.conn
|
||||
break
|
||||
if sock_info:
|
||||
sock_info.authenticate(reauthenticate=True)
|
||||
if conn:
|
||||
conn.authenticate(reauthenticate=True)
|
||||
else:
|
||||
raise
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@ -227,14 +227,14 @@ def _gen_find_command(
|
||||
return cmd
|
||||
|
||||
|
||||
def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, sock_info):
|
||||
def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, conn):
|
||||
"""Generate a getMore command document."""
|
||||
cmd = SON([("getMore", cursor_id), ("collection", coll)])
|
||||
if batch_size:
|
||||
cmd["batchSize"] = batch_size
|
||||
if max_await_time_ms is not None:
|
||||
cmd["maxTimeMS"] = max_await_time_ms
|
||||
if comment is not None and sock_info.max_wire_version >= 9:
|
||||
if comment is not None and conn.max_wire_version >= 9:
|
||||
cmd["comment"] = comment
|
||||
return cmd
|
||||
|
||||
@ -264,7 +264,7 @@ class _Query:
|
||||
)
|
||||
|
||||
# For compatibility with the _GetMore class.
|
||||
sock_mgr = None
|
||||
conn_mgr = None
|
||||
cursor_id = None
|
||||
|
||||
def __init__(
|
||||
@ -311,24 +311,23 @@ class _Query:
|
||||
def namespace(self):
|
||||
return f"{self.db}.{self.coll}"
|
||||
|
||||
def use_command(self, sock_info):
|
||||
def use_command(self, conn):
|
||||
use_find_cmd = False
|
||||
if not self.exhaust:
|
||||
use_find_cmd = True
|
||||
elif sock_info.max_wire_version >= 8:
|
||||
elif conn.max_wire_version >= 8:
|
||||
# OP_MSG supports exhaust on MongoDB 4.2+
|
||||
use_find_cmd = True
|
||||
elif not self.read_concern.ok_for_legacy:
|
||||
raise ConfigurationError(
|
||||
"read concern level of %s is not valid "
|
||||
"with a max wire version of %d."
|
||||
% (self.read_concern.level, sock_info.max_wire_version)
|
||||
"with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version)
|
||||
)
|
||||
|
||||
sock_info.validate_session(self.client, self.session)
|
||||
conn.validate_session(self.client, self.session)
|
||||
return use_find_cmd
|
||||
|
||||
def as_command(self, sock_info, apply_timeout=False):
|
||||
def as_command(self, conn, apply_timeout=False):
|
||||
"""Return a find command document for this query."""
|
||||
# We use the command twice: on the wire and for command monitoring.
|
||||
# Generate it once, for speed and to avoid repeating side-effects.
|
||||
@ -353,24 +352,24 @@ class _Query:
|
||||
self.name = "explain"
|
||||
cmd = SON([("explain", cmd)])
|
||||
session = self.session
|
||||
sock_info.add_server_api(cmd)
|
||||
conn.add_server_api(cmd)
|
||||
if session:
|
||||
session._apply_to(cmd, False, self.read_preference, sock_info)
|
||||
session._apply_to(cmd, False, self.read_preference, conn)
|
||||
# Explain does not support readConcern.
|
||||
if not explain and not session.in_transaction:
|
||||
session._update_read_concern(cmd, sock_info)
|
||||
sock_info.send_cluster_time(cmd, session, self.client)
|
||||
session._update_read_concern(cmd, conn)
|
||||
conn.send_cluster_time(cmd, session, self.client)
|
||||
# Support auto encryption
|
||||
client = self.client
|
||||
if client._encrypter and not client._encrypter._bypass_auto_encryption:
|
||||
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
|
||||
# Support CSOT
|
||||
if apply_timeout:
|
||||
sock_info.apply_timeout(client, cmd)
|
||||
conn.apply_timeout(client, cmd)
|
||||
self._as_command = cmd, self.db
|
||||
return self._as_command
|
||||
|
||||
def get_message(self, read_preference, sock_info, use_cmd=False):
|
||||
def get_message(self, read_preference, conn, use_cmd=False):
|
||||
"""Get a query message, possibly setting the secondaryOk bit."""
|
||||
# Use the read_preference decided by _socket_from_server.
|
||||
self.read_preference = read_preference
|
||||
@ -384,14 +383,14 @@ class _Query:
|
||||
spec = self.spec
|
||||
|
||||
if use_cmd:
|
||||
spec = self.as_command(sock_info, apply_timeout=True)[0]
|
||||
spec = self.as_command(conn, apply_timeout=True)[0]
|
||||
request_id, msg, size, _ = _op_msg(
|
||||
0,
|
||||
spec,
|
||||
self.db,
|
||||
read_preference,
|
||||
self.codec_options,
|
||||
ctx=sock_info.compression_context,
|
||||
ctx=conn.compression_context,
|
||||
)
|
||||
return request_id, msg, size
|
||||
|
||||
@ -405,7 +404,7 @@ class _Query:
|
||||
else:
|
||||
ntoreturn = self.limit
|
||||
|
||||
if sock_info.is_mongos:
|
||||
if conn.is_mongos:
|
||||
spec = _maybe_add_read_preference(spec, read_preference)
|
||||
|
||||
return _query(
|
||||
@ -416,7 +415,7 @@ class _Query:
|
||||
spec,
|
||||
None if use_cmd else self.fields,
|
||||
self.codec_options,
|
||||
ctx=sock_info.compression_context,
|
||||
ctx=conn.compression_context,
|
||||
)
|
||||
|
||||
|
||||
@ -433,7 +432,7 @@ class _GetMore:
|
||||
"read_preference",
|
||||
"session",
|
||||
"client",
|
||||
"sock_mgr",
|
||||
"conn_mgr",
|
||||
"_as_command",
|
||||
"exhaust",
|
||||
"comment",
|
||||
@ -452,7 +451,7 @@ class _GetMore:
|
||||
session,
|
||||
client,
|
||||
max_await_time_ms,
|
||||
sock_mgr,
|
||||
conn_mgr,
|
||||
exhaust,
|
||||
comment,
|
||||
):
|
||||
@ -465,7 +464,7 @@ class _GetMore:
|
||||
self.session = session
|
||||
self.client = client
|
||||
self.max_await_time_ms = max_await_time_ms
|
||||
self.sock_mgr = sock_mgr
|
||||
self.conn_mgr = conn_mgr
|
||||
self._as_command = None
|
||||
self.exhaust = exhaust
|
||||
self.comment = comment
|
||||
@ -476,18 +475,18 @@ class _GetMore:
|
||||
def namespace(self):
|
||||
return f"{self.db}.{self.coll}"
|
||||
|
||||
def use_command(self, sock_info):
|
||||
def use_command(self, conn):
|
||||
use_cmd = False
|
||||
if not self.exhaust:
|
||||
use_cmd = True
|
||||
elif sock_info.max_wire_version >= 8:
|
||||
elif conn.max_wire_version >= 8:
|
||||
# OP_MSG supports exhaust on MongoDB 4.2+
|
||||
use_cmd = True
|
||||
|
||||
sock_info.validate_session(self.client, self.session)
|
||||
conn.validate_session(self.client, self.session)
|
||||
return use_cmd
|
||||
|
||||
def as_command(self, sock_info, apply_timeout=False):
|
||||
def as_command(self, conn, apply_timeout=False):
|
||||
"""Return a getMore command document for this query."""
|
||||
# See _Query.as_command for an explanation of this caching.
|
||||
if self._as_command is not None:
|
||||
@ -499,35 +498,35 @@ class _GetMore:
|
||||
self.ntoreturn,
|
||||
self.max_await_time_ms,
|
||||
self.comment,
|
||||
sock_info,
|
||||
conn,
|
||||
)
|
||||
if self.session:
|
||||
self.session._apply_to(cmd, False, self.read_preference, sock_info)
|
||||
sock_info.add_server_api(cmd)
|
||||
sock_info.send_cluster_time(cmd, self.session, self.client)
|
||||
self.session._apply_to(cmd, False, self.read_preference, conn)
|
||||
conn.add_server_api(cmd)
|
||||
conn.send_cluster_time(cmd, self.session, self.client)
|
||||
# Support auto encryption
|
||||
client = self.client
|
||||
if client._encrypter and not client._encrypter._bypass_auto_encryption:
|
||||
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
|
||||
# Support CSOT
|
||||
if apply_timeout:
|
||||
sock_info.apply_timeout(client, cmd=None)
|
||||
conn.apply_timeout(client, cmd=None)
|
||||
self._as_command = cmd, self.db
|
||||
return self._as_command
|
||||
|
||||
def get_message(self, dummy0, sock_info, use_cmd=False):
|
||||
def get_message(self, dummy0, conn, use_cmd=False):
|
||||
"""Get a getmore message."""
|
||||
ns = self.namespace()
|
||||
ctx = sock_info.compression_context
|
||||
ctx = conn.compression_context
|
||||
|
||||
if use_cmd:
|
||||
spec = self.as_command(sock_info, apply_timeout=True)[0]
|
||||
if self.sock_mgr:
|
||||
spec = self.as_command(conn, apply_timeout=True)[0]
|
||||
if self.conn_mgr:
|
||||
flags = _OpMsg.EXHAUST_ALLOWED
|
||||
else:
|
||||
flags = 0
|
||||
request_id, msg, size, _ = _op_msg(
|
||||
flags, spec, self.db, None, self.codec_options, ctx=sock_info.compression_context
|
||||
flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context
|
||||
)
|
||||
return request_id, msg, size
|
||||
|
||||
@ -535,10 +534,10 @@ class _GetMore:
|
||||
|
||||
|
||||
class _RawBatchQuery(_Query):
|
||||
def use_command(self, sock_info):
|
||||
def use_command(self, conn):
|
||||
# Compatibility checks.
|
||||
super().use_command(sock_info)
|
||||
if sock_info.max_wire_version >= 8:
|
||||
super().use_command(conn)
|
||||
if conn.max_wire_version >= 8:
|
||||
# MongoDB 4.2+ supports exhaust over OP_MSG
|
||||
return True
|
||||
elif not self.exhaust:
|
||||
@ -547,10 +546,10 @@ class _RawBatchQuery(_Query):
|
||||
|
||||
|
||||
class _RawBatchGetMore(_GetMore):
|
||||
def use_command(self, sock_info):
|
||||
def use_command(self, conn):
|
||||
# Compatibility checks.
|
||||
super().use_command(sock_info)
|
||||
if sock_info.max_wire_version >= 8:
|
||||
super().use_command(conn)
|
||||
if conn.max_wire_version >= 8:
|
||||
# MongoDB 4.2+ supports exhaust over OP_MSG
|
||||
return True
|
||||
elif not self.exhaust:
|
||||
@ -794,11 +793,11 @@ def _get_more(collection_name, num_to_return, cursor_id, ctx=None):
|
||||
|
||||
|
||||
class _BulkWriteContext:
|
||||
"""A wrapper around SocketInfo for use with write splitting functions."""
|
||||
"""A wrapper around Connection for use with write splitting functions."""
|
||||
|
||||
__slots__ = (
|
||||
"db_name",
|
||||
"sock_info",
|
||||
"conn",
|
||||
"op_id",
|
||||
"name",
|
||||
"field",
|
||||
@ -812,10 +811,10 @@ class _BulkWriteContext:
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, database_name, cmd_name, sock_info, operation_id, listeners, session, op_type, codec
|
||||
self, database_name, cmd_name, conn, operation_id, listeners, session, op_type, codec
|
||||
):
|
||||
self.db_name = database_name
|
||||
self.sock_info = sock_info
|
||||
self.conn = conn
|
||||
self.op_id = operation_id
|
||||
self.listeners = listeners
|
||||
self.publish = listeners.enabled_for_commands
|
||||
@ -823,7 +822,7 @@ class _BulkWriteContext:
|
||||
self.field = _FIELD_MAP[self.name]
|
||||
self.start_time = datetime.datetime.now() if self.publish else None
|
||||
self.session = session
|
||||
self.compress = True if sock_info.compression_context else False
|
||||
self.compress = True if conn.compression_context else False
|
||||
self.op_type = op_type
|
||||
self.codec = codec
|
||||
|
||||
@ -855,20 +854,20 @@ class _BulkWriteContext:
|
||||
@property
|
||||
def max_bson_size(self):
|
||||
"""A proxy for SockInfo.max_bson_size."""
|
||||
return self.sock_info.max_bson_size
|
||||
return self.conn.max_bson_size
|
||||
|
||||
@property
|
||||
def max_message_size(self):
|
||||
"""A proxy for SockInfo.max_message_size."""
|
||||
if self.compress:
|
||||
# Subtract 16 bytes for the message header.
|
||||
return self.sock_info.max_message_size - 16
|
||||
return self.sock_info.max_message_size
|
||||
return self.conn.max_message_size - 16
|
||||
return self.conn.max_message_size
|
||||
|
||||
@property
|
||||
def max_write_batch_size(self):
|
||||
"""A proxy for SockInfo.max_write_batch_size."""
|
||||
return self.sock_info.max_write_batch_size
|
||||
return self.conn.max_write_batch_size
|
||||
|
||||
@property
|
||||
def max_split_size(self):
|
||||
@ -876,14 +875,14 @@ class _BulkWriteContext:
|
||||
return self.max_bson_size
|
||||
|
||||
def unack_write(self, cmd, request_id, msg, max_doc_size, docs):
|
||||
"""A proxy for SocketInfo.unack_write that handles event publishing."""
|
||||
"""A proxy for Connection.unack_write that handles event publishing."""
|
||||
if self.publish:
|
||||
assert self.start_time is not None
|
||||
duration = datetime.datetime.now() - self.start_time
|
||||
cmd = self._start(cmd, request_id, docs)
|
||||
start = datetime.datetime.now()
|
||||
try:
|
||||
result = self.sock_info.unack_write(msg, max_doc_size)
|
||||
result = self.conn.unack_write(msg, max_doc_size)
|
||||
if self.publish:
|
||||
duration = (datetime.datetime.now() - start) + duration
|
||||
if result is not None:
|
||||
@ -910,14 +909,14 @@ class _BulkWriteContext:
|
||||
|
||||
@_handle_reauth
|
||||
def write_command(self, cmd, request_id, msg, docs):
|
||||
"""A proxy for SocketInfo.write_command that handles event publishing."""
|
||||
"""A proxy for Connection.write_command that handles event publishing."""
|
||||
if self.publish:
|
||||
assert self.start_time is not None
|
||||
duration = datetime.datetime.now() - self.start_time
|
||||
self._start(cmd, request_id, docs)
|
||||
start = datetime.datetime.now()
|
||||
try:
|
||||
reply = self.sock_info.write_command(request_id, msg, self.codec)
|
||||
reply = self.conn.write_command(request_id, msg, self.codec)
|
||||
if self.publish:
|
||||
duration = (datetime.datetime.now() - start) + duration
|
||||
self._succeed(request_id, reply, duration)
|
||||
@ -941,9 +940,9 @@ class _BulkWriteContext:
|
||||
cmd,
|
||||
self.db_name,
|
||||
request_id,
|
||||
self.sock_info.address,
|
||||
self.conn.address,
|
||||
self.op_id,
|
||||
self.sock_info.service_id,
|
||||
self.conn.service_id,
|
||||
)
|
||||
return cmd
|
||||
|
||||
@ -954,9 +953,9 @@ class _BulkWriteContext:
|
||||
reply,
|
||||
self.name,
|
||||
request_id,
|
||||
self.sock_info.address,
|
||||
self.conn.address,
|
||||
self.op_id,
|
||||
self.sock_info.service_id,
|
||||
self.conn.service_id,
|
||||
)
|
||||
|
||||
def _fail(self, request_id, failure, duration):
|
||||
@ -966,9 +965,9 @@ class _BulkWriteContext:
|
||||
failure,
|
||||
self.name,
|
||||
request_id,
|
||||
self.sock_info.address,
|
||||
self.conn.address,
|
||||
self.op_id,
|
||||
self.sock_info.service_id,
|
||||
self.conn.service_id,
|
||||
)
|
||||
|
||||
|
||||
@ -997,14 +996,14 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
|
||||
|
||||
def execute(self, cmd, docs, client):
|
||||
batched_cmd, to_send = self._batch_command(cmd, docs)
|
||||
result = self.sock_info.command(
|
||||
result = self.conn.command(
|
||||
self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client
|
||||
)
|
||||
return result, to_send
|
||||
|
||||
def execute_unack(self, cmd, docs, client):
|
||||
batched_cmd, to_send = self._batch_command(cmd, docs)
|
||||
self.sock_info.command(
|
||||
self.conn.command(
|
||||
self.db_name,
|
||||
batched_cmd,
|
||||
write_concern=WriteConcern(w=0),
|
||||
@ -1124,7 +1123,7 @@ def _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx):
|
||||
"""
|
||||
data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx)
|
||||
|
||||
request_id, msg = _compress(2013, data, ctx.sock_info.compression_context)
|
||||
request_id, msg = _compress(2013, data, ctx.conn.compression_context)
|
||||
return request_id, msg, to_send
|
||||
|
||||
|
||||
@ -1162,7 +1161,7 @@ def _do_batched_op_msg(namespace, operation, command, docs, opts, ctx):
|
||||
ack = bool(command["writeConcern"].get("w", 1))
|
||||
else:
|
||||
ack = True
|
||||
if ctx.sock_info.compression_context:
|
||||
if ctx.conn.compression_context:
|
||||
return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx)
|
||||
return _batched_op_msg(operation, command, docs, ack, opts, ctx)
|
||||
|
||||
|
||||
@ -1160,18 +1160,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
def _end_sessions(self, session_ids):
|
||||
"""Send endSessions command(s) with the given session ids."""
|
||||
try:
|
||||
# Use SocketInfo.command directly to avoid implicitly creating
|
||||
# Use Connection.command directly to avoid implicitly creating
|
||||
# another session.
|
||||
with self._socket_for_reads(ReadPreference.PRIMARY_PREFERRED, None) as (
|
||||
sock_info,
|
||||
with self._conn_for_reads(ReadPreference.PRIMARY_PREFERRED, None) as (
|
||||
conn,
|
||||
read_pref,
|
||||
):
|
||||
if not sock_info.supports_sessions:
|
||||
if not conn.supports_sessions:
|
||||
return
|
||||
|
||||
for i in range(0, len(session_ids), common._MAX_END_SESSIONS):
|
||||
spec = SON([("endSessions", session_ids[i : i + common._MAX_END_SESSIONS])])
|
||||
sock_info.command("admin", spec, read_preference=read_pref, client=self)
|
||||
conn.command("admin", spec, read_preference=read_pref, client=self)
|
||||
except PyMongoError:
|
||||
# Drivers MUST ignore any errors returned by the endSessions
|
||||
# command.
|
||||
@ -1216,7 +1216,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
return self._topology
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _get_socket(self, server, session):
|
||||
def _checkout(self, server, session):
|
||||
in_txn = session and session.in_transaction
|
||||
with _MongoClientErrorHandler(self, server, session) as err_handler:
|
||||
# Reuse the pinned connection, if it exists.
|
||||
@ -1224,23 +1224,23 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
err_handler.contribute_socket(session._pinned_connection)
|
||||
yield session._pinned_connection
|
||||
return
|
||||
with server.get_socket(handler=err_handler) as sock_info:
|
||||
with server.checkout(handler=err_handler) as conn:
|
||||
# Pin this session to the selected server or connection.
|
||||
if in_txn and server.description.server_type in (
|
||||
SERVER_TYPE.Mongos,
|
||||
SERVER_TYPE.LoadBalancer,
|
||||
):
|
||||
session._pin(server, sock_info)
|
||||
err_handler.contribute_socket(sock_info)
|
||||
session._pin(server, conn)
|
||||
err_handler.contribute_socket(conn)
|
||||
if (
|
||||
self._encrypter
|
||||
and not self._encrypter._bypass_auto_encryption
|
||||
and sock_info.max_wire_version < 8
|
||||
and conn.max_wire_version < 8
|
||||
):
|
||||
raise ConfigurationError(
|
||||
"Auto-encryption requires a minimum MongoDB version of 4.2"
|
||||
)
|
||||
yield sock_info
|
||||
yield conn
|
||||
|
||||
def _select_server(self, server_selector, session, address=None):
|
||||
"""Select a server to run an operation on this client.
|
||||
@ -1273,15 +1273,15 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
session._unpin()
|
||||
raise
|
||||
|
||||
def _socket_for_writes(self, session):
|
||||
def _conn_for_writes(self, session):
|
||||
server = self._select_server(writable_server_selector, session)
|
||||
return self._get_socket(server, session)
|
||||
return self._checkout(server, session)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _socket_from_server(self, read_preference, server, session):
|
||||
def _conn_from_server(self, read_preference, server, session):
|
||||
assert read_preference is not None, "read_preference must not be None"
|
||||
# Get a socket for a server matching the read preference, and yield
|
||||
# sock_info with the effective read preference. The Server Selection
|
||||
# Get a connection for a server matching the read preference, and yield
|
||||
# conn with the effective read preference. The Server Selection
|
||||
# Spec says not to send any $readPreference to standalones and to
|
||||
# always send primaryPreferred when directly connected to a repl set
|
||||
# member.
|
||||
@ -1289,22 +1289,22 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
topology = self._get_topology()
|
||||
single = topology.description.topology_type == TOPOLOGY_TYPE.Single
|
||||
|
||||
with self._get_socket(server, session) as sock_info:
|
||||
with self._checkout(server, session) as conn:
|
||||
if single:
|
||||
if sock_info.is_repl and not (session and session.in_transaction):
|
||||
if conn.is_repl and not (session and session.in_transaction):
|
||||
# Use primary preferred to ensure any repl set member
|
||||
# can handle the request.
|
||||
read_preference = ReadPreference.PRIMARY_PREFERRED
|
||||
elif sock_info.is_standalone:
|
||||
elif conn.is_standalone:
|
||||
# Don't send read preference to standalones.
|
||||
read_preference = ReadPreference.PRIMARY
|
||||
yield sock_info, read_preference
|
||||
yield conn, read_preference
|
||||
|
||||
def _socket_for_reads(self, read_preference, session):
|
||||
def _conn_for_reads(self, read_preference, session):
|
||||
assert read_preference is not None, "read_preference must not be None"
|
||||
_ = self._get_topology()
|
||||
server = self._select_server(read_preference, session)
|
||||
return self._socket_from_server(read_preference, server, session)
|
||||
return self._conn_from_server(read_preference, server, session)
|
||||
|
||||
def _should_pin_cursor(self, session):
|
||||
return self.__options.load_balanced and not (session and session.in_transaction)
|
||||
@ -1319,22 +1319,26 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
- `address` (optional): Optional address when sending a message
|
||||
to a specific server, used for getMore.
|
||||
"""
|
||||
if operation.sock_mgr:
|
||||
if operation.conn_mgr:
|
||||
server = self._select_server(
|
||||
operation.read_preference, operation.session, address=address
|
||||
)
|
||||
|
||||
with operation.sock_mgr.lock:
|
||||
with operation.conn_mgr.lock:
|
||||
with _MongoClientErrorHandler(self, server, operation.session) as err_handler:
|
||||
err_handler.contribute_socket(operation.sock_mgr.sock)
|
||||
err_handler.contribute_socket(operation.conn_mgr.conn)
|
||||
return server.run_operation(
|
||||
operation.sock_mgr.sock, operation, True, self._event_listeners, unpack_res
|
||||
operation.conn_mgr.conn,
|
||||
operation,
|
||||
True,
|
||||
self._event_listeners,
|
||||
unpack_res,
|
||||
)
|
||||
|
||||
def _cmd(session, server, sock_info, read_preference):
|
||||
def _cmd(session, server, conn, read_preference):
|
||||
operation.reset() # Reset op in case of retry.
|
||||
return server.run_operation(
|
||||
sock_info, operation, read_preference, self._event_listeners, unpack_res
|
||||
conn, operation, read_preference, self._event_listeners, unpack_res
|
||||
)
|
||||
|
||||
return self._retryable_read(
|
||||
@ -1388,8 +1392,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
supports_session = (
|
||||
session is not None and server.description.retryable_writes_supported
|
||||
)
|
||||
with self._get_socket(server, session) as sock_info:
|
||||
max_wire_version = sock_info.max_wire_version
|
||||
with self._checkout(server, session) as conn:
|
||||
max_wire_version = conn.max_wire_version
|
||||
if retryable and not supports_session:
|
||||
if is_retrying():
|
||||
# A retry is not possible because this server does
|
||||
@ -1397,7 +1401,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
retryable = False
|
||||
return func(session, sock_info, retryable)
|
||||
return func(session, conn, retryable)
|
||||
except ServerSelectionTimeoutError:
|
||||
if is_retrying():
|
||||
# The application may think the write was never attempted
|
||||
@ -1455,13 +1459,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
raise last_error
|
||||
try:
|
||||
server = self._select_server(read_pref, session, address=address)
|
||||
with self._socket_from_server(read_pref, server, session) as (sock_info, read_pref):
|
||||
with self._conn_from_server(read_pref, server, session) as (
|
||||
conn,
|
||||
read_pref,
|
||||
):
|
||||
if retrying and not retryable:
|
||||
# A retry is not possible because this server does
|
||||
# not support retryable reads, raise the last error.
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
return func(session, server, sock_info, read_pref)
|
||||
return func(session, server, conn, read_pref)
|
||||
except ServerSelectionTimeoutError:
|
||||
if retrying:
|
||||
# The application may think the write was never attempted
|
||||
@ -1566,7 +1573,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
return database.Database(self, name)
|
||||
|
||||
def _cleanup_cursor(
|
||||
self, locks_allowed, cursor_id, address, sock_mgr, session, explicit_session
|
||||
self, locks_allowed, cursor_id, address, conn_mgr, session, explicit_session
|
||||
):
|
||||
"""Cleanup a cursor from cursor.close() or __del__.
|
||||
|
||||
@ -1578,33 +1585,33 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
- `locks_allowed`: True if we are allowed to acquire locks.
|
||||
- `cursor_id`: The cursor id which may be 0.
|
||||
- `address`: The _CursorAddress.
|
||||
- `sock_mgr`: The _SocketManager for the pinned connection or None.
|
||||
- `conn_mgr`: The _ConnectionManager for the pinned connection or None.
|
||||
- `session`: The cursor's session.
|
||||
- `explicit_session`: True if the session was passed explicitly.
|
||||
"""
|
||||
if locks_allowed:
|
||||
if cursor_id:
|
||||
if sock_mgr and sock_mgr.more_to_come:
|
||||
if conn_mgr and conn_mgr.more_to_come:
|
||||
# If this is an exhaust cursor and we haven't completely
|
||||
# exhausted the result set we *must* close the socket
|
||||
# to stop the server from sending more data.
|
||||
sock_mgr.sock.close_socket(ConnectionClosedReason.ERROR)
|
||||
conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
else:
|
||||
self._close_cursor_now(cursor_id, address, session=session, sock_mgr=sock_mgr)
|
||||
if sock_mgr:
|
||||
sock_mgr.close()
|
||||
self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr)
|
||||
if conn_mgr:
|
||||
conn_mgr.close()
|
||||
else:
|
||||
# The cursor will be closed later in a different session.
|
||||
if cursor_id or sock_mgr:
|
||||
self._close_cursor_soon(cursor_id, address, sock_mgr)
|
||||
if cursor_id or conn_mgr:
|
||||
self._close_cursor_soon(cursor_id, address, conn_mgr)
|
||||
if session and not explicit_session:
|
||||
session._end_session(lock=locks_allowed)
|
||||
|
||||
def _close_cursor_soon(self, cursor_id, address, sock_mgr=None):
|
||||
def _close_cursor_soon(self, cursor_id, address, conn_mgr=None):
|
||||
"""Request that a cursor and/or connection be cleaned up soon."""
|
||||
self.__kill_cursors_queue.append((address, cursor_id, sock_mgr))
|
||||
self.__kill_cursors_queue.append((address, cursor_id, conn_mgr))
|
||||
|
||||
def _close_cursor_now(self, cursor_id, address=None, session=None, sock_mgr=None):
|
||||
def _close_cursor_now(self, cursor_id, address=None, session=None, conn_mgr=None):
|
||||
"""Send a kill cursors message with the given id.
|
||||
|
||||
The cursor is closed synchronously on the current thread.
|
||||
@ -1613,10 +1620,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
raise TypeError("cursor_id must be an instance of int")
|
||||
|
||||
try:
|
||||
if sock_mgr:
|
||||
with sock_mgr.lock:
|
||||
if conn_mgr:
|
||||
with conn_mgr.lock:
|
||||
# Cursor is pinned to LB outside of a transaction.
|
||||
self._kill_cursor_impl([cursor_id], address, session, sock_mgr.sock)
|
||||
self._kill_cursor_impl([cursor_id], address, session, conn_mgr.conn)
|
||||
else:
|
||||
self._kill_cursors([cursor_id], address, self._get_topology(), session)
|
||||
except PyMongoError:
|
||||
@ -1633,14 +1640,14 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# Application called close_cursor() with no address.
|
||||
server = topology.select_server(writable_server_selector)
|
||||
|
||||
with self._get_socket(server, session) as sock_info:
|
||||
self._kill_cursor_impl(cursor_ids, address, session, sock_info)
|
||||
with self._checkout(server, session) as conn:
|
||||
self._kill_cursor_impl(cursor_ids, address, session, conn)
|
||||
|
||||
def _kill_cursor_impl(self, cursor_ids, address, session, sock_info):
|
||||
def _kill_cursor_impl(self, cursor_ids, address, session, conn):
|
||||
namespace = address.namespace
|
||||
db, coll = namespace.split(".", 1)
|
||||
spec = SON([("killCursors", coll), ("cursors", cursor_ids)])
|
||||
sock_info.command(db, spec, session=session, client=self)
|
||||
conn.command(db, spec, session=session, client=self)
|
||||
|
||||
def _process_kill_cursors(self):
|
||||
"""Process any pending kill cursors requests."""
|
||||
@ -1650,18 +1657,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# Other threads or the GC may append to the queue concurrently.
|
||||
while True:
|
||||
try:
|
||||
address, cursor_id, sock_mgr = self.__kill_cursors_queue.pop()
|
||||
address, cursor_id, conn_mgr = self.__kill_cursors_queue.pop()
|
||||
except IndexError:
|
||||
break
|
||||
|
||||
if sock_mgr:
|
||||
pinned_cursors.append((address, cursor_id, sock_mgr))
|
||||
if conn_mgr:
|
||||
pinned_cursors.append((address, cursor_id, conn_mgr))
|
||||
else:
|
||||
address_to_cursor_ids[address].append(cursor_id)
|
||||
|
||||
for address, cursor_id, sock_mgr in pinned_cursors:
|
||||
for address, cursor_id, conn_mgr in pinned_cursors:
|
||||
try:
|
||||
self._cleanup_cursor(True, cursor_id, address, sock_mgr, None, False)
|
||||
self._cleanup_cursor(True, cursor_id, address, conn_mgr, None, False)
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
# Raise the exception when client is closed so that it
|
||||
@ -1925,9 +1932,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name_or_database must be an instance of str or a Database")
|
||||
|
||||
with self._socket_for_writes(session) as sock_info:
|
||||
with self._conn_for_writes(session) as conn:
|
||||
self[name]._command(
|
||||
sock_info,
|
||||
conn,
|
||||
{"dropDatabase": 1, "comment": comment},
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
write_concern=self._write_concern_for(session),
|
||||
@ -2149,11 +2156,11 @@ class _MongoClientErrorHandler:
|
||||
self.service_id = None
|
||||
self.handled = False
|
||||
|
||||
def contribute_socket(self, sock_info, completed_handshake=True):
|
||||
def contribute_socket(self, conn, completed_handshake=True):
|
||||
"""Provide socket information to the error handler."""
|
||||
self.max_wire_version = sock_info.max_wire_version
|
||||
self.sock_generation = sock_info.generation
|
||||
self.service_id = sock_info.service_id
|
||||
self.max_wire_version = conn.max_wire_version
|
||||
self.sock_generation = conn.generation
|
||||
self.service_id = conn.service_id
|
||||
self.completed_handshake = completed_handshake
|
||||
|
||||
def handle(self, exc_type, exc_val):
|
||||
|
||||
@ -245,9 +245,9 @@ class Monitor(MonitorBase):
|
||||
|
||||
if self._cancel_context and self._cancel_context.cancelled:
|
||||
self._reset_connection()
|
||||
with self._pool.get_socket() as sock_info:
|
||||
self._cancel_context = sock_info.cancel_context
|
||||
response, round_trip_time = self._check_with_socket(sock_info)
|
||||
with self._pool.checkout() as conn:
|
||||
self._cancel_context = conn.cancel_context
|
||||
response, round_trip_time = self._check_with_socket(conn)
|
||||
if not response.awaitable:
|
||||
self._rtt_monitor.add_sample(round_trip_time)
|
||||
|
||||
@ -393,11 +393,11 @@ class _RttMonitor(MonitorBase):
|
||||
|
||||
def _ping(self):
|
||||
"""Run a "hello" command and return the RTT."""
|
||||
with self._pool.get_socket() as sock_info:
|
||||
with self._pool.checkout() as conn:
|
||||
if self._executor._stopped:
|
||||
raise Exception("_RttMonitor closed")
|
||||
start = time.monotonic()
|
||||
sock_info.hello()
|
||||
conn.hello()
|
||||
return time.monotonic() - start
|
||||
|
||||
|
||||
|
||||
@ -135,15 +135,15 @@ Connection monitoring and pooling events are also available. For example::
|
||||
logging.info("[pool {0.address}] pool closed".format(event))
|
||||
|
||||
def connection_created(self, event):
|
||||
logging.info("[pool {0.address}][conn #{0.connection_id}] "
|
||||
logging.info("[pool {0.address}][connection #{0.connection_id}] "
|
||||
"connection created".format(event))
|
||||
|
||||
def connection_ready(self, event):
|
||||
logging.info("[pool {0.address}][conn #{0.connection_id}] "
|
||||
logging.info("[pool {0.address}][connection #{0.connection_id}] "
|
||||
"connection setup succeeded".format(event))
|
||||
|
||||
def connection_closed(self, event):
|
||||
logging.info("[pool {0.address}][conn #{0.connection_id}] "
|
||||
logging.info("[pool {0.address}][connection #{0.connection_id}] "
|
||||
"connection closed, reason: "
|
||||
"{0.reason}".format(event))
|
||||
|
||||
@ -156,11 +156,11 @@ Connection monitoring and pooling events are also available. For example::
|
||||
"failed, reason: {0.reason}".format(event))
|
||||
|
||||
def connection_checked_out(self, event):
|
||||
logging.info("[pool {0.address}][conn #{0.connection_id}] "
|
||||
logging.info("[pool {0.address}][connection #{0.connection_id}] "
|
||||
"connection checked out of pool".format(event))
|
||||
|
||||
def connection_checked_in(self, event):
|
||||
logging.info("[pool {0.address}][conn #{0.connection_id}] "
|
||||
logging.info("[pool {0.address}][connection #{0.connection_id}] "
|
||||
"connection checked into pool".format(event))
|
||||
|
||||
|
||||
@ -268,7 +268,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
def pool_created(self, event: "PoolCreatedEvent") -> None:
|
||||
"""Abstract method to handle a :class:`PoolCreatedEvent`.
|
||||
|
||||
Emitted when a Connection Pool is created.
|
||||
Emitted when a connection Pool is created.
|
||||
|
||||
:Parameters:
|
||||
- `event`: An instance of :class:`PoolCreatedEvent`.
|
||||
@ -278,7 +278,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
def pool_ready(self, event: "PoolReadyEvent") -> None:
|
||||
"""Abstract method to handle a :class:`PoolReadyEvent`.
|
||||
|
||||
Emitted when a Connection Pool is marked ready.
|
||||
Emitted when a connection Pool is marked ready.
|
||||
|
||||
:Parameters:
|
||||
- `event`: An instance of :class:`PoolReadyEvent`.
|
||||
@ -290,7 +290,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
def pool_cleared(self, event: "PoolClearedEvent") -> None:
|
||||
"""Abstract method to handle a `PoolClearedEvent`.
|
||||
|
||||
Emitted when a Connection Pool is cleared.
|
||||
Emitted when a connection Pool is cleared.
|
||||
|
||||
:Parameters:
|
||||
- `event`: An instance of :class:`PoolClearedEvent`.
|
||||
@ -300,7 +300,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
def pool_closed(self, event: "PoolClosedEvent") -> None:
|
||||
"""Abstract method to handle a `PoolClosedEvent`.
|
||||
|
||||
Emitted when a Connection Pool is closed.
|
||||
Emitted when a connection Pool is closed.
|
||||
|
||||
:Parameters:
|
||||
- `event`: An instance of :class:`PoolClosedEvent`.
|
||||
@ -310,7 +310,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
def connection_created(self, event: "ConnectionCreatedEvent") -> None:
|
||||
"""Abstract method to handle a :class:`ConnectionCreatedEvent`.
|
||||
|
||||
Emitted when a Connection Pool creates a Connection object.
|
||||
Emitted when a connection Pool creates a Connection object.
|
||||
|
||||
:Parameters:
|
||||
- `event`: An instance of :class:`ConnectionCreatedEvent`.
|
||||
@ -320,7 +320,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
def connection_ready(self, event: "ConnectionReadyEvent") -> None:
|
||||
"""Abstract method to handle a :class:`ConnectionReadyEvent`.
|
||||
|
||||
Emitted when a Connection has finished its setup, and is now ready to
|
||||
Emitted when a connection has finished its setup, and is now ready to
|
||||
use.
|
||||
|
||||
:Parameters:
|
||||
@ -331,7 +331,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
def connection_closed(self, event: "ConnectionClosedEvent") -> None:
|
||||
"""Abstract method to handle a :class:`ConnectionClosedEvent`.
|
||||
|
||||
Emitted when a Connection Pool closes a Connection.
|
||||
Emitted when a connection Pool closes a connection.
|
||||
|
||||
:Parameters:
|
||||
- `event`: An instance of :class:`ConnectionClosedEvent`.
|
||||
@ -361,7 +361,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
def connection_checked_out(self, event: "ConnectionCheckedOutEvent") -> None:
|
||||
"""Abstract method to handle a :class:`ConnectionCheckedOutEvent`.
|
||||
|
||||
Emitted when the driver successfully checks out a Connection.
|
||||
Emitted when the driver successfully checks out a connection.
|
||||
|
||||
:Parameters:
|
||||
- `event`: An instance of :class:`ConnectionCheckedOutEvent`.
|
||||
@ -371,7 +371,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
def connection_checked_in(self, event: "ConnectionCheckedInEvent") -> None:
|
||||
"""Abstract method to handle a :class:`ConnectionCheckedInEvent`.
|
||||
|
||||
Emitted when the driver checks in a Connection back to the Connection
|
||||
Emitted when the driver checks in a connection back to the connection
|
||||
Pool.
|
||||
|
||||
:Parameters:
|
||||
@ -948,7 +948,7 @@ class _ConnectionIdEvent(_ConnectionEvent):
|
||||
|
||||
@property
|
||||
def connection_id(self) -> int:
|
||||
"""The ID of the Connection."""
|
||||
"""The ID of the connection."""
|
||||
return self.__connection_id
|
||||
|
||||
def __repr__(self):
|
||||
@ -1066,7 +1066,7 @@ class ConnectionCheckOutFailedEvent(_ConnectionEvent):
|
||||
|
||||
|
||||
class ConnectionCheckedOutEvent(_ConnectionIdEvent):
|
||||
"""Published when the driver successfully checks out a Connection.
|
||||
"""Published when the driver successfully checks out a connection.
|
||||
|
||||
:Parameters:
|
||||
- `address`: The address (host, port) pair of the server this
|
||||
|
||||
@ -52,7 +52,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.typings import _Address
|
||||
@ -62,7 +62,7 @@ _UNPACK_HEADER = struct.Struct("<iiii").unpack
|
||||
|
||||
|
||||
def command(
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
dbname: str,
|
||||
spec: MutableMapping[str, Any],
|
||||
is_mongos: bool,
|
||||
@ -88,7 +88,7 @@ def command(
|
||||
"""Execute a command over the socket, or raise socket.error.
|
||||
|
||||
:Parameters:
|
||||
- `sock`: a raw socket instance
|
||||
- `conn`: a Connection instance
|
||||
- `dbname`: name of the database on which to run the command
|
||||
- `spec`: a command document as an ordered dict type, eg SON.
|
||||
- `is_mongos`: are we connected to a mongos?
|
||||
@ -98,7 +98,7 @@ def command(
|
||||
- `client`: optional MongoClient instance for updating $clusterTime.
|
||||
- `check`: raise OperationFailure if there are errors
|
||||
- `allowable_errors`: errors to ignore if `check` is True
|
||||
- `address`: the (host, port) of `sock`
|
||||
- `address`: the (host, port) of `conn`
|
||||
- `listeners`: An instance of :class:`~pymongo.monitoring.EventListeners`
|
||||
- `max_bson_size`: The maximum encoded bson size for this server
|
||||
- `read_concern`: The read concern for this command.
|
||||
@ -125,7 +125,7 @@ def command(
|
||||
if read_concern.level:
|
||||
spec["readConcern"] = read_concern.document
|
||||
if session:
|
||||
session._update_read_concern(spec, sock_info)
|
||||
session._update_read_concern(spec, conn)
|
||||
if collation is not None:
|
||||
spec["collation"] = collation
|
||||
|
||||
@ -142,7 +142,7 @@ def command(
|
||||
|
||||
# Support CSOT
|
||||
if client:
|
||||
sock_info.apply_timeout(client, spec)
|
||||
conn.apply_timeout(client, spec)
|
||||
_csot.apply_write_concern(spec, write_concern)
|
||||
|
||||
if use_op_msg:
|
||||
@ -167,19 +167,19 @@ def command(
|
||||
encoding_duration = datetime.datetime.now() - start
|
||||
assert listeners is not None
|
||||
listeners.publish_command_start(
|
||||
orig, dbname, request_id, address, service_id=sock_info.service_id
|
||||
orig, dbname, request_id, address, service_id=conn.service_id
|
||||
)
|
||||
start = datetime.datetime.now()
|
||||
|
||||
try:
|
||||
sock_info.sock.sendall(msg)
|
||||
conn.conn.sendall(msg)
|
||||
if use_op_msg and unacknowledged:
|
||||
# Unacknowledged, fake a successful command response.
|
||||
reply = None
|
||||
response_doc = {"ok": 1}
|
||||
else:
|
||||
reply = receive_message(sock_info, request_id)
|
||||
sock_info.more_to_come = reply.more_to_come
|
||||
reply = receive_message(conn, request_id)
|
||||
conn.more_to_come = reply.more_to_come
|
||||
unpacked_docs = reply.unpack_response(
|
||||
codec_options=codec_options, user_fields=user_fields
|
||||
)
|
||||
@ -190,7 +190,7 @@ def command(
|
||||
if check:
|
||||
helpers._check_command_response(
|
||||
response_doc,
|
||||
sock_info.max_wire_version,
|
||||
conn.max_wire_version,
|
||||
allowable_errors,
|
||||
parse_write_concern_error=parse_write_concern_error,
|
||||
)
|
||||
@ -203,7 +203,7 @@ def command(
|
||||
failure = message._convert_exception(exc)
|
||||
assert listeners is not None
|
||||
listeners.publish_command_failure(
|
||||
duration, failure, name, request_id, address, service_id=sock_info.service_id
|
||||
duration, failure, name, request_id, address, service_id=conn.service_id
|
||||
)
|
||||
raise
|
||||
if publish:
|
||||
@ -215,7 +215,7 @@ def command(
|
||||
name,
|
||||
request_id,
|
||||
address,
|
||||
service_id=sock_info.service_id,
|
||||
service_id=conn.service_id,
|
||||
speculative_hello=speculative_hello,
|
||||
)
|
||||
|
||||
@ -230,21 +230,19 @@ _UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
|
||||
|
||||
|
||||
def receive_message(
|
||||
sock_info: SocketInfo, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE
|
||||
conn: Connection, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
if _csot.get_timeout():
|
||||
deadline = _csot.get_deadline()
|
||||
else:
|
||||
timeout = sock_info.sock.gettimeout()
|
||||
timeout = conn.conn.gettimeout()
|
||||
if timeout:
|
||||
deadline = time.monotonic() + timeout
|
||||
else:
|
||||
deadline = None
|
||||
# Ignore the response's request id.
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(
|
||||
_receive_data_on_socket(sock_info, 16, deadline)
|
||||
)
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(_receive_data_on_socket(conn, 16, deadline))
|
||||
# No request_id for exhaust cursor "getMore".
|
||||
if request_id is not None:
|
||||
if request_id != response_to:
|
||||
@ -260,11 +258,11 @@ def receive_message(
|
||||
)
|
||||
if op_code == 2012:
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
|
||||
_receive_data_on_socket(sock_info, 9, deadline)
|
||||
_receive_data_on_socket(conn, 9, deadline)
|
||||
)
|
||||
data = decompress(_receive_data_on_socket(sock_info, length - 25, deadline), compressor_id)
|
||||
data = decompress(_receive_data_on_socket(conn, length - 25, deadline), compressor_id)
|
||||
else:
|
||||
data = _receive_data_on_socket(sock_info, length - 16, deadline)
|
||||
data = _receive_data_on_socket(conn, length - 16, deadline)
|
||||
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
@ -276,12 +274,12 @@ def receive_message(
|
||||
_POLL_TIMEOUT = 0.5
|
||||
|
||||
|
||||
def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None:
|
||||
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
|
||||
"""Block until at least one byte is read, or a timeout, or a cancel."""
|
||||
context = sock_info.cancel_context
|
||||
context = conn.cancel_context
|
||||
# Only Monitor connections can be cancelled.
|
||||
if context:
|
||||
sock = sock_info.sock
|
||||
sock = conn.conn
|
||||
timed_out = False
|
||||
while True:
|
||||
# SSLSocket can have buffered data which won't be caught by select.
|
||||
@ -300,7 +298,7 @@ def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None:
|
||||
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
|
||||
else:
|
||||
timeout = _POLL_TIMEOUT
|
||||
readable = sock_info.socket_checker.select(sock, read=True, timeout=timeout)
|
||||
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
|
||||
if context.cancelled:
|
||||
raise _OperationCancelled("hello cancelled")
|
||||
if readable:
|
||||
@ -313,21 +311,19 @@ def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None:
|
||||
BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS)
|
||||
|
||||
|
||||
def _receive_data_on_socket(
|
||||
sock_info: SocketInfo, length: int, deadline: Optional[float]
|
||||
) -> memoryview:
|
||||
def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
|
||||
buf = bytearray(length)
|
||||
mv = memoryview(buf)
|
||||
bytes_read = 0
|
||||
while bytes_read < length:
|
||||
try:
|
||||
wait_for_read(sock_info, deadline)
|
||||
wait_for_read(conn, deadline)
|
||||
# CSOT: Update timeout. When the timeout has expired perform one
|
||||
# final non-blocking recv. This helps avoid spurious timeouts when
|
||||
# the response is actually already buffered on the client.
|
||||
if _csot.get_timeout() and deadline is not None:
|
||||
sock_info.set_socket_timeout(max(deadline - time.monotonic(), 0))
|
||||
chunk_length = sock_info.sock.recv_into(mv[bytes_read:])
|
||||
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
|
||||
chunk_length = conn.conn.recv_into(mv[bytes_read:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
raise socket.timeout("timed out")
|
||||
except OSError as exc: # noqa: B014
|
||||
|
||||
232
pymongo/pool.py
232
pymongo/pool.py
@ -614,19 +614,19 @@ class _CancellationContext:
|
||||
return self._cancelled
|
||||
|
||||
|
||||
class SocketInfo:
|
||||
"""Store a socket with some metadata.
|
||||
class Connection:
|
||||
"""Store a connection with some metadata.
|
||||
|
||||
:Parameters:
|
||||
- `sock`: a raw socket object
|
||||
- `conn`: a raw connection object
|
||||
- `pool`: a Pool instance
|
||||
- `address`: the server's (host, port)
|
||||
- `id`: the id of this socket in it's pool
|
||||
"""
|
||||
|
||||
def __init__(self, sock, pool, address, id):
|
||||
def __init__(self, conn, pool, address, id):
|
||||
self.pool_ref = weakref.ref(pool)
|
||||
self.sock = sock
|
||||
self.conn = conn
|
||||
self.address = address
|
||||
self.id = id
|
||||
self.authed = set()
|
||||
@ -673,12 +673,12 @@ class SocketInfo:
|
||||
self.last_timeout = self.opts.socket_timeout
|
||||
self.connect_rtt = 0.0
|
||||
|
||||
def set_socket_timeout(self, timeout):
|
||||
"""Cache last timeout to avoid duplicate calls to sock.settimeout."""
|
||||
def set_conn_timeout(self, timeout):
|
||||
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
|
||||
if timeout == self.last_timeout:
|
||||
return
|
||||
self.last_timeout = timeout
|
||||
self.sock.settimeout(timeout)
|
||||
self.conn.settimeout(timeout)
|
||||
|
||||
def apply_timeout(self, client, cmd):
|
||||
# CSOT: use remaining timeout when set.
|
||||
@ -686,7 +686,7 @@ class SocketInfo:
|
||||
if timeout is None:
|
||||
# Reset the socket timeout unless we're performing a streaming monitor check.
|
||||
if not self.more_to_come:
|
||||
self.set_socket_timeout(self.opts.socket_timeout)
|
||||
self.set_conn_timeout(self.opts.socket_timeout)
|
||||
return None
|
||||
# RTT validation.
|
||||
rtt = _csot.get_rtt()
|
||||
@ -701,7 +701,7 @@ class SocketInfo:
|
||||
)
|
||||
if cmd is not None:
|
||||
cmd["maxTimeMS"] = int(max_time_ms * 1000)
|
||||
self.set_socket_timeout(timeout)
|
||||
self.set_conn_timeout(timeout)
|
||||
return timeout
|
||||
|
||||
def pin_txn(self):
|
||||
@ -715,9 +715,9 @@ class SocketInfo:
|
||||
def unpin(self):
|
||||
pool = self.pool_ref()
|
||||
if pool:
|
||||
pool.return_socket(self)
|
||||
pool.checkin(self)
|
||||
else:
|
||||
self.close_socket(ConnectionClosedReason.STALE)
|
||||
self.close_conn(ConnectionClosedReason.STALE)
|
||||
|
||||
def hello_cmd(self):
|
||||
# Handshake spec requires us to use OP_MSG+hello command for the
|
||||
@ -748,7 +748,7 @@ class SocketInfo:
|
||||
awaitable = True
|
||||
# If connect_timeout is None there is no timeout.
|
||||
if self.opts.connect_timeout:
|
||||
self.set_socket_timeout(self.opts.connect_timeout + heartbeat_frequency)
|
||||
self.set_conn_timeout(self.opts.connect_timeout + heartbeat_frequency)
|
||||
|
||||
if not performing_handshake and cluster_time is not None:
|
||||
cmd["$clusterTime"] = cluster_time
|
||||
@ -919,7 +919,7 @@ class SocketInfo:
|
||||
)
|
||||
|
||||
try:
|
||||
self.sock.sendall(message)
|
||||
self.conn.sendall(message)
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
|
||||
@ -999,15 +999,15 @@ class SocketInfo:
|
||||
if session._client is not client:
|
||||
raise InvalidOperation("Can only use session with the MongoClient that started it")
|
||||
|
||||
def close_socket(self, reason):
|
||||
def close_conn(self, reason):
|
||||
"""Close this connection with a reason."""
|
||||
if self.closed:
|
||||
return
|
||||
self._close_socket()
|
||||
self._close_conn()
|
||||
if reason and self.enabled_for_cmap:
|
||||
self.listeners.publish_connection_closed(self.address, self.id, reason)
|
||||
|
||||
def _close_socket(self):
|
||||
def _close_conn(self):
|
||||
"""Close this connection."""
|
||||
if self.closed:
|
||||
return
|
||||
@ -1017,13 +1017,13 @@ class SocketInfo:
|
||||
# Note: We catch exceptions to avoid spurious errors on interpreter
|
||||
# shutdown.
|
||||
try:
|
||||
self.sock.close()
|
||||
self.conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def socket_closed(self):
|
||||
def conn_closed(self):
|
||||
"""Return True if we know socket has been closed, False otherwise."""
|
||||
return self.socket_checker.socket_closed(self.sock)
|
||||
return self.socket_checker.socket_closed(self.conn)
|
||||
|
||||
def send_cluster_time(self, command, session, client):
|
||||
"""Add $clusterTime."""
|
||||
@ -1060,12 +1060,12 @@ class SocketInfo:
|
||||
# KeyboardInterrupt from the start, rather than as an initial
|
||||
# socket.error, so we catch that, close the socket, and reraise it.
|
||||
#
|
||||
# The connection closed event will be emitted later in return_socket.
|
||||
# The connection closed event will be emitted later in checkin.
|
||||
if self.ready:
|
||||
reason = None
|
||||
else:
|
||||
reason = ConnectionClosedReason.ERROR
|
||||
self.close_socket(reason)
|
||||
self.close_conn(reason)
|
||||
# SSLError from PyOpenSSL inherits directly from Exception.
|
||||
if isinstance(error, (IOError, OSError, SSLError)):
|
||||
_raise_connection_failure(self.address, error)
|
||||
@ -1073,17 +1073,17 @@ class SocketInfo:
|
||||
raise
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.sock == other.sock
|
||||
return self.conn == other.conn
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.sock)
|
||||
return hash(self.conn)
|
||||
|
||||
def __repr__(self):
|
||||
return "SocketInfo({}){} at {}".format(
|
||||
repr(self.sock),
|
||||
return "Connection({}){} at {}".format(
|
||||
repr(self.conn),
|
||||
self.closed and " CLOSED" or "",
|
||||
id(self),
|
||||
)
|
||||
@ -1256,7 +1256,7 @@ class Pool:
|
||||
:Parameters:
|
||||
- `address`: a (hostname, port) tuple
|
||||
- `options`: a PoolOptions instance
|
||||
- `handshake`: whether to call hello for each new SocketInfo
|
||||
- `handshake`: whether to call hello for each new Connection
|
||||
"""
|
||||
if options.pause_enabled:
|
||||
self.state = PoolState.PAUSED
|
||||
@ -1268,7 +1268,7 @@ class Pool:
|
||||
# LIFO pool. Sockets are ordered on idle time. Sockets claimed
|
||||
# and returned to pool from the left side. Stale sockets removed
|
||||
# from the right side.
|
||||
self.sockets: collections.deque = collections.deque()
|
||||
self.conns: collections.deque = collections.deque()
|
||||
self.lock = _create_lock()
|
||||
self.active_sockets = 0
|
||||
# Monotonically increasing connection ID required for CMAP Events.
|
||||
@ -1344,17 +1344,17 @@ class Pool:
|
||||
self.active_sockets = 0
|
||||
self.operation_count = 0
|
||||
if service_id is None:
|
||||
sockets, self.sockets = self.sockets, collections.deque()
|
||||
sockets, self.conns = self.conns, collections.deque()
|
||||
else:
|
||||
discard: collections.deque = collections.deque()
|
||||
keep: collections.deque = collections.deque()
|
||||
for sock_info in self.sockets:
|
||||
if sock_info.service_id == service_id:
|
||||
discard.append(sock_info)
|
||||
for conn in self.conns:
|
||||
if conn.service_id == service_id:
|
||||
discard.append(conn)
|
||||
else:
|
||||
keep.append(sock_info)
|
||||
keep.append(conn)
|
||||
sockets = discard
|
||||
self.sockets = keep
|
||||
self.conns = keep
|
||||
|
||||
if close:
|
||||
self.state = PoolState.CLOSED
|
||||
@ -1367,15 +1367,15 @@ class Pool:
|
||||
# PoolClosedEvent but that reset() SHOULD close sockets *after*
|
||||
# publishing the PoolClearedEvent.
|
||||
if close:
|
||||
for sock_info in sockets:
|
||||
sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED)
|
||||
for conn in sockets:
|
||||
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_pool_closed(self.address)
|
||||
else:
|
||||
if old_state != PoolState.PAUSED and self.enabled_for_cmap:
|
||||
listeners.publish_pool_cleared(self.address, service_id=service_id)
|
||||
for sock_info in sockets:
|
||||
sock_info.close_socket(ConnectionClosedReason.STALE)
|
||||
for conn in sockets:
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
|
||||
def update_is_writable(self, is_writable):
|
||||
"""Updates the is_writable attribute on all sockets currently in the
|
||||
@ -1383,7 +1383,7 @@ class Pool:
|
||||
"""
|
||||
self.is_writable = is_writable
|
||||
with self.lock:
|
||||
for _socket in self.sockets:
|
||||
for _socket in self.conns:
|
||||
_socket.update_is_writable(self.is_writable)
|
||||
|
||||
def reset(self, service_id=None):
|
||||
@ -1412,16 +1412,16 @@ class Pool:
|
||||
if self.opts.max_idle_time_seconds is not None:
|
||||
with self.lock:
|
||||
while (
|
||||
self.sockets
|
||||
and self.sockets[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
|
||||
self.conns
|
||||
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
|
||||
):
|
||||
sock_info = self.sockets.pop()
|
||||
sock_info.close_socket(ConnectionClosedReason.IDLE)
|
||||
conn = self.conns.pop()
|
||||
conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
|
||||
while True:
|
||||
with self.size_cond:
|
||||
# There are enough sockets in the pool.
|
||||
if len(self.sockets) + self.active_sockets >= self.opts.min_pool_size:
|
||||
if len(self.conns) + self.active_sockets >= self.opts.min_pool_size:
|
||||
return
|
||||
if self.requests >= self.opts.min_pool_size:
|
||||
return
|
||||
@ -1435,14 +1435,14 @@ class Pool:
|
||||
return
|
||||
self._pending += 1
|
||||
incremented = True
|
||||
sock_info = self.connect()
|
||||
conn = self.connect()
|
||||
with self.lock:
|
||||
# Close connection and return if the pool was reset during
|
||||
# socket creation or while acquiring the pool lock.
|
||||
if self.gen.get_overall() != reference_generation:
|
||||
sock_info.close_socket(ConnectionClosedReason.STALE)
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
return
|
||||
self.sockets.appendleft(sock_info)
|
||||
self.conns.appendleft(conn)
|
||||
finally:
|
||||
if incremented:
|
||||
# Notify after adding the socket to the pool.
|
||||
@ -1455,12 +1455,12 @@ class Pool:
|
||||
self.size_cond.notify()
|
||||
|
||||
def connect(self, handler=None):
|
||||
"""Connect to Mongo and return a new SocketInfo.
|
||||
"""Connect to Mongo and return a new Connection.
|
||||
|
||||
Can raise ConnectionFailure.
|
||||
|
||||
Note that the pool does not keep a reference to the socket -- you
|
||||
must call return_socket() when you're done with it.
|
||||
must call checkin() when you're done with it.
|
||||
"""
|
||||
with self.lock:
|
||||
conn_id = self.next_connection_id
|
||||
@ -1483,33 +1483,33 @@ class Pool:
|
||||
|
||||
raise
|
||||
|
||||
sock_info = SocketInfo(sock, self, self.address, conn_id)
|
||||
conn = Connection(sock, self, self.address, conn_id)
|
||||
try:
|
||||
if self.handshake:
|
||||
sock_info.hello()
|
||||
self.is_writable = sock_info.is_writable
|
||||
conn.hello()
|
||||
self.is_writable = conn.is_writable
|
||||
if handler:
|
||||
handler.contribute_socket(sock_info, completed_handshake=False)
|
||||
handler.contribute_socket(conn, completed_handshake=False)
|
||||
|
||||
sock_info.authenticate()
|
||||
conn.authenticate()
|
||||
except BaseException:
|
||||
sock_info.close_socket(ConnectionClosedReason.ERROR)
|
||||
conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
raise
|
||||
|
||||
return sock_info
|
||||
return conn
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_socket(self, handler=None):
|
||||
"""Get a socket from the pool. Use with a "with" statement.
|
||||
def checkout(self, handler=None):
|
||||
"""Get a connection from the pool. Use with a "with" statement.
|
||||
|
||||
Returns a :class:`SocketInfo` object wrapping a connected
|
||||
Returns a :class:`Connection` object wrapping a connected
|
||||
:class:`socket.socket`.
|
||||
|
||||
This method should always be used in a with-statement::
|
||||
|
||||
with pool.get_socket() as socket_info:
|
||||
socket_info.send_message(msg)
|
||||
data = socket_info.receive_message(op_code, request_id)
|
||||
with pool.get_conn() as connection:
|
||||
connection.send_message(msg)
|
||||
data = connection.receive_message(op_code, request_id)
|
||||
|
||||
Can raise ConnectionFailure or OperationFailure.
|
||||
|
||||
@ -1520,36 +1520,36 @@ class Pool:
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_check_out_started(self.address)
|
||||
|
||||
sock_info = self._get_socket(handler=handler)
|
||||
conn = self._get_conn(handler=handler)
|
||||
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_checked_out(self.address, sock_info.id)
|
||||
listeners.publish_connection_checked_out(self.address, conn.id)
|
||||
try:
|
||||
yield sock_info
|
||||
yield conn
|
||||
except BaseException:
|
||||
# Exception in caller. Ensure the connection gets returned.
|
||||
# Note that when pinned is True, the session owns the
|
||||
# connection and it is responsible for checking the connection
|
||||
# back into the pool.
|
||||
pinned = sock_info.pinned_txn or sock_info.pinned_cursor
|
||||
pinned = conn.pinned_txn or conn.pinned_cursor
|
||||
if handler:
|
||||
# Perform SDAM error handling rules while the connection is
|
||||
# still checked out.
|
||||
exc_type, exc_val, _ = sys.exc_info()
|
||||
handler.handle(exc_type, exc_val)
|
||||
if not pinned and sock_info.active:
|
||||
self.return_socket(sock_info)
|
||||
if not pinned and conn.active:
|
||||
self.checkin(conn)
|
||||
raise
|
||||
if sock_info.pinned_txn:
|
||||
if conn.pinned_txn:
|
||||
with self.lock:
|
||||
self.__pinned_sockets.add(sock_info)
|
||||
self.__pinned_sockets.add(conn)
|
||||
self.ntxns += 1
|
||||
elif sock_info.pinned_cursor:
|
||||
elif conn.pinned_cursor:
|
||||
with self.lock:
|
||||
self.__pinned_sockets.add(sock_info)
|
||||
self.__pinned_sockets.add(conn)
|
||||
self.ncursors += 1
|
||||
elif sock_info.active:
|
||||
self.return_socket(sock_info)
|
||||
elif conn.active:
|
||||
self.checkin(conn)
|
||||
|
||||
def _raise_if_not_ready(self, emit_event):
|
||||
if self.state != PoolState.READY:
|
||||
@ -1559,8 +1559,8 @@ class Pool:
|
||||
)
|
||||
_raise_connection_failure(self.address, AutoReconnect("connection pool paused"))
|
||||
|
||||
def _get_socket(self, handler=None):
|
||||
"""Get or create a SocketInfo. Can raise ConnectionFailure."""
|
||||
def _get_conn(self, handler=None):
|
||||
"""Get or create a Connection. Can raise ConnectionFailure."""
|
||||
# We use the pid here to avoid issues with fork / multiprocessing.
|
||||
# See test.test_client:TestClient.test_fork for an example of
|
||||
# what could go wrong otherwise
|
||||
@ -1600,7 +1600,7 @@ class Pool:
|
||||
self.requests += 1
|
||||
|
||||
# We've now acquired the semaphore and must release it on error.
|
||||
sock_info = None
|
||||
conn = None
|
||||
incremented = False
|
||||
emitted_event = False
|
||||
try:
|
||||
@ -1608,40 +1608,40 @@ class Pool:
|
||||
self.active_sockets += 1
|
||||
incremented = True
|
||||
|
||||
while sock_info is None:
|
||||
while conn is None:
|
||||
# CMAP: we MUST wait for either maxConnecting OR for a socket
|
||||
# to be checked back into the pool.
|
||||
with self._max_connecting_cond:
|
||||
self._raise_if_not_ready(emit_event=False)
|
||||
while not (self.sockets or self._pending < self._max_connecting):
|
||||
while not (self.conns or self._pending < self._max_connecting):
|
||||
if not _cond_wait(self._max_connecting_cond, deadline):
|
||||
# Timed out, notify the next thread to ensure a
|
||||
# timeout doesn't consume the condition.
|
||||
if self.sockets or self._pending < self._max_connecting:
|
||||
if self.conns or self._pending < self._max_connecting:
|
||||
self._max_connecting_cond.notify()
|
||||
emitted_event = True
|
||||
self._raise_wait_queue_timeout()
|
||||
self._raise_if_not_ready(emit_event=False)
|
||||
|
||||
try:
|
||||
sock_info = self.sockets.popleft()
|
||||
conn = self.conns.popleft()
|
||||
except IndexError:
|
||||
self._pending += 1
|
||||
if sock_info: # We got a socket from the pool
|
||||
if self._perished(sock_info):
|
||||
sock_info = None
|
||||
if conn: # We got a socket from the pool
|
||||
if self._perished(conn):
|
||||
conn = None
|
||||
continue
|
||||
else: # We need to create a new connection
|
||||
try:
|
||||
sock_info = self.connect(handler=handler)
|
||||
conn = self.connect(handler=handler)
|
||||
finally:
|
||||
with self._max_connecting_cond:
|
||||
self._pending -= 1
|
||||
self._max_connecting_cond.notify()
|
||||
except BaseException:
|
||||
if sock_info:
|
||||
if conn:
|
||||
# We checked out a socket but authentication failed.
|
||||
sock_info.close_socket(ConnectionClosedReason.ERROR)
|
||||
conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
with self.size_cond:
|
||||
self.requests -= 1
|
||||
if incremented:
|
||||
@ -1654,45 +1654,45 @@ class Pool:
|
||||
)
|
||||
raise
|
||||
|
||||
sock_info.active = True
|
||||
return sock_info
|
||||
conn.active = True
|
||||
return conn
|
||||
|
||||
def return_socket(self, sock_info):
|
||||
"""Return the socket to the pool, or if it's closed discard it.
|
||||
def checkin(self, conn):
|
||||
"""Return the connection to the pool, or if it's closed discard it.
|
||||
|
||||
:Parameters:
|
||||
- `sock_info`: The socket to check into the pool.
|
||||
- `conn`: The connection to check into the pool.
|
||||
"""
|
||||
txn = sock_info.pinned_txn
|
||||
cursor = sock_info.pinned_cursor
|
||||
sock_info.active = False
|
||||
sock_info.pinned_txn = False
|
||||
sock_info.pinned_cursor = False
|
||||
self.__pinned_sockets.discard(sock_info)
|
||||
txn = conn.pinned_txn
|
||||
cursor = conn.pinned_cursor
|
||||
conn.active = False
|
||||
conn.pinned_txn = False
|
||||
conn.pinned_cursor = False
|
||||
self.__pinned_sockets.discard(conn)
|
||||
listeners = self.opts._event_listeners
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_checked_in(self.address, sock_info.id)
|
||||
listeners.publish_connection_checked_in(self.address, conn.id)
|
||||
if self.pid != os.getpid():
|
||||
self.reset_without_pause()
|
||||
else:
|
||||
if self.closed:
|
||||
sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED)
|
||||
elif sock_info.closed:
|
||||
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
|
||||
elif conn.closed:
|
||||
# CMAP requires the closed event be emitted after the check in.
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_closed(
|
||||
self.address, sock_info.id, ConnectionClosedReason.ERROR
|
||||
self.address, conn.id, ConnectionClosedReason.ERROR
|
||||
)
|
||||
else:
|
||||
with self.lock:
|
||||
# Hold the lock to ensure this section does not race with
|
||||
# Pool.reset().
|
||||
if self.stale_generation(sock_info.generation, sock_info.service_id):
|
||||
sock_info.close_socket(ConnectionClosedReason.STALE)
|
||||
if self.stale_generation(conn.generation, conn.service_id):
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
else:
|
||||
sock_info.update_last_checkin_time()
|
||||
sock_info.update_is_writable(self.is_writable)
|
||||
self.sockets.appendleft(sock_info)
|
||||
conn.update_last_checkin_time()
|
||||
conn.update_is_writable(self.is_writable)
|
||||
self.conns.appendleft(conn)
|
||||
# Notify any threads waiting to create a connection.
|
||||
self._max_connecting_cond.notify()
|
||||
|
||||
@ -1706,7 +1706,7 @@ class Pool:
|
||||
self.operation_count -= 1
|
||||
self.size_cond.notify()
|
||||
|
||||
def _perished(self, sock_info):
|
||||
def _perished(self, conn):
|
||||
"""Return True and close the connection if it is "perished".
|
||||
|
||||
This side-effecty function checks if this socket has been idle for
|
||||
@ -1720,24 +1720,24 @@ class Pool:
|
||||
pool, to keep performance reasonable - we can't avoid AutoReconnects
|
||||
completely anyway.
|
||||
"""
|
||||
idle_time_seconds = sock_info.idle_time_seconds()
|
||||
idle_time_seconds = conn.idle_time_seconds()
|
||||
# If socket is idle, open a new one.
|
||||
if (
|
||||
self.opts.max_idle_time_seconds is not None
|
||||
and idle_time_seconds > self.opts.max_idle_time_seconds
|
||||
):
|
||||
sock_info.close_socket(ConnectionClosedReason.IDLE)
|
||||
conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
return True
|
||||
|
||||
if self._check_interval_seconds is not None and (
|
||||
0 == self._check_interval_seconds or idle_time_seconds > self._check_interval_seconds
|
||||
):
|
||||
if sock_info.socket_closed():
|
||||
sock_info.close_socket(ConnectionClosedReason.ERROR)
|
||||
if conn.conn_closed():
|
||||
conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
return True
|
||||
|
||||
if self.stale_generation(sock_info.generation, sock_info.service_id):
|
||||
sock_info.close_socket(ConnectionClosedReason.STALE)
|
||||
if self.stale_generation(conn.generation, conn.service_id):
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -1772,5 +1772,5 @@ class Pool:
|
||||
# Avoid ResourceWarnings in Python 3
|
||||
# Close all sockets without calling reset() or close() because it is
|
||||
# not safe to acquire a lock in __del__.
|
||||
for sock_info in self.sockets:
|
||||
sock_info.close_socket(None)
|
||||
for conn in self.conns:
|
||||
conn.close_conn(None)
|
||||
|
||||
@ -81,7 +81,7 @@ def _is_ip_address(address):
|
||||
return False
|
||||
|
||||
|
||||
# According to the docs for Connection.send it can raise
|
||||
# According to the docs for socket.send it can raise
|
||||
# WantX509LookupError and should be retried.
|
||||
BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError)
|
||||
|
||||
@ -347,7 +347,7 @@ class SSLContext:
|
||||
server_hostname=None,
|
||||
session=None,
|
||||
):
|
||||
"""Wrap an existing Python socket sock and return a TLS socket
|
||||
"""Wrap an existing Python socket connection and return a TLS socket
|
||||
object.
|
||||
"""
|
||||
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
|
||||
|
||||
@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from pymongo.message import _OpMsg, _OpReply
|
||||
from pymongo.pool import SocketInfo
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.typings import _Address
|
||||
|
||||
|
||||
@ -85,13 +85,13 @@ class Response:
|
||||
|
||||
|
||||
class PinnedResponse(Response):
|
||||
__slots__ = ("_socket_info", "_more_to_come")
|
||||
__slots__ = ("_conn", "_more_to_come")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Union[_OpMsg, _OpReply],
|
||||
address: _Address,
|
||||
socket_info: SocketInfo,
|
||||
conn: Connection,
|
||||
request_id: int,
|
||||
duration: Optional[timedelta],
|
||||
from_command: bool,
|
||||
@ -103,7 +103,7 @@ class PinnedResponse(Response):
|
||||
:Parameters:
|
||||
- `data`: A network response message.
|
||||
- `address`: (host, port) of the source server.
|
||||
- `socket_info`: The SocketInfo used for the initial query.
|
||||
- `conn`: The Connection used for the initial query.
|
||||
- `request_id`: The request id of this operation.
|
||||
- `duration`: The duration of the operation.
|
||||
- `from_command`: If the response is the result of a db command.
|
||||
@ -112,18 +112,18 @@ class PinnedResponse(Response):
|
||||
exhausted.
|
||||
"""
|
||||
super().__init__(data, address, request_id, duration, from_command, docs)
|
||||
self._socket_info = socket_info
|
||||
self._conn = conn
|
||||
self._more_to_come = more_to_come
|
||||
|
||||
@property
|
||||
def socket_info(self) -> SocketInfo:
|
||||
"""The SocketInfo used for the initial query.
|
||||
def conn(self) -> Connection:
|
||||
"""The Connection used for the initial query.
|
||||
|
||||
The server will send batches on this socket, without waiting for
|
||||
getMores from the client, until the result set is exhausted or there
|
||||
is an error.
|
||||
"""
|
||||
return self._socket_info
|
||||
return self._conn
|
||||
|
||||
@property
|
||||
def more_to_come(self) -> bool:
|
||||
|
||||
@ -42,7 +42,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.mongo_client import _MongoClientErrorHandler
|
||||
from pymongo.monitor import Monitor
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.pool import Pool, SocketInfo
|
||||
from pymongo.pool import Connection, Pool
|
||||
from pymongo.server_description import ServerDescription
|
||||
|
||||
_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}
|
||||
@ -105,7 +105,7 @@ class Server:
|
||||
@_handle_reauth
|
||||
def run_operation(
|
||||
self,
|
||||
sock_info: SocketInfo,
|
||||
conn: Connection,
|
||||
operation: Union[_Query, _GetMore],
|
||||
read_preference: bool,
|
||||
listeners: _EventListeners,
|
||||
@ -118,7 +118,7 @@ class Server:
|
||||
Can raise ConnectionFailure, OperationFailure, etc.
|
||||
|
||||
:Parameters:
|
||||
- `sock_info`: A SocketInfo instance.
|
||||
- `conn`: A Connection instance.
|
||||
- `operation`: A _Query or _GetMore object.
|
||||
- `read_preference`: The read preference to use.
|
||||
- `listeners`: Instance of _EventListeners or None.
|
||||
@ -129,27 +129,27 @@ class Server:
|
||||
if publish:
|
||||
start = datetime.now()
|
||||
|
||||
use_cmd = operation.use_command(sock_info)
|
||||
more_to_come = operation.sock_mgr and operation.sock_mgr.more_to_come
|
||||
use_cmd = operation.use_command(conn)
|
||||
more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come
|
||||
if more_to_come:
|
||||
request_id = 0
|
||||
else:
|
||||
message = operation.get_message(read_preference, sock_info, use_cmd)
|
||||
message = operation.get_message(read_preference, conn, use_cmd)
|
||||
request_id, data, max_doc_size = self._split_message(message)
|
||||
|
||||
if publish:
|
||||
cmd, dbn = operation.as_command(sock_info)
|
||||
cmd, dbn = operation.as_command(conn)
|
||||
listeners.publish_command_start(
|
||||
cmd, dbn, request_id, sock_info.address, service_id=sock_info.service_id
|
||||
cmd, dbn, request_id, conn.address, service_id=conn.service_id
|
||||
)
|
||||
start = datetime.now()
|
||||
|
||||
try:
|
||||
if more_to_come:
|
||||
reply = sock_info.receive_message(None)
|
||||
reply = conn.receive_message(None)
|
||||
else:
|
||||
sock_info.send_message(data, max_doc_size)
|
||||
reply = sock_info.receive_message(request_id)
|
||||
conn.send_message(data, max_doc_size)
|
||||
reply = conn.receive_message(request_id)
|
||||
|
||||
# Unpack and check for command errors.
|
||||
if use_cmd:
|
||||
@ -168,7 +168,7 @@ class Server:
|
||||
if use_cmd:
|
||||
first = docs[0]
|
||||
operation.client._process_response(first, operation.session)
|
||||
_check_command_response(first, sock_info.max_wire_version)
|
||||
_check_command_response(first, conn.max_wire_version)
|
||||
except Exception as exc:
|
||||
if publish:
|
||||
duration = datetime.now() - start
|
||||
@ -181,8 +181,8 @@ class Server:
|
||||
failure,
|
||||
operation.name,
|
||||
request_id,
|
||||
sock_info.address,
|
||||
service_id=sock_info.service_id,
|
||||
conn.address,
|
||||
service_id=conn.service_id,
|
||||
)
|
||||
raise
|
||||
|
||||
@ -205,8 +205,8 @@ class Server:
|
||||
res,
|
||||
operation.name,
|
||||
request_id,
|
||||
sock_info.address,
|
||||
service_id=sock_info.service_id,
|
||||
conn.address,
|
||||
service_id=conn.service_id,
|
||||
)
|
||||
|
||||
# Decrypt response.
|
||||
@ -219,7 +219,7 @@ class Server:
|
||||
response: Response
|
||||
|
||||
if client._should_pin_cursor(operation.session) or operation.exhaust:
|
||||
sock_info.pin_cursor()
|
||||
conn.pin_cursor()
|
||||
if isinstance(reply, _OpMsg):
|
||||
# In OP_MSG, the server keeps sending only if the
|
||||
# more_to_come flag is set.
|
||||
@ -227,12 +227,12 @@ class Server:
|
||||
else:
|
||||
# In OP_REPLY, the server keeps sending until cursor_id is 0.
|
||||
more_to_come = bool(operation.exhaust and reply.cursor_id)
|
||||
if operation.sock_mgr:
|
||||
operation.sock_mgr.update_exhaust(more_to_come)
|
||||
if operation.conn_mgr:
|
||||
operation.conn_mgr.update_exhaust(more_to_come)
|
||||
response = PinnedResponse(
|
||||
data=reply,
|
||||
address=self._description.address,
|
||||
socket_info=sock_info,
|
||||
conn=conn,
|
||||
duration=duration,
|
||||
request_id=request_id,
|
||||
from_command=use_cmd,
|
||||
@ -251,10 +251,10 @@ class Server:
|
||||
|
||||
return response
|
||||
|
||||
def get_socket(
|
||||
def checkout(
|
||||
self, handler: Optional[_MongoClientErrorHandler] = None
|
||||
) -> ContextManager[SocketInfo]:
|
||||
return self.pool.get_socket(handler)
|
||||
) -> ContextManager[Connection]:
|
||||
return self.pool.checkout(handler)
|
||||
|
||||
@property
|
||||
def description(self) -> ServerDescription:
|
||||
|
||||
@ -106,7 +106,7 @@ class TestHandshake(unittest.TestCase):
|
||||
|
||||
self.addCleanup(client.close)
|
||||
|
||||
# New monitoring sockets send data during handshake.
|
||||
# New monitoring connections send data during handshake.
|
||||
heartbeat = primary.receives("ismaster")
|
||||
_check_handshake_data(heartbeat)
|
||||
heartbeat.ok(primary_response)
|
||||
@ -169,7 +169,7 @@ class TestHandshake(unittest.TestCase):
|
||||
|
||||
self.addCleanup(client.close)
|
||||
|
||||
# New monitoring sockets send data during handshake.
|
||||
# New monitoring connections send data during handshake.
|
||||
heartbeat = server.receives("ismaster")
|
||||
heartbeat.ok(primary_response)
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ class MockPool(Pool):
|
||||
Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_socket(self, handler=None):
|
||||
def checkout(self, handler=None):
|
||||
client = self.client
|
||||
host_and_port = f"{self.mock_host}:{self.mock_port}"
|
||||
if host_and_port in client.mock_down_hosts:
|
||||
@ -48,10 +48,10 @@ class MockPool(Pool):
|
||||
client.mock_standalones + client.mock_members + client.mock_mongoses
|
||||
), ("bad host: %s" % host_and_port)
|
||||
|
||||
with Pool.get_socket(self, handler) as sock_info:
|
||||
sock_info.mock_host = self.mock_host
|
||||
sock_info.mock_port = self.mock_port
|
||||
yield sock_info
|
||||
with Pool.checkout(self, handler) as conn:
|
||||
conn.mock_host = self.mock_host
|
||||
conn.mock_port = self.mock_port
|
||||
yield conn
|
||||
|
||||
|
||||
class DummyMonitor:
|
||||
|
||||
@ -60,7 +60,7 @@ class AutoAuthenticateThread(threading.Thread):
|
||||
"""Used in testing threaded authentication.
|
||||
|
||||
This does collection.find_one() with a 1-second delay to ensure it must
|
||||
check out and authenticate multiple sockets from the pool concurrently.
|
||||
check out and authenticate multiple connections from the pool concurrently.
|
||||
|
||||
:Parameters:
|
||||
`collection`: An auth-protected collection containing one document.
|
||||
@ -217,7 +217,7 @@ class TestGSSAPI(unittest.TestCase):
|
||||
|
||||
# Need one document in the collection. AutoAuthenticateThread does
|
||||
# collection.find_one with a 1-second delay, forcing it to check out
|
||||
# multiple sockets from the pool concurrently, proving that
|
||||
# multiple connections from the pool concurrently, proving that
|
||||
# auto-authentication works with GSSAPI.
|
||||
collection = db.test
|
||||
if not collection.count_documents({}):
|
||||
|
||||
@ -96,7 +96,7 @@ from pymongo.errors import (
|
||||
)
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent
|
||||
from pymongo.pool import _METADATA, PoolOptions, SocketInfo
|
||||
from pymongo.pool import _METADATA, Connection, PoolOptions
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import readable_server_selector, writable_server_selector
|
||||
@ -538,13 +538,13 @@ class TestClient(IntegrationTest):
|
||||
|
||||
def test_max_idle_time_reaper_default(self):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
# Assert reaper doesn't remove sockets when maxIdleTimeMS not set
|
||||
# Assert reaper doesn't remove connections when maxIdleTimeMS not set
|
||||
client = rs_or_single_client()
|
||||
server = client._get_topology().select_server(readable_server_selector)
|
||||
with server._pool.get_socket() as sock_info:
|
||||
with server._pool.checkout() as conn:
|
||||
pass
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
self.assertTrue(sock_info in server._pool.sockets)
|
||||
self.assertEqual(1, len(server._pool.conns))
|
||||
self.assertTrue(conn in server._pool.conns)
|
||||
client.close()
|
||||
|
||||
def test_max_idle_time_reaper_removes_stale_minPoolSize(self):
|
||||
@ -552,27 +552,27 @@ class TestClient(IntegrationTest):
|
||||
# Assert reaper removes idle socket and replaces it with a new one
|
||||
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1)
|
||||
server = client._get_topology().select_server(readable_server_selector)
|
||||
with server._pool.get_socket() as sock_info:
|
||||
with server._pool.checkout() as conn:
|
||||
pass
|
||||
# When the reaper runs at the same time as the get_socket, two
|
||||
# sockets could be created and checked into the pool.
|
||||
self.assertGreaterEqual(len(server._pool.sockets), 1)
|
||||
wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket")
|
||||
wait_until(lambda: 1 <= len(server._pool.sockets), "replace stale socket")
|
||||
# connections could be created and checked into the pool.
|
||||
self.assertGreaterEqual(len(server._pool.conns), 1)
|
||||
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||
wait_until(lambda: 1 <= len(server._pool.conns), "replace stale socket")
|
||||
client.close()
|
||||
|
||||
def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
# Assert reaper respects maxPoolSize when adding new sockets.
|
||||
# Assert reaper respects maxPoolSize when adding new connections.
|
||||
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1)
|
||||
server = client._get_topology().select_server(readable_server_selector)
|
||||
with server._pool.get_socket() as sock_info:
|
||||
with server._pool.checkout() as conn:
|
||||
pass
|
||||
# When the reaper runs at the same time as the get_socket,
|
||||
# maxPoolSize=1 should prevent two sockets from being created.
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket")
|
||||
wait_until(lambda: 1 == len(server._pool.sockets), "replace stale socket")
|
||||
# maxPoolSize=1 should prevent two connections from being created.
|
||||
self.assertEqual(1, len(server._pool.conns))
|
||||
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||
wait_until(lambda: 1 == len(server._pool.conns), "replace stale socket")
|
||||
client.close()
|
||||
|
||||
def test_max_idle_time_reaper_removes_stale(self):
|
||||
@ -580,15 +580,15 @@ class TestClient(IntegrationTest):
|
||||
# Assert reaper has removed idle socket and NOT replaced it
|
||||
client = rs_or_single_client(maxIdleTimeMS=500)
|
||||
server = client._get_topology().select_server(readable_server_selector)
|
||||
with server._pool.get_socket() as sock_info_one:
|
||||
with server._pool.checkout() as conn_one:
|
||||
pass
|
||||
# Assert that the pool does not close sockets prematurely.
|
||||
# Assert that the pool does not close connections prematurely.
|
||||
time.sleep(0.300)
|
||||
with server._pool.get_socket() as sock_info_two:
|
||||
with server._pool.checkout() as conn_two:
|
||||
pass
|
||||
self.assertIs(sock_info_one, sock_info_two)
|
||||
self.assertIs(conn_one, conn_two)
|
||||
wait_until(
|
||||
lambda: 0 == len(server._pool.sockets),
|
||||
lambda: 0 == len(server._pool.conns),
|
||||
"stale socket reaped and new one NOT added to the pool",
|
||||
)
|
||||
client.close()
|
||||
@ -597,48 +597,50 @@ class TestClient(IntegrationTest):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
client = rs_or_single_client()
|
||||
server = client._get_topology().select_server(readable_server_selector)
|
||||
self.assertEqual(0, len(server._pool.sockets))
|
||||
self.assertEqual(0, len(server._pool.conns))
|
||||
|
||||
# Assert that pool started up at minPoolSize
|
||||
client = rs_or_single_client(minPoolSize=10)
|
||||
server = client._get_topology().select_server(readable_server_selector)
|
||||
wait_until(lambda: 10 == len(server._pool.sockets), "pool initialized with 10 sockets")
|
||||
wait_until(
|
||||
lambda: 10 == len(server._pool.conns), "pool initialized with 10 connections"
|
||||
)
|
||||
|
||||
# Assert that if a socket is closed, a new one takes its place
|
||||
with server._pool.get_socket() as sock_info:
|
||||
sock_info.close_socket(None)
|
||||
with server._pool.checkout() as conn:
|
||||
conn.close_conn(None)
|
||||
wait_until(
|
||||
lambda: 10 == len(server._pool.sockets),
|
||||
lambda: 10 == len(server._pool.conns),
|
||||
"a closed socket gets replaced from the pool",
|
||||
)
|
||||
self.assertFalse(sock_info in server._pool.sockets)
|
||||
self.assertFalse(conn in server._pool.conns)
|
||||
|
||||
def test_max_idle_time_checkout(self):
|
||||
# Use high frequency to test _get_socket_no_auth.
|
||||
with client_knobs(kill_cursor_frequency=99999999):
|
||||
client = rs_or_single_client(maxIdleTimeMS=500)
|
||||
server = client._get_topology().select_server(readable_server_selector)
|
||||
with server._pool.get_socket() as sock_info:
|
||||
with server._pool.checkout() as conn:
|
||||
pass
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
self.assertEqual(1, len(server._pool.conns))
|
||||
time.sleep(1) # Sleep so that the socket becomes stale.
|
||||
|
||||
with server._pool.get_socket() as new_sock_info:
|
||||
self.assertNotEqual(sock_info, new_sock_info)
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
self.assertFalse(sock_info in server._pool.sockets)
|
||||
self.assertTrue(new_sock_info in server._pool.sockets)
|
||||
with server._pool.checkout() as new_con:
|
||||
self.assertNotEqual(conn, new_con)
|
||||
self.assertEqual(1, len(server._pool.conns))
|
||||
self.assertFalse(conn in server._pool.conns)
|
||||
self.assertTrue(new_con in server._pool.conns)
|
||||
|
||||
# Test that sockets are reused if maxIdleTimeMS is not set.
|
||||
# Test that connections are reused if maxIdleTimeMS is not set.
|
||||
client = rs_or_single_client()
|
||||
server = client._get_topology().select_server(readable_server_selector)
|
||||
with server._pool.get_socket() as sock_info:
|
||||
with server._pool.checkout() as conn:
|
||||
pass
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
self.assertEqual(1, len(server._pool.conns))
|
||||
time.sleep(1)
|
||||
with server._pool.get_socket() as new_sock_info:
|
||||
self.assertEqual(sock_info, new_sock_info)
|
||||
self.assertEqual(1, len(server._pool.sockets))
|
||||
with server._pool.checkout() as new_con:
|
||||
self.assertEqual(conn, new_con)
|
||||
self.assertEqual(1, len(server._pool.conns))
|
||||
|
||||
def test_constants(self):
|
||||
"""This test uses MongoClient explicitly to make sure that host and
|
||||
@ -933,11 +935,11 @@ class TestClient(IntegrationTest):
|
||||
topology = client._topology
|
||||
client.close()
|
||||
for server in topology._servers.values():
|
||||
self.assertFalse(server._pool.sockets)
|
||||
self.assertFalse(server._pool.conns)
|
||||
self.assertTrue(server._monitor._executor._stopped)
|
||||
self.assertTrue(server._monitor._rtt_monitor._executor._stopped)
|
||||
self.assertFalse(server._monitor._pool.sockets)
|
||||
self.assertFalse(server._monitor._rtt_monitor._pool.sockets)
|
||||
self.assertFalse(server._monitor._pool.conns)
|
||||
self.assertFalse(server._monitor._rtt_monitor._pool.conns)
|
||||
|
||||
def test_bad_uri(self):
|
||||
with self.assertRaises(InvalidURI):
|
||||
@ -1130,8 +1132,8 @@ class TestClient(IntegrationTest):
|
||||
|
||||
def test_socketKeepAlive(self):
|
||||
pool = get_pool(self.client)
|
||||
with pool.get_socket() as sock_info:
|
||||
keepalive = sock_info.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
|
||||
with pool.checkout() as conn:
|
||||
keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
|
||||
self.assertTrue(keepalive)
|
||||
|
||||
@no_type_check
|
||||
@ -1184,7 +1186,7 @@ class TestClient(IntegrationTest):
|
||||
|
||||
# The socket used for the previous commands has been returned to the
|
||||
# pool
|
||||
self.assertEqual(1, len(get_pool(client).sockets))
|
||||
self.assertEqual(1, len(get_pool(client).conns))
|
||||
|
||||
with contextlib.closing(client):
|
||||
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
|
||||
@ -1223,7 +1225,7 @@ class TestClient(IntegrationTest):
|
||||
# main thread while find() is in-progress: On Windows, SIGALRM is
|
||||
# unavailable so we use a second thread. In our Evergreen setup on
|
||||
# Linux, the thread technique causes an error in the test at
|
||||
# sock.recv(): TypeError: 'int' object is not callable
|
||||
# conn.recv(): TypeError: 'int' object is not callable
|
||||
# We don't know what causes this, so we hack around it.
|
||||
|
||||
if sys.platform == "win32":
|
||||
@ -1271,16 +1273,16 @@ class TestClient(IntegrationTest):
|
||||
self.addCleanup(client.close)
|
||||
client.pymongo_test.test.find_one()
|
||||
pool = get_pool(client)
|
||||
socket_count = len(pool.sockets)
|
||||
socket_count = len(pool.conns)
|
||||
self.assertGreaterEqual(socket_count, 1)
|
||||
old_sock_info = next(iter(pool.sockets))
|
||||
old_conn = next(iter(pool.conns))
|
||||
client.pymongo_test.test.drop()
|
||||
client.pymongo_test.test.insert_one({"_id": "foo"})
|
||||
self.assertRaises(OperationFailure, client.pymongo_test.test.insert_one, {"_id": "foo"})
|
||||
|
||||
self.assertEqual(socket_count, len(pool.sockets))
|
||||
new_sock_info = next(iter(pool.sockets))
|
||||
self.assertEqual(old_sock_info, new_sock_info)
|
||||
self.assertEqual(socket_count, len(pool.conns))
|
||||
new_con = next(iter(pool.conns))
|
||||
self.assertEqual(old_conn, new_con)
|
||||
|
||||
def test_lazy_connect_w0(self):
|
||||
# Ensure that connect-on-demand works when the first operation is
|
||||
@ -1326,13 +1328,13 @@ class TestClient(IntegrationTest):
|
||||
connected(client)
|
||||
|
||||
# Cause a network error.
|
||||
sock_info = one(pool.sockets)
|
||||
sock_info.sock.close()
|
||||
conn = one(pool.conns)
|
||||
conn.conn.close()
|
||||
cursor = collection.find(cursor_type=CursorType.EXHAUST)
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
next(cursor)
|
||||
|
||||
self.assertTrue(sock_info.closed)
|
||||
self.assertTrue(conn.closed)
|
||||
|
||||
# The semaphore was decremented despite the error.
|
||||
self.assertEqual(0, pool.requests)
|
||||
@ -1347,10 +1349,10 @@ class TestClient(IntegrationTest):
|
||||
|
||||
# Cause a network error on the actual socket.
|
||||
pool = get_pool(c)
|
||||
socket_info = one(pool.sockets)
|
||||
socket_info.sock.close()
|
||||
socket_info = one(pool.conns)
|
||||
socket_info.conn.close()
|
||||
|
||||
# SocketInfo.authenticate logs, but gets a socket.error. Should be
|
||||
# Connection.authenticate logs, but gets a socket.error. Should be
|
||||
# reraised as AutoReconnect.
|
||||
self.assertRaises(AutoReconnect, c.test.collection.find_one)
|
||||
|
||||
@ -1586,7 +1588,7 @@ class TestClient(IntegrationTest):
|
||||
self.addCleanup(delattr, pool, "connect")
|
||||
|
||||
# Wait for the background thread to start creating connections
|
||||
wait_until(lambda: len(pool.sockets) > 1, "start creating connections")
|
||||
wait_until(lambda: len(pool.conns) > 1, "start creating connections")
|
||||
|
||||
# Assert that application operations do not block.
|
||||
for _ in range(10):
|
||||
@ -1847,7 +1849,7 @@ class TestExhaustCursor(IntegrationTest):
|
||||
|
||||
collection = client.pymongo_test.test
|
||||
pool = get_pool(client)
|
||||
sock_info = one(pool.sockets)
|
||||
conn = one(pool.conns)
|
||||
|
||||
# This will cause OperationFailure in all mongo versions since
|
||||
# the value for $orderby must be a document.
|
||||
@ -1856,10 +1858,10 @@ class TestExhaustCursor(IntegrationTest):
|
||||
)
|
||||
|
||||
self.assertRaises(OperationFailure, cursor.next)
|
||||
self.assertFalse(sock_info.closed)
|
||||
self.assertFalse(conn.closed)
|
||||
|
||||
# The socket was checked in and the semaphore was decremented.
|
||||
self.assertIn(sock_info, pool.sockets)
|
||||
self.assertIn(conn, pool.conns)
|
||||
self.assertEqual(0, pool.requests)
|
||||
|
||||
def test_exhaust_getmore_server_error(self):
|
||||
@ -1874,7 +1876,7 @@ class TestExhaustCursor(IntegrationTest):
|
||||
|
||||
pool = get_pool(client)
|
||||
pool._check_interval_seconds = None # Never check.
|
||||
sock_info = one(pool.sockets)
|
||||
conn = one(pool.conns)
|
||||
|
||||
cursor = collection.find(cursor_type=CursorType.EXHAUST)
|
||||
|
||||
@ -1884,21 +1886,21 @@ class TestExhaustCursor(IntegrationTest):
|
||||
# Cause a server error on getmore.
|
||||
def receive_message(request_id):
|
||||
# Discard the actual server response.
|
||||
SocketInfo.receive_message(sock_info, request_id)
|
||||
Connection.receive_message(conn, request_id)
|
||||
|
||||
# responseFlags bit 1 is QueryFailure.
|
||||
msg = struct.pack("<iiiii", 1 << 1, 0, 0, 0, 0)
|
||||
msg += encode({"$err": "mock err", "code": 0})
|
||||
return message._OpReply.unpack(msg)
|
||||
|
||||
sock_info.receive_message = receive_message
|
||||
conn.receive_message = receive_message
|
||||
self.assertRaises(OperationFailure, list, cursor)
|
||||
# Unpatch the instance.
|
||||
del sock_info.receive_message
|
||||
del conn.receive_message
|
||||
|
||||
# The socket is returned to the pool and it still works.
|
||||
self.assertEqual(200, collection.count_documents({}))
|
||||
self.assertIn(sock_info, pool.sockets)
|
||||
self.assertIn(conn, pool.conns)
|
||||
|
||||
def test_exhaust_query_network_error(self):
|
||||
# When doing an exhaust query, the socket stays checked out on success
|
||||
@ -1909,15 +1911,15 @@ class TestExhaustCursor(IntegrationTest):
|
||||
pool._check_interval_seconds = None # Never check.
|
||||
|
||||
# Cause a network error.
|
||||
sock_info = one(pool.sockets)
|
||||
sock_info.sock.close()
|
||||
conn = one(pool.conns)
|
||||
conn.conn.close()
|
||||
|
||||
cursor = collection.find(cursor_type=CursorType.EXHAUST)
|
||||
self.assertRaises(ConnectionFailure, cursor.next)
|
||||
self.assertTrue(sock_info.closed)
|
||||
self.assertTrue(conn.closed)
|
||||
|
||||
# The socket was closed and the semaphore was decremented.
|
||||
self.assertNotIn(sock_info, pool.sockets)
|
||||
self.assertNotIn(conn, pool.conns)
|
||||
self.assertEqual(0, pool.requests)
|
||||
|
||||
def test_exhaust_getmore_network_error(self):
|
||||
@ -1936,15 +1938,15 @@ class TestExhaustCursor(IntegrationTest):
|
||||
cursor.next()
|
||||
|
||||
# Cause a network error.
|
||||
sock_info = cursor._Cursor__sock_mgr.sock
|
||||
sock_info.sock.close()
|
||||
conn = cursor._Cursor__sock_mgr.conn
|
||||
conn.conn.close()
|
||||
|
||||
# A getmore fails.
|
||||
self.assertRaises(ConnectionFailure, list, cursor)
|
||||
self.assertTrue(sock_info.closed)
|
||||
self.assertTrue(conn.closed)
|
||||
|
||||
# The socket was closed and the semaphore was decremented.
|
||||
self.assertNotIn(sock_info, pool.sockets)
|
||||
self.assertNotIn(conn, pool.conns)
|
||||
self.assertEqual(0, pool.requests)
|
||||
|
||||
def test_gevent_task(self):
|
||||
@ -2243,10 +2245,10 @@ class TestClientPool(MockClientTest):
|
||||
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2)
|
||||
# Assert that we do not create connections to arbiters.
|
||||
arbiter = c._topology.get_server_by_address(("c", 3))
|
||||
self.assertFalse(arbiter.pool.sockets)
|
||||
self.assertFalse(arbiter.pool.conns)
|
||||
# Assert that we do not create connections to unknown servers.
|
||||
arbiter = c._topology.get_server_by_address(("d", 4))
|
||||
self.assertFalse(arbiter.pool.sockets)
|
||||
self.assertFalse(arbiter.pool.conns)
|
||||
# Arbiter pool is not marked ready.
|
||||
self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 2)
|
||||
|
||||
@ -2271,7 +2273,7 @@ class TestClientPool(MockClientTest):
|
||||
listener.wait_for_event(monitoring.ConnectionReadyEvent, 1)
|
||||
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
|
||||
arbiter = c._topology.get_server_by_address(("c", 3))
|
||||
self.assertEqual(len(arbiter.pool.sockets), 1)
|
||||
self.assertEqual(len(arbiter.pool.conns), 1)
|
||||
# Arbiter pool is marked ready.
|
||||
self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 1)
|
||||
|
||||
|
||||
@ -123,19 +123,19 @@ class TestCMAP(IntegrationTest):
|
||||
def check_out(self, op):
|
||||
"""Run the 'checkOut' operation."""
|
||||
label = op["label"]
|
||||
with self.pool.get_socket() as sock_info:
|
||||
with self.pool.checkout() as conn:
|
||||
# Call 'pin_cursor' so we can hold the socket.
|
||||
sock_info.pin_cursor()
|
||||
conn.pin_cursor()
|
||||
if label:
|
||||
self.labels[label] = sock_info
|
||||
self.labels[label] = conn
|
||||
else:
|
||||
self.addCleanup(sock_info.close_socket, None)
|
||||
self.addCleanup(conn.close_conn, None)
|
||||
|
||||
def check_in(self, op):
|
||||
"""Run the 'checkIn' operation."""
|
||||
label = op["connection"]
|
||||
sock_info = self.labels[label]
|
||||
self.pool.return_socket(sock_info)
|
||||
conn = self.labels[label]
|
||||
self.pool.checkin(conn)
|
||||
|
||||
def ready(self, op):
|
||||
"""Run the 'ready' operation."""
|
||||
@ -270,7 +270,7 @@ class TestCMAP(IntegrationTest):
|
||||
for t in self.targets.values():
|
||||
t.join(5)
|
||||
for conn in self.labels.values():
|
||||
conn.close_socket(None)
|
||||
conn.close_conn(None)
|
||||
|
||||
self.addCleanup(cleanup)
|
||||
|
||||
@ -444,7 +444,7 @@ class TestCMAP(IntegrationTest):
|
||||
self.assertEqual(1, listener.event_count(PoolClearedEvent))
|
||||
self.assertEqual(PoolState.READY, pool.state)
|
||||
# Checking out a connection should succeed
|
||||
with pool.get_socket():
|
||||
with pool.checkout():
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@ -1721,15 +1721,15 @@ class TestCollection(IntegrationTest):
|
||||
# Make sure the socket is returned after exhaustion.
|
||||
cur = client[self.db.name].test.find(cursor_type=CursorType.EXHAUST)
|
||||
next(cur)
|
||||
self.assertEqual(0, len(pool.sockets))
|
||||
self.assertEqual(0, len(pool.conns))
|
||||
for _ in cur:
|
||||
pass
|
||||
self.assertEqual(1, len(pool.sockets))
|
||||
self.assertEqual(1, len(pool.conns))
|
||||
|
||||
# Same as previous but don't call next()
|
||||
for _ in client[self.db.name].test.find(cursor_type=CursorType.EXHAUST):
|
||||
pass
|
||||
self.assertEqual(1, len(pool.sockets))
|
||||
self.assertEqual(1, len(pool.conns))
|
||||
|
||||
# If the Cursor instance is discarded before being completely iterated
|
||||
# and the socket has pending data (more_to_come=True) we have to close
|
||||
@ -1742,7 +1742,7 @@ class TestCollection(IntegrationTest):
|
||||
next(cur)
|
||||
else:
|
||||
next(cur)
|
||||
self.assertEqual(0, len(pool.sockets))
|
||||
self.assertEqual(0, len(pool.conns))
|
||||
if sys.platform.startswith("java") or "PyPy" in sys.version:
|
||||
# Don't wait for GC or use gc.collect(), it's unreliable.
|
||||
cur.close()
|
||||
@ -1750,7 +1750,7 @@ class TestCollection(IntegrationTest):
|
||||
# Wait until the background thread returns the socket.
|
||||
wait_until(lambda: pool.active_sockets == 0, "return socket")
|
||||
# The socket should be discarded.
|
||||
self.assertEqual(0, len(pool.sockets))
|
||||
self.assertEqual(0, len(pool.conns))
|
||||
|
||||
def test_distinct(self):
|
||||
self.db.drop_collection("test")
|
||||
|
||||
@ -47,7 +47,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
|
||||
)
|
||||
|
||||
# Ensure connections to all servers in replica set. This is to test
|
||||
# that the is_writable flag is properly updated for sockets that
|
||||
# that the is_writable flag is properly updated for connections that
|
||||
# survive a replica set election.
|
||||
ensure_all_connected(cls.client)
|
||||
cls.listener.reset()
|
||||
|
||||
@ -94,7 +94,7 @@ def got_app_error(topology, app_error):
|
||||
when = app_error["when"]
|
||||
max_wire_version = app_error["maxWireVersion"]
|
||||
# XXX: We could get better test coverage by mocking the errors on the
|
||||
# Pool/SocketInfo.
|
||||
# Pool/Connection.
|
||||
try:
|
||||
if error_type == "command":
|
||||
_check_command_response(app_error["response"], max_wire_version)
|
||||
@ -274,14 +274,14 @@ class TestIgnoreStaleErrors(IntegrationTest):
|
||||
client.admin.command("ping")
|
||||
pool = get_pool(client)
|
||||
starting_generation = pool.gen.get_overall()
|
||||
wait_until(lambda: len(pool.sockets) == N_THREADS, "created sockets")
|
||||
wait_until(lambda: len(pool.conns) == N_THREADS, "created conns")
|
||||
|
||||
def mock_command(*args, **kwargs):
|
||||
# Synchronize all threads to ensure they use the same generation.
|
||||
barrier.wait()
|
||||
raise AutoReconnect("mock SocketInfo.command error")
|
||||
raise AutoReconnect("mock Connection.command error")
|
||||
|
||||
for sock in pool.sockets:
|
||||
for sock in pool.conns:
|
||||
sock.command = mock_command
|
||||
|
||||
def insert_command(i):
|
||||
|
||||
@ -41,11 +41,11 @@ class TestLB(IntegrationTest):
|
||||
# Tracked in PYTHON-3011
|
||||
self.skipTest("Test is flaky on PyPy")
|
||||
pool = get_pool(self.client)
|
||||
nconns = len(pool.sockets)
|
||||
n_conns = len(pool.conns)
|
||||
self.db.test.find_one({})
|
||||
self.assertEqual(len(pool.sockets), nconns)
|
||||
self.assertEqual(len(pool.conns), n_conns)
|
||||
list(self.db.test.aggregate([{"$limit": 1}]))
|
||||
self.assertEqual(len(pool.sockets), nconns)
|
||||
self.assertEqual(len(pool.conns), n_conns)
|
||||
|
||||
@client_context.require_load_balancer
|
||||
def test_unpin_committed_transaction(self):
|
||||
|
||||
@ -116,15 +116,15 @@ class SocketGetter(MongoThread):
|
||||
self.state = "get_socket"
|
||||
|
||||
# Call 'pin_cursor' so we can hold the socket.
|
||||
with self.pool.get_socket() as sock:
|
||||
with self.pool.checkout() as sock:
|
||||
sock.pin_cursor()
|
||||
self.sock = sock
|
||||
|
||||
self.state = "sock"
|
||||
self.state = "connection"
|
||||
|
||||
def __del__(self):
|
||||
if self.sock:
|
||||
self.sock.close_socket(None)
|
||||
self.sock.close_conn(None)
|
||||
|
||||
|
||||
def run_cases(client, cases):
|
||||
@ -188,36 +188,36 @@ class TestPooling(_TestPoolingBase):
|
||||
# Test Pool's _check_closed() method doesn't close a healthy socket.
|
||||
cx_pool = self.create_pool(max_pool_size=10)
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
with cx_pool.checkout() as conn:
|
||||
pass
|
||||
|
||||
with cx_pool.get_socket() as new_sock_info:
|
||||
self.assertEqual(sock_info, new_sock_info)
|
||||
with cx_pool.checkout() as new_connection:
|
||||
self.assertEqual(conn, new_connection)
|
||||
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
self.assertEqual(1, len(cx_pool.conns))
|
||||
|
||||
def test_get_socket_and_exception(self):
|
||||
# get_socket() returns socket after a non-network error.
|
||||
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
with cx_pool.checkout() as conn:
|
||||
1 / 0
|
||||
|
||||
# Socket was returned, not closed.
|
||||
with cx_pool.get_socket() as new_sock_info:
|
||||
self.assertEqual(sock_info, new_sock_info)
|
||||
with cx_pool.checkout() as new_connection:
|
||||
self.assertEqual(conn, new_connection)
|
||||
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
self.assertEqual(1, len(cx_pool.conns))
|
||||
|
||||
def test_pool_removes_closed_socket(self):
|
||||
# Test that Pool removes explicitly closed socket.
|
||||
cx_pool = self.create_pool()
|
||||
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
# Use SocketInfo's API to close the socket.
|
||||
sock_info.close_socket(None)
|
||||
with cx_pool.checkout() as conn:
|
||||
# Use Connection's API to close the socket.
|
||||
conn.close_conn(None)
|
||||
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
self.assertEqual(0, len(cx_pool.conns))
|
||||
|
||||
def test_pool_removes_dead_socket(self):
|
||||
# Test that Pool removes dead socket and the socket doesn't return
|
||||
@ -225,20 +225,20 @@ class TestPooling(_TestPoolingBase):
|
||||
cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
# Simulate a closed socket without telling the SocketInfo it's
|
||||
with cx_pool.checkout() as conn:
|
||||
# Simulate a closed socket without telling the Connection it's
|
||||
# closed.
|
||||
sock_info.sock.close()
|
||||
self.assertTrue(sock_info.socket_closed())
|
||||
conn.conn.close()
|
||||
self.assertTrue(conn.conn_closed())
|
||||
|
||||
with cx_pool.get_socket() as new_sock_info:
|
||||
self.assertEqual(0, len(cx_pool.sockets))
|
||||
self.assertNotEqual(sock_info, new_sock_info)
|
||||
with cx_pool.checkout() as new_connection:
|
||||
self.assertEqual(0, len(cx_pool.conns))
|
||||
self.assertNotEqual(conn, new_connection)
|
||||
|
||||
self.assertEqual(1, len(cx_pool.sockets))
|
||||
self.assertEqual(1, len(cx_pool.conns))
|
||||
|
||||
# Semaphore was released.
|
||||
with cx_pool.get_socket():
|
||||
with cx_pool.checkout():
|
||||
pass
|
||||
|
||||
def test_socket_closed(self):
|
||||
@ -282,13 +282,13 @@ class TestPooling(_TestPoolingBase):
|
||||
|
||||
def test_return_socket_after_reset(self):
|
||||
pool = self.create_pool()
|
||||
with pool.get_socket() as sock:
|
||||
with pool.checkout() as sock:
|
||||
self.assertEqual(pool.active_sockets, 1)
|
||||
self.assertEqual(pool.operation_count, 1)
|
||||
pool.reset()
|
||||
|
||||
self.assertTrue(sock.closed)
|
||||
self.assertEqual(0, len(pool.sockets))
|
||||
self.assertEqual(0, len(pool.conns))
|
||||
self.assertEqual(pool.active_sockets, 0)
|
||||
self.assertEqual(pool.operation_count, 0)
|
||||
|
||||
@ -299,20 +299,20 @@ class TestPooling(_TestPoolingBase):
|
||||
cx_pool._check_interval_seconds = 0 # Always check.
|
||||
self.addCleanup(cx_pool.close)
|
||||
|
||||
with cx_pool.get_socket() as sock_info:
|
||||
# Simulate a closed socket without telling the SocketInfo it's
|
||||
with cx_pool.checkout() as conn:
|
||||
# Simulate a closed socket without telling the Connection it's
|
||||
# closed.
|
||||
sock_info.sock.close()
|
||||
conn.conn.close()
|
||||
|
||||
# Swap pool's address with a bad one.
|
||||
address, cx_pool.address = cx_pool.address, ("foo.com", 1234)
|
||||
with self.assertRaises(AutoReconnect):
|
||||
with cx_pool.get_socket():
|
||||
with cx_pool.checkout():
|
||||
pass
|
||||
|
||||
# Back to normal, semaphore was correctly released.
|
||||
cx_pool.address = address
|
||||
with cx_pool.get_socket():
|
||||
with cx_pool.checkout():
|
||||
pass
|
||||
|
||||
def test_wait_queue_timeout(self):
|
||||
@ -320,10 +320,10 @@ class TestPooling(_TestPoolingBase):
|
||||
pool = self.create_pool(max_pool_size=1, wait_queue_timeout=wait_queue_timeout)
|
||||
self.addCleanup(pool.close)
|
||||
|
||||
with pool.get_socket():
|
||||
with pool.checkout():
|
||||
start = time.time()
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
with pool.get_socket():
|
||||
with pool.checkout():
|
||||
pass
|
||||
|
||||
duration = time.time() - start
|
||||
@ -338,7 +338,7 @@ class TestPooling(_TestPoolingBase):
|
||||
self.addCleanup(pool.close)
|
||||
|
||||
# Reach max_size.
|
||||
with pool.get_socket() as s1:
|
||||
with pool.checkout() as s1:
|
||||
t = SocketGetter(self.c, pool)
|
||||
t.start()
|
||||
while t.state != "get_socket":
|
||||
@ -347,10 +347,10 @@ class TestPooling(_TestPoolingBase):
|
||||
time.sleep(1)
|
||||
self.assertEqual(t.state, "get_socket")
|
||||
|
||||
while t.state != "sock":
|
||||
while t.state != "connection":
|
||||
time.sleep(0.1)
|
||||
|
||||
self.assertEqual(t.state, "sock")
|
||||
self.assertEqual(t.state, "connection")
|
||||
self.assertEqual(t.sock, s1)
|
||||
|
||||
def test_checkout_more_than_max_pool_size(self):
|
||||
@ -359,7 +359,7 @@ class TestPooling(_TestPoolingBase):
|
||||
socks = []
|
||||
for _ in range(2):
|
||||
# Call 'pin_cursor' so we can hold the socket.
|
||||
with pool.get_socket() as sock:
|
||||
with pool.checkout() as sock:
|
||||
sock.pin_cursor()
|
||||
socks.append(sock)
|
||||
|
||||
@ -373,7 +373,7 @@ class TestPooling(_TestPoolingBase):
|
||||
self.assertEqual(t.state, "get_socket")
|
||||
|
||||
for socket_info in socks:
|
||||
socket_info.close_socket(None)
|
||||
socket_info.close_conn(None)
|
||||
|
||||
def test_maxConnecting(self):
|
||||
client = rs_or_single_client()
|
||||
@ -394,14 +394,14 @@ class TestPooling(_TestPoolingBase):
|
||||
thread.join(10)
|
||||
|
||||
self.assertEqual(len(docs), 50)
|
||||
self.assertLessEqual(len(pool.sockets), 50)
|
||||
self.assertLessEqual(len(pool.conns), 50)
|
||||
# TLS and auth make connection establishment more expensive than
|
||||
# the query which leads to more threads hitting maxConnecting.
|
||||
# The end result is fewer total connections and better latency.
|
||||
if client_context.tls and client_context.auth_enabled:
|
||||
self.assertLessEqual(len(pool.sockets), 30)
|
||||
self.assertLessEqual(len(pool.conns), 30)
|
||||
else:
|
||||
self.assertLessEqual(len(pool.sockets), 50)
|
||||
self.assertLessEqual(len(pool.conns), 50)
|
||||
# MongoDB 4.4.1 with auth + ssl:
|
||||
# maxConnecting = 2: 6 connections in ~0.231+ seconds
|
||||
# maxConnecting = unbounded: 50 connections in ~0.642+ seconds
|
||||
@ -409,7 +409,7 @@ class TestPooling(_TestPoolingBase):
|
||||
# MongoDB 4.4.1 with no-auth no-ssl Python 3.8:
|
||||
# maxConnecting = 2: 15-22 connections in ~0.108+ seconds
|
||||
# maxConnecting = unbounded: 30+ connections in ~0.140+ seconds
|
||||
print(len(pool.sockets))
|
||||
print(len(pool.conns))
|
||||
|
||||
|
||||
class TestPoolMaxSize(_TestPoolingBase):
|
||||
@ -424,7 +424,7 @@ class TestPoolMaxSize(_TestPoolingBase):
|
||||
collection.insert_one({})
|
||||
|
||||
# nthreads had better be much larger than max_pool_size to ensure that
|
||||
# max_pool_size sockets are actually required at some point in this
|
||||
# max_pool_size connections are actually required at some point in this
|
||||
# test's execution.
|
||||
cx_pool = get_pool(c)
|
||||
nthreads = 10
|
||||
@ -435,7 +435,7 @@ class TestPoolMaxSize(_TestPoolingBase):
|
||||
def f():
|
||||
for _ in range(5):
|
||||
collection.find_one({"$where": delay(0.1)})
|
||||
assert len(cx_pool.sockets) <= max_pool_size
|
||||
assert len(cx_pool.conns) <= max_pool_size
|
||||
|
||||
with lock:
|
||||
self.n_passed += 1
|
||||
@ -447,7 +447,7 @@ class TestPoolMaxSize(_TestPoolingBase):
|
||||
|
||||
joinall(threads)
|
||||
self.assertEqual(nthreads, self.n_passed)
|
||||
self.assertTrue(len(cx_pool.sockets) > 1)
|
||||
self.assertTrue(len(cx_pool.conns) > 1)
|
||||
self.assertEqual(0, cx_pool.requests)
|
||||
|
||||
def test_max_pool_size_none(self):
|
||||
@ -479,7 +479,7 @@ class TestPoolMaxSize(_TestPoolingBase):
|
||||
|
||||
joinall(threads)
|
||||
self.assertEqual(nthreads, self.n_passed)
|
||||
self.assertTrue(len(cx_pool.sockets) > 1)
|
||||
self.assertTrue(len(cx_pool.conns) > 1)
|
||||
self.assertEqual(cx_pool.max_pool_size, float("inf"))
|
||||
|
||||
def test_max_pool_size_zero(self):
|
||||
@ -502,7 +502,7 @@ class TestPoolMaxSize(_TestPoolingBase):
|
||||
# socket from pool" instead of AutoReconnect.
|
||||
for _i in range(2):
|
||||
with self.assertRaises(AutoReconnect) as context:
|
||||
with test_pool.get_socket():
|
||||
with test_pool.checkout():
|
||||
pass
|
||||
|
||||
# Testing for AutoReconnect instead of ConnectionFailure, above,
|
||||
|
||||
@ -284,18 +284,18 @@ class ReadPrefTester(MongoClient):
|
||||
super().__init__(*args, **client_options)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _socket_for_reads(self, read_preference, session):
|
||||
context = super()._socket_for_reads(read_preference, session)
|
||||
with context as (sock_info, read_preference):
|
||||
self.record_a_read(sock_info.address)
|
||||
yield sock_info, read_preference
|
||||
def _conn_for_reads(self, read_preference, session):
|
||||
context = super()._conn_for_reads(read_preference, session)
|
||||
with context as (conn, read_preference):
|
||||
self.record_a_read(conn.address)
|
||||
yield conn, read_preference
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _socket_from_server(self, read_preference, server, session):
|
||||
context = super()._socket_from_server(read_preference, server, session)
|
||||
with context as (sock_info, read_preference):
|
||||
self.record_a_read(sock_info.address)
|
||||
yield sock_info, read_preference
|
||||
def _conn_from_server(self, read_preference, server, session):
|
||||
context = super()._conn_from_server(read_preference, server, session)
|
||||
with context as (conn, read_preference):
|
||||
self.record_a_read(conn.address)
|
||||
yield conn, read_preference
|
||||
|
||||
def record_a_read(self, address):
|
||||
server = self._get_topology().select_server_by_address(address, 0)
|
||||
|
||||
@ -141,7 +141,7 @@ class TestProse(IntegrationTest):
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
wait_until(lambda: len(client.nodes) == 2, "discover both nodes")
|
||||
wait_until(lambda: len(get_pool(client).sockets) >= 10, "create 10 connections")
|
||||
wait_until(lambda: len(get_pool(client).conns) >= 10, "create 10 connections")
|
||||
# Delay find commands on
|
||||
delay_finds = {
|
||||
"configureFailPoint": "failCommand",
|
||||
|
||||
@ -279,12 +279,12 @@ class HeartbeatEventListener(BaseListener, monitoring.ServerHeartbeatListener):
|
||||
self.add_event(event)
|
||||
|
||||
|
||||
class MockSocketInfo:
|
||||
class MockConnection:
|
||||
def __init__(self):
|
||||
self.cancel_context = _CancellationContext()
|
||||
self.more_to_come = False
|
||||
|
||||
def close_socket(self, reason):
|
||||
def close_conn(self, reason):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
@ -304,10 +304,10 @@ class MockPool:
|
||||
def stale_generation(self, gen, service_id):
|
||||
return self.gen.stale(gen, service_id)
|
||||
|
||||
def get_socket(self, handler=None):
|
||||
return MockSocketInfo()
|
||||
def checkout(self, handler=None):
|
||||
return MockConnection()
|
||||
|
||||
def return_socket(self, *args, **kwargs):
|
||||
def checkin(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def _reset(self, service_id=None):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user