diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 78bc0a8c9..245afdef5 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -267,7 +267,7 @@ class _TransactionContext(object): return self def __exit__(self, exc_type, exc_val, exc_tb): - if self.__session._in_transaction: + if self.__session.in_transaction: if exc_val is None: self.__session.commit_transaction() else: @@ -356,7 +356,7 @@ class ClientSession(object): def _end_session(self, lock): if self._server_session is not None: try: - if self._in_transaction: + if self.in_transaction: self.abort_transaction() finally: self._client._return_server_session(self._server_session, lock) @@ -505,7 +505,7 @@ class ClientSession(object): try: ret = callback(self) except Exception as exc: - if self._in_transaction: + if self.in_transaction: self.abort_transaction() if (isinstance(exc, PyMongoError) and exc.has_error_label("TransientTransactionError") and @@ -514,8 +514,7 @@ class ClientSession(object): continue raise - if self._transaction.state in ( - _TxnState.NONE, _TxnState.COMMITTED, _TxnState.ABORTED): + if not self.in_transaction: # Assume callback intentionally ended the transaction. return ret @@ -551,7 +550,7 @@ class ClientSession(object): """ self._check_ended() - if self._in_transaction: + if self.in_transaction: raise InvalidOperation("Transaction already in progress") read_concern = self._inherit_option("read_concern", read_concern) @@ -589,7 +588,7 @@ class ClientSession(object): "Cannot call commitTransaction after calling abortTransaction") elif state is _TxnState.COMMITTED: # We're explicitly retrying the commit, move the state back to - # "in progress" so that _in_transaction returns true. + # "in progress" so that in_transaction returns true. self._transaction.state = _TxnState.IN_PROGRESS retry = True @@ -750,7 +749,7 @@ class ClientSession(object): """Process a response to a command that was run with this session.""" self._advance_cluster_time(reply.get('$clusterTime')) self._advance_operation_time(reply.get('operationTime')) - if self._in_transaction and self._transaction.sharded: + if self.in_transaction and self._transaction.sharded: recovery_token = reply.get('recoveryToken') if recovery_token: self._transaction.recovery_token = recovery_token @@ -761,8 +760,11 @@ class ClientSession(object): return self._server_session is None @property - def _in_transaction(self): - """True if this session has an active multi-statement transaction.""" + def in_transaction(self): + """True if this session has an active multi-statement transaction. + + .. versionadded:: 3.10 + """ return self._transaction.active() @property @@ -783,7 +785,7 @@ class ClientSession(object): def _txn_read_preference(self): """Return read preference of this transaction or None.""" - if self._in_transaction: + if self.in_transaction: return self._transaction.opts.read_preference return None @@ -793,14 +795,14 @@ class ClientSession(object): self._server_session.last_use = monotonic.time() command['lsid'] = self._server_session.session_id - if not self._in_transaction: + if not self.in_transaction: self._transaction.reset() if is_retryable: command['txnNumber'] = self._server_session.transaction_id return - if self._in_transaction: + if self.in_transaction: if read_preference != ReadPreference.PRIMARY: raise InvalidOperation( 'read preference in a transaction must be primary, not: ' diff --git a/pymongo/common.py b/pymongo/common.py index 76c14e08f..f208e83d0 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -824,7 +824,7 @@ class BaseObject(object): """Read only access to the write concern of this instance or session. """ # Override this operation's write concern with the transaction's. - if session and session._in_transaction: + if session and session.in_transaction: return DEFAULT_WRITE_CONCERN return self.write_concern diff --git a/pymongo/message.py b/pymongo/message.py index 56f863695..2b8ff1042 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -205,7 +205,7 @@ def _gen_find_command(coll, spec, projection, skip, limit, batch_size, options, cmd['singleBatch'] = True if batch_size: cmd['batchSize'] = batch_size - if read_concern.level and not (session and session._in_transaction): + if read_concern.level and not (session and session.in_transaction): cmd['readConcern'] = read_concern.document if collation: cmd['collation'] = collation @@ -304,7 +304,7 @@ class _Query(object): # Explain does not support readConcern. if (not explain and session.options.causal_consistency and session.operation_time is not None - and not session._in_transaction): + and not session.in_transaction): cmd.setdefault( 'readConcern', {})[ 'afterClusterTime'] = session.operation_time diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 79ceea245..8220a5e44 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1255,7 +1255,7 @@ class MongoClient(common.BaseObject): # Pin this session to the selected server if it's performing a # sharded transaction. if server.description.mongos and (session and - session._in_transaction): + session.in_transaction): session._pin_mongos(server) return server except PyMongoError as exc: @@ -1355,7 +1355,7 @@ class MongoClient(common.BaseObject): Re-raises any exception thrown by func(). """ retryable = (retryable and self.retry_writes - and session and not session._in_transaction) + and session and not session.in_transaction) last_error = None retrying = False @@ -1445,7 +1445,7 @@ class MongoClient(common.BaseObject): """ retryable = (retryable and self.retry_reads - and not (session and session._in_transaction)) + and not (session and session.in_transaction)) last_error = None retrying = False diff --git a/pymongo/network.py b/pymongo/network.py index cf88a814a..0996180f5 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -98,7 +98,7 @@ def command(sock, dbname, spec, slave_ok, is_mongos, orig = spec if is_mongos and not use_op_msg: spec = message._maybe_add_read_preference(spec, read_preference) - if read_concern and not (session and session._in_transaction): + if read_concern and not (session and session.in_transaction): if read_concern.level: spec['readConcern'] = read_concern.document if (session and session.options.causal_consistency diff --git a/test/test_transactions.py b/test/test_transactions.py index 5b575adbf..88e6dae5a 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -334,6 +334,42 @@ class TestTransactionsConvenientAPI(TransactionsBase): self.assertEqual(listener.started_command_names(), ['insert', 'commitTransaction', 'commitTransaction']) + # Tested here because this supports Motor's convenient transactions API. + @client_context.require_transactions + def test_in_transaction_property(self): + client = client_context.client + coll = client.test.testcollection + coll.insert_one({}) + self.addCleanup(coll.drop) + + with client.start_session() as s: + self.assertFalse(s.in_transaction) + s.start_transaction() + self.assertTrue(s.in_transaction) + coll.insert_one({}, session=s) + self.assertTrue(s.in_transaction) + s.commit_transaction() + self.assertFalse(s.in_transaction) + + with client.start_session() as s: + s.start_transaction() + # commit empty transaction + s.commit_transaction() + self.assertFalse(s.in_transaction) + + with client.start_session() as s: + s.start_transaction() + s.abort_transaction() + self.assertFalse(s.in_transaction) + + # Using a callback + def callback(session): + self.assertTrue(session.in_transaction) + with client.start_session() as s: + self.assertFalse(s.in_transaction) + s.with_transaction(callback) + self.assertFalse(s.in_transaction) + def create_test(scenario_def, test, name): @client_context.require_test_commands