Pin transactions to a single server address

This commit is contained in:
Shane Harvey 2018-03-23 22:53:21 -04:00 committed by A. Jesse Jiryu Davis
parent 116d2c278f
commit 656aa1e703
6 changed files with 72 additions and 42 deletions

View File

@ -427,7 +427,7 @@ class _Bulk(object):
client = self.collection.database.client
if not write_concern.acknowledged:
with client._socket_for_writes() as sock_info:
with client._socket_for_writes(session) as sock_info:
self.execute_no_results(sock_info, generator)
else:
return self.execute_command(generator, write_concern, session)

View File

@ -118,6 +118,7 @@ class ClientSession(object):
self._cluster_time = None
self._operation_time = None
self._current_txn_read_pref = None
self._current_txn_address = None
if self.options.auto_start_transaction:
# TODO: Get transaction options from self.options.
self._current_transaction_opts = TransactionOptions()
@ -240,6 +241,8 @@ class ClientSession(object):
finally:
self._server_session.reset_transaction()
self._current_transaction_opts = None
self._current_txn_address = None
self._current_txn_read_pref = None
def _advance_cluster_time(self, cluster_time):
"""Internal cluster time helper."""
@ -295,6 +298,10 @@ class ClientSession(object):
"""True if this session has an active multi-statement transaction."""
return self._current_transaction_opts is not None
def _pin_server_address(self, address):
assert self._current_txn_address is None, "Transaction already pinned"
self._current_txn_address = address
def _apply_to(self, command, is_retryable, read_preference):
self._check_ended()

View File

@ -184,14 +184,16 @@ class Collection(common.BaseObject):
unicode_decode_error_handler='replace',
document_class=dict)
def _socket_for_reads(self):
return self.__database.client._socket_for_reads(self.read_preference)
def _socket_for_reads(self, session):
return self.__database.client._socket_for_reads(
self.read_preference, session)
def _socket_for_primary_reads(self):
return self.__database.client._socket_for_reads(ReadPreference.PRIMARY)
def _socket_for_primary_reads(self, session):
return self.__database.client._socket_for_reads(
ReadPreference.PRIMARY, session)
def _socket_for_writes(self):
return self.__database.client._socket_for_writes()
def _socket_for_writes(self, session):
return self.__database.client._socket_for_writes(session)
def _command(self, sock_info, command, slave_ok=False,
read_preference=None,
@ -252,7 +254,7 @@ class Collection(common.BaseObject):
if "size" in options:
options["size"] = float(options["size"])
cmd.update(options)
with self._socket_for_writes() as sock_info:
with self._socket_for_writes(session) as sock_info:
self._command(
sock_info, cmd, read_preference=ReadPreference.PRIMARY,
write_concern=self.write_concern,
@ -579,7 +581,7 @@ class Collection(common.BaseObject):
True, _insert_command, session)
_check_write_command_response(result)
else:
with self._socket_for_writes() as sock_info:
with self._socket_for_writes(session=None) as sock_info:
# Legacy OP_INSERT.
self._legacy_write(
sock_info, 'insert', command, op_id,
@ -1493,7 +1495,7 @@ class Collection(common.BaseObject):
('numCursors', num_cursors)])
cmd.update(kwargs)
with self._socket_for_reads() as (sock_info, slave_ok):
with self._socket_for_reads(session) as (sock_info, slave_ok):
result = self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern,
session=session)
@ -1509,7 +1511,7 @@ class Collection(common.BaseObject):
def _count(self, cmd, collation=None, session=None):
"""Internal count helper."""
with self._socket_for_reads() as (sock_info, slave_ok):
with self._socket_for_reads(session) as (sock_info, slave_ok):
res = self._command(
sock_info, cmd, slave_ok,
allowable_errors=["ns missing"],
@ -1606,7 +1608,7 @@ class Collection(common.BaseObject):
"""
common.validate_list('indexes', indexes)
names = []
with self._socket_for_writes() as sock_info:
with self._socket_for_writes(session) as sock_info:
supports_collations = sock_info.max_wire_version >= 5
def gen_indexes():
for index in indexes:
@ -1647,7 +1649,7 @@ class Collection(common.BaseObject):
index_options.pop('collation', None))
index.update(index_options)
with self._socket_for_writes() as sock_info:
with self._socket_for_writes(session) as sock_info:
if collation is not None:
if sock_info.max_wire_version < 5:
raise ConfigurationError(
@ -1874,7 +1876,7 @@ class Collection(common.BaseObject):
self.__database.name, self.__name, name)
cmd = SON([("dropIndexes", self.__name), ("index", name)])
cmd.update(kwargs)
with self._socket_for_writes() as sock_info:
with self._socket_for_writes(session) as sock_info:
self._command(sock_info,
cmd,
read_preference=ReadPreference.PRIMARY,
@ -1911,7 +1913,7 @@ class Collection(common.BaseObject):
"""
cmd = SON([("reIndex", self.__name)])
cmd.update(kwargs)
with self._socket_for_writes() as sock_info:
with self._socket_for_writes(session) as sock_info:
return self._command(
sock_info, cmd, read_preference=ReadPreference.PRIMARY,
parse_write_concern_error=True, session=session)
@ -1940,7 +1942,7 @@ class Collection(common.BaseObject):
codec_options = CodecOptions(SON)
coll = self.with_options(codec_options=codec_options,
read_preference=ReadPreference.PRIMARY)
with self._socket_for_primary_reads() as (sock_info, slave_ok):
with self._socket_for_primary_reads(session) as (sock_info, slave_ok):
cmd = SON([("listIndexes", self.__name), ("cursor", {})])
if sock_info.max_wire_version > 2:
with self.__database.client._tmp_session(session, False) as s:
@ -2061,7 +2063,7 @@ class Collection(common.BaseObject):
"batchSize", kwargs.pop("batchSize", None))
# If the server does not support the "cursor" option we
# ignore useCursor and batchSize.
with self._socket_for_reads() as (sock_info, slave_ok):
with self._socket_for_reads(session) as (sock_info, slave_ok):
dollar_out = pipeline and '$out' in pipeline[-1]
if use_cursor:
if "cursor" not in kwargs:
@ -2350,7 +2352,7 @@ class Collection(common.BaseObject):
collation = validate_collation_or_none(kwargs.pop('collation', None))
cmd.update(kwargs)
with self._socket_for_reads() as (sock_info, slave_ok):
with self._socket_for_reads(session=None) as (sock_info, slave_ok):
return self._command(sock_info, cmd, slave_ok,
collation=collation)["retval"]
@ -2396,7 +2398,7 @@ class Collection(common.BaseObject):
new_name = "%s.%s" % (self.__database.name, new_name)
cmd = SON([("renameCollection", self.__full_name), ("to", new_name)])
with self._socket_for_writes() as sock_info:
with self._socket_for_writes(session) as sock_info:
with self.__database.client._tmp_session(session) as s:
if sock_info.max_wire_version >= 5 and self.write_concern:
cmd['writeConcern'] = self.write_concern.document
@ -2451,7 +2453,7 @@ class Collection(common.BaseObject):
kwargs["query"] = filter
collation = validate_collation_or_none(kwargs.pop('collation', None))
cmd.update(kwargs)
with self._socket_for_reads() as (sock_info, slave_ok):
with self._socket_for_reads(session) as (sock_info, slave_ok):
return self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern,
collation=collation, session=session)["values"]
@ -2523,7 +2525,7 @@ class Collection(common.BaseObject):
cmd.update(kwargs)
inline = 'inline' in cmd['out']
with self._socket_for_primary_reads() as (sock_info, slave_ok):
with self._socket_for_primary_reads(session) as (sock_info, slave_ok):
if (sock_info.max_wire_version >= 5 and self.write_concern and
not inline):
cmd['writeConcern'] = self.write_concern.document
@ -2592,7 +2594,7 @@ class Collection(common.BaseObject):
("out", {"inline": 1})])
collation = validate_collation_or_none(kwargs.pop('collation', None))
cmd.update(kwargs)
with self._socket_for_reads() as (sock_info, slave_ok):
with self._socket_for_reads(session) as (sock_info, slave_ok):
if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd:
res = self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern,

View File

@ -526,7 +526,8 @@ class Database(common.BaseObject):
.. mongodoc:: commands
"""
client = self.__client
with client._socket_for_reads(read_preference) as (sock_info, slave_ok):
with client._socket_for_reads(
read_preference, session) as (sock_info, slave_ok):
return self._command(sock_info, command, slave_ok, value,
check, allowable_errors, read_preference,
codec_options, session=session, **kwargs)
@ -584,7 +585,7 @@ class Database(common.BaseObject):
.. versionadded:: 3.6
"""
with self.__client._socket_for_reads(
ReadPreference.PRIMARY) as (sock_info, slave_okay):
ReadPreference.PRIMARY, session) as (sock_info, slave_okay):
return self._list_collections(
sock_info, slave_okay, session=session, **kwargs)
@ -649,7 +650,7 @@ class Database(common.BaseObject):
self.__client._purge_index(self.__name, name)
with self.__client._socket_for_reads(
ReadPreference.PRIMARY) as (sock_info, slave_ok):
ReadPreference.PRIMARY, session) as (sock_info, slave_ok):
return self._command(
sock_info, 'drop', slave_ok, _unicode(name),
allowable_errors=['ns not found'],
@ -730,7 +731,7 @@ class Database(common.BaseObject):
Added ``session`` parameter.
"""
cmd = SON([("currentOp", 1), ("$all", include_all)])
with self.__client._socket_for_writes() as sock_info:
with self.__client._socket_for_writes(session) as sock_info:
if sock_info.max_wire_version >= 4:
with self.__client._tmp_session(session) as s:
return sock_info.command("admin", cmd, session=s,

View File

@ -871,7 +871,8 @@ class MongoClient(common.BaseObject):
# Use SocketInfo.command directly to avoid implicitly creating
# another session.
with self._socket_for_reads(
ReadPreference.PRIMARY_PREFERRED) as (sock_info, slave_ok):
ReadPreference.PRIMARY_PREFERRED,
None) as (sock_info, slave_ok):
if not sock_info.supports_sessions:
return
@ -967,13 +968,31 @@ class MongoClient(common.BaseObject):
self.__reset_server(server.description.address)
raise
def _socket_for_writes(self):
server = self._get_topology().select_server(writable_server_selector)
return self._get_socket(server)
def _select_server(self, read_preference, session):
topology = self._get_topology()
if session and session.in_transaction:
if session._current_txn_address:
server = topology.select_server_by_address(
session._current_txn_address)
if not server:
raise AutoReconnect(
'Pinned server %s:%d for transaction no longer'
'available' % session._current_txn_address)
return server
server = topology.select_server(read_preference)
session._pin_server_address(server.description.address)
return server
else:
return topology.select_server(read_preference)
def _socket_for_writes(self, session):
return self._get_socket(self._select_server(
ReadPreference.PRIMARY, session))
@contextlib.contextmanager
def _socket_for_reads(self, read_preference):
preference = read_preference or ReadPreference.PRIMARY
def _socket_for_reads(self, read_preference, session):
assert read_preference is not None, "read_preference must not be None"
# Get a socket for a server matching the read preference, and yield
# sock_info, slave_ok. Server Selection Spec: "slaveOK must be sent to
# mongods with topology type Single. If the server type is Mongos,
@ -982,10 +1001,11 @@ class MongoClient(common.BaseObject):
# Thread safe: if the type is single it cannot change.
topology = self._get_topology()
single = topology.description.topology_type == TOPOLOGY_TYPE.Single
server = topology.select_server(read_preference)
server = self._select_server(read_preference, session)
with self._get_socket(server) as sock_info:
slave_ok = (single and not sock_info.is_mongos) or (
preference != ReadPreference.PRIMARY)
read_preference != ReadPreference.PRIMARY)
yield sock_info, slave_ok
def _send_message_with_response(self, operation, read_preference=None,
@ -1005,14 +1025,14 @@ class MongoClient(common.BaseObject):
self._kill_cursors_executor.open()
topology = self._get_topology()
session = operation.session
if address:
server = topology.select_server_by_address(address)
if not server:
raise AutoReconnect('server %s:%d no longer available'
% address)
else:
selector = read_preference or writable_server_selector
server = topology.select_server(selector)
server = self._select_server(read_preference, session)
# A _Query's slaveOk bit is already set for queries with non-primary
# read preference. If this is a direct connection to a mongod, override
@ -1064,8 +1084,7 @@ class MongoClient(common.BaseObject):
return bulk.retrying if bulk else retrying
while True:
try:
server = self._get_topology().select_server(
writable_server_selector)
server = self._select_server(ReadPreference.PRIMARY, session)
supports_session = (
session is not None and
server.description.retryable_writes_supported)
@ -1539,7 +1558,7 @@ class MongoClient(common.BaseObject):
self._purge_index(name)
with self._socket_for_reads(
ReadPreference.PRIMARY) as (sock_info, slave_ok):
ReadPreference.PRIMARY, None) as (sock_info, slave_ok):
self[name]._command(
sock_info,
"dropDatabase",
@ -1681,7 +1700,7 @@ class MongoClient(common.BaseObject):
Added ``session`` parameter.
"""
cmd = SON([("fsyncUnlock", 1)])
with self._socket_for_writes() as sock_info:
with self._socket_for_writes(session=None) as sock_info:
if sock_info.max_wire_version >= 4:
try:
with self._tmp_session(session) as s:

View File

@ -315,8 +315,9 @@ class ReadPrefTester(MongoClient):
super(ReadPrefTester, self).__init__(*args, **client_options)
@contextlib.contextmanager
def _socket_for_reads(self, read_preference):
context = super(ReadPrefTester, self)._socket_for_reads(read_preference)
def _socket_for_reads(self, read_preference, session):
context = super(ReadPrefTester, self)._socket_for_reads(
read_preference, session)
with context as (sock_info, slave_ok):
self.record_a_read(sock_info.address)
yield sock_info, slave_ok