diff --git a/src/mongo/db/transaction/transaction_api.cpp b/src/mongo/db/transaction/transaction_api.cpp index 6c3ed0f9c13..f772f437bcf 100644 --- a/src/mongo/db/transaction/transaction_api.cpp +++ b/src/mongo/db/transaction/transaction_api.cpp @@ -55,6 +55,7 @@ #include "mongo/db/write_concern_options.h" #include "mongo/executor/inline_executor.h" #include "mongo/executor/task_executor.h" +#include "mongo/idl/idl_parser.h" #include "mongo/logv2/log.h" #include "mongo/rpc/factory.h" #include "mongo/rpc/get_status_from_command_result.h" @@ -521,6 +522,43 @@ BatchedCommandResponse SEPTransactionClient::runCRUDOpSync(const BatchedCommandR return std::move(result).get(); } +ExecutorFuture SEPTransactionClient::_runCRUDOp( + const BulkWriteCommandRequest& cmd) const { + BSONObjBuilder cmdBob(cmd.toBSON(BSONObj())); + // BulkWrite can only execute on admin DB. + return runCommand(DatabaseName::kAdmin, cmdBob.obj()) + .thenRunOn(_executor) + .then([](BSONObj reply) { + uassertStatusOK(getStatusFromWriteCommandReply(reply)); + + IDLParserContext ctx("BulkWriteCommandReplyParse"); + auto response = BulkWriteCommandReply::parse(ctx, reply); + return response; + }); +} + +SemiFuture SEPTransactionClient::runCRUDOp( + const BulkWriteCommandRequest& cmd) const { + return _runCRUDOp(cmd).semi(); +} + +BulkWriteCommandReply SEPTransactionClient::runCRUDOpSync( + const BulkWriteCommandRequest& cmd) const { + + Notification mayReturn; + + auto result = + _runCRUDOp(cmd) + .unsafeToInlineFuture() + // Use tap and tapError instead of tapAll since tapAll is not move-only type friendly + .tap([&](auto&&) { mayReturn.set(); }) + .tapError([&](auto&&) { mayReturn.set(); }); + + runFutureInline(_inlineExecutor.get(), mayReturn); + + return std::move(result).get(); +} + ExecutorFuture> SEPTransactionClient::_exhaustiveFind( const FindCommandRequest& cmd) const { return runCommand(cmd.getDbName(), cmd.toBSON({})) @@ -816,7 +854,9 @@ void Transaction::prepareRequest(BSONObjBuilder* cmdBuilder) { !isRetryableWriteCommand( cmdBuilder->asTempObj().firstElement().fieldNameStringData()) || (cmdBuilder->hasField(write_ops::WriteCommandRequestBase::kStmtIdsFieldName) || - cmdBuilder->hasField(write_ops::WriteCommandRequestBase::kStmtIdFieldName)), + cmdBuilder->hasField(write_ops::WriteCommandRequestBase::kStmtIdFieldName)) || + (cmdBuilder->hasField(BulkWriteCommandRequest::kStmtIdFieldName) || + cmdBuilder->hasField(BulkWriteCommandRequest::kStmtIdsFieldName)), str::stream() << "In a retryable write transaction every retryable write command should have an " "explicit statement id, command: " diff --git a/src/mongo/db/transaction/transaction_api.h b/src/mongo/db/transaction/transaction_api.h index ff5cb5aa335..d754ab7390f 100644 --- a/src/mongo/db/transaction/transaction_api.h +++ b/src/mongo/db/transaction/transaction_api.h @@ -30,6 +30,7 @@ #pragma once #include "mongo/db/cancelable_operation_context.h" +#include "mongo/db/commands/bulk_write_gen.h" #include "mongo/db/operation_context.h" #include "mongo/db/query/find_command_gen.h" #include "mongo/db/resource_yielder.h" @@ -128,6 +129,13 @@ public: virtual BatchedCommandResponse runCRUDOpSync(const BatchedCommandRequest& cmd, std::vector stmtIds) const = 0; + /** + * Helper method to run BulkWriteCommandRequest in the transaction client's transaction. + */ + virtual SemiFuture runCRUDOp( + const BulkWriteCommandRequest& cmd) const = 0; + virtual BulkWriteCommandReply runCRUDOpSync(const BulkWriteCommandRequest& cmd) const = 0; + /** * Helper method that runs the given find in the transaction client's transaction and will * iterate and exhaust the find's cursor, returning a vector with all matching documents. @@ -291,6 +299,10 @@ public: virtual BatchedCommandResponse runCRUDOpSync(const BatchedCommandRequest& cmd, std::vector stmtIds) const override; + virtual SemiFuture runCRUDOp( + const BulkWriteCommandRequest& cmd) const override; + virtual BulkWriteCommandReply runCRUDOpSync(const BulkWriteCommandRequest& cmd) const override; + virtual SemiFuture> exhaustiveFind( const FindCommandRequest& cmd) const override; virtual std::vector exhaustiveFindSync(const FindCommandRequest& cmd) const override; @@ -309,6 +321,8 @@ private: ExecutorFuture _runCRUDOp(const BatchedCommandRequest& cmd, std::vector stmtIds) const; + ExecutorFuture _runCRUDOp(const BulkWriteCommandRequest& cmd) const; + ExecutorFuture> _exhaustiveFind(const FindCommandRequest& cmd) const; private: diff --git a/src/mongo/db/transaction/transaction_api_test.cpp b/src/mongo/db/transaction/transaction_api_test.cpp index d1bd94b067d..b14b6918644 100644 --- a/src/mongo/db/transaction/transaction_api_test.cpp +++ b/src/mongo/db/transaction/transaction_api_test.cpp @@ -204,6 +204,15 @@ public: MONGO_UNREACHABLE; } + virtual SemiFuture runCRUDOp( + const BulkWriteCommandRequest& cmd) const override { + MONGO_UNREACHABLE; + } + + virtual BulkWriteCommandReply runCRUDOpSync(const BulkWriteCommandRequest& cmd) const override { + MONGO_UNREACHABLE; + } + virtual bool supportsClientTransactionContext() const override { return true; } @@ -505,6 +514,15 @@ public: MONGO_UNREACHABLE; } + virtual SemiFuture runCRUDOp( + const BulkWriteCommandRequest& cmd) const override { + MONGO_UNREACHABLE; + } + + virtual BulkWriteCommandReply runCRUDOpSync(const BulkWriteCommandRequest& cmd) const override { + MONGO_UNREACHABLE; + } + virtual SemiFuture> exhaustiveFind( const FindCommandRequest& cmd) const override { MONGO_UNREACHABLE; diff --git a/src/mongo/s/SConscript b/src/mongo/s/SConscript index e04d07aea0d..d1e67e670b5 100644 --- a/src/mongo/s/SConscript +++ b/src/mongo/s/SConscript @@ -698,6 +698,7 @@ env.CppUnitTest( 'write_ops/batch_write_op_test.cpp', 'write_ops/batched_command_request_test.cpp', 'write_ops/batched_command_response_test.cpp', + 'write_ops/bulk_write_command_modifier_test.cpp', 'write_ops/bulk_write_exec_test.cpp', 'write_ops/write_op_test.cpp', 'write_ops/write_without_shard_key_util_test.cpp', diff --git a/src/mongo/s/write_ops/SConscript b/src/mongo/s/write_ops/SConscript index 7d4e4536841..30329e3e910 100644 --- a/src/mongo/s/write_ops/SConscript +++ b/src/mongo/s/write_ops/SConscript @@ -7,9 +7,8 @@ env = env.Clone() env.Library( target='batch_write_types', source=[ - 'batched_command_request.cpp', - 'batched_command_response.cpp', - 'batched_upsert_detail.cpp', + 'batched_command_request.cpp', 'batched_command_response.cpp', 'batched_upsert_detail.cpp', + 'bulk_write_command_modifier.cpp' ], LIBDEPS=[ '$BUILD_DIR/mongo/base', diff --git a/src/mongo/s/write_ops/bulk_write_command_modifier.cpp b/src/mongo/s/write_ops/bulk_write_command_modifier.cpp new file mode 100644 index 00000000000..fe6412726d4 --- /dev/null +++ b/src/mongo/s/write_ops/bulk_write_command_modifier.cpp @@ -0,0 +1,242 @@ +/** + * Copyright (C) 2023-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + + +#include "mongo/db/commands/bulk_write_gen.h" +#include "mongo/db/ops/write_ops_parsers.h" +#include "mongo/platform/basic.h" + +#include "mongo/db/pipeline/variables.h" +#include "mongo/s/write_ops/bulk_write_command_modifier.h" + +#include "mongo/bson/bsonobj.h" + +namespace mongo { + +void BulkWriteCommandModifier::parseRequestFromOpMsg(const NamespaceString& nss, + const OpMsgRequest& request) { + auto shardVersionField = request.body[ShardVersion::kShardVersionField]; + if (!shardVersionField.eoo()) { + auto shardVersion = ShardVersion::parse(shardVersionField); + if (shardVersion == ShardVersion::UNSHARDED()) { + setDbVersion(nss, DatabaseVersion(request.body)); + } + setShardVersion(nss, shardVersion); + } + + // The 'isTimeseriesNamespace' is an internal parameter used for communication between mongos + // and mongod. + auto isTimeseriesNamespace = + request.body[write_ops::WriteCommandRequestBase::kIsTimeseriesNamespaceFieldName]; + uassert(7299100, + "the 'isTimeseriesNamespace' parameter cannot be used on mongos", + !isTimeseriesNamespace.trueValue()); + + setIsTimeseriesNamespace(nss, isTimeseriesNamespace.trueValue()); +} + +std::tuple BulkWriteCommandModifier::getNsInfoEntry( + const NamespaceString& nss) { + if (_nsInfoIdxes.contains(nss)) { + // Already have this NamespaceInfoEntry stored. + auto idx = _nsInfoIdxes[nss]; + return std::tie(_nsInfos[idx], idx); + } + // Create new NamespaceInfoEntry. + auto nsInfoEntry = NamespaceInfoEntry(nss); + auto idx = _nsInfos.size(); + _nsInfos.emplace_back(nsInfoEntry); + + _nsInfoIdxes[nss] = idx; + return std::tie(_nsInfos[idx], idx); +} + +void BulkWriteCommandModifier::finishBuild() { + _request->setOps(std::move(_ops)); + _request->setNsInfo(std::move(_nsInfos)); +} + +void BulkWriteCommandModifier::addOp(write_ops::InsertCommandRequest insertOp) { + auto nss = insertOp.getNamespace(); + auto [nsInfoEntry, idx] = getNsInfoEntry(nss); + nsInfoEntry.setEncryptionInformation(insertOp.getEncryptionInformation()); + + for (const auto& doc : insertOp.getDocuments()) { + auto op = BulkWriteInsertOp(idx, doc); + _ops.emplace_back(op); + } +} + +void BulkWriteCommandModifier::addOp(write_ops::UpdateCommandRequest updateOp) { + auto nss = updateOp.getNamespace(); + auto [nsInfoEntry, idx] = getNsInfoEntry(nss); + nsInfoEntry.setEncryptionInformation(updateOp.getEncryptionInformation()); + + for (const auto& update : updateOp.getUpdates()) { + auto op = BulkWriteUpdateOp(idx, update.getQ(), update.getU()); + + op.setArrayFilters(update.getArrayFilters()); + op.setMulti(update.getMulti()); + op.setCollation(update.getCollation()); + op.setUpsert(update.getUpsert()); + op.setHint(update.getHint()); + + // TODO SERVER-73231 let will be renamed to constants. + op.setLet(update.getC()); + + // TODO SET OTHER OPTIONAL FIELDS HERE!!!!!!!! + + _ops.emplace_back(op); + } +} + +void BulkWriteCommandModifier::addOp(write_ops::DeleteCommandRequest deleteOp) { + auto nss = deleteOp.getNamespace(); + auto [nsInfoEntry, idx] = getNsInfoEntry(nss); + nsInfoEntry.setEncryptionInformation(deleteOp.getEncryptionInformation()); + + for (const auto& delOp : deleteOp.getDeletes()) { + auto op = BulkWriteDeleteOp(idx, delOp.getQ()); + + op.setHint(delOp.getHint()); + op.setMulti(delOp.getMulti()); + op.setCollation(delOp.getCollation()); + + _ops.emplace_back(op); + } +} + +void BulkWriteCommandModifier::addInsert(const OpMsgRequest& request) { + auto parsedInsertOp = InsertOp::parse(request); + + auto nss = parsedInsertOp.getNamespace(); + + parseRequestFromOpMsg(nss, request); + + addOp(parsedInsertOp); +} + +void BulkWriteCommandModifier::addUpdate(const OpMsgRequest& request) { + auto parsedUpdateOp = UpdateOp::parse(request); + + auto nss = parsedUpdateOp.getNamespace(); + + parseRequestFromOpMsg(nss, request); + + addOp(parsedUpdateOp); +} + +void BulkWriteCommandModifier::addDelete(const OpMsgRequest& request) { + auto parsedDeleteOp = DeleteOp::parse(request); + + auto nss = parsedDeleteOp.getNamespace(); + + parseRequestFromOpMsg(nss, request); + + addOp(parsedDeleteOp); +} + +void BulkWriteCommandModifier::addInsertOps(const NamespaceString& nss, + const std::vector docs) { + auto [nsInfoEntry, idx] = getNsInfoEntry(nss); + + for (const auto& doc : docs) { + auto op = BulkWriteInsertOp(idx, doc); + + _ops.emplace_back(op); + } +} + +void BulkWriteCommandModifier::addUpdateOp( + const NamespaceString& nss, + const BSONObj& query, + const BSONObj& update, + bool upsert, + bool multi, + const StringData& returnField, + const boost::optional>& arrayFilters, + const boost::optional& collation, + const boost::optional& sort, + const boost::optional& returnFields, + const boost::optional& hint) { + auto [nsInfoEntry, idx] = getNsInfoEntry(nss); + + auto op = BulkWriteUpdateOp(idx, query, update); + + op.setUpsert(upsert); + op.setMulti(multi); + op.setReturn(returnField); + op.setReturnFields(returnFields); + op.setCollation(collation); + op.setHint(hint.value_or(BSONObj())); + op.setArrayFilters(arrayFilters); + op.setSort(sort); + + _ops.emplace_back(op); +} + +void BulkWriteCommandModifier::addPipelineUpdateOps(const NamespaceString& nss, + const BSONObj& query, + const std::vector& updates, + bool upsert, + bool useMultiUpdate) { + auto [nsInfoEntry, idx] = getNsInfoEntry(nss); + + auto updateMod = write_ops::UpdateModification(); + auto op = BulkWriteUpdateOp(idx, query, updates); + + op.setUpsert(upsert); + op.setMulti(useMultiUpdate); + + _ops.emplace_back(op); +} + +void BulkWriteCommandModifier::addDeleteOp(const NamespaceString& nss, + const BSONObj& query, + bool multiDelete, + bool returnField, + const boost::optional& collation, + const boost::optional& sort, + const boost::optional& returnFields, + const boost::optional& hint) { + auto [nsInfoEntry, idx] = getNsInfoEntry(nss); + + auto op = BulkWriteDeleteOp(idx, query); + + op.setMulti(multiDelete); + op.setReturn(returnField); + op.setReturnFields(returnFields); + op.setHint(hint.value_or(BSONObj())); + op.setSort(sort); + op.setCollation(collation); + + _ops.emplace_back(op); +} + +} // namespace mongo diff --git a/src/mongo/s/write_ops/bulk_write_command_modifier.h b/src/mongo/s/write_ops/bulk_write_command_modifier.h new file mode 100644 index 00000000000..00b1e3fdbc2 --- /dev/null +++ b/src/mongo/s/write_ops/bulk_write_command_modifier.h @@ -0,0 +1,168 @@ +/** + * Copyright (C) 2023-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include +#include + +#include "mongo/crypto/fle_field_schema_gen.h" +#include "mongo/db/commands/bulk_write_crud_op.h" +#include "mongo/db/commands/bulk_write_gen.h" +#include "mongo/db/namespace_string.h" +#include "mongo/db/ops/write_ops.h" +#include "mongo/rpc/op_msg.h" +#include "mongo/s/database_version.h" +#include "mongo/s/shard_version.h" +#include "mongo/stdx/unordered_map.h" + +namespace mongo { + +/** + * Helper functions which add new operations into an existing BulkWriteCommandRequest. + */ +class BulkWriteCommandModifier { +public: + BulkWriteCommandModifier(BulkWriteCommandRequest* request, size_t capacity = 0) + : _request(request), _ops(request->getOps()), _nsInfos(request->getNsInfo()) { + invariant(_request); + for (size_t i = 0; i < _nsInfos.size(); i++) { + auto nsInfo = _nsInfos[i]; + _nsInfoIdxes[nsInfo.getNs()] = i; + } + + if (capacity > 0) { + _ops.reserve(capacity); + } + } + + BulkWriteCommandModifier(BulkWriteCommandModifier&&) = default; + + /** + * This function must be called for the BulkWriteCommandRequest to be in a usable state. + */ + void finishBuild(); + + void addOp(write_ops::InsertCommandRequest insertOp); + void addOp(write_ops::UpdateCommandRequest updateOp); + void addOp(write_ops::DeleteCommandRequest deleteOp); + + void addInsert(const OpMsgRequest& request); + void addUpdate(const OpMsgRequest& request); + void addDelete(const OpMsgRequest& request); + + size_t numOps() const { + return _request->getOps().size(); + } + + void setIsTimeseriesNamespace(const NamespaceString& nss, bool isTimeseriesNamespace) { + auto [nsInfoEntry, idx] = getNsInfoEntry(nss); + nsInfoEntry.setIsTimeseriesNamespace(isTimeseriesNamespace); + } + + void setEncryptionInformation(const NamespaceString& nss, + const EncryptionInformation& encryption) { + auto [nsInfoEntry, idx] = getNsInfoEntry(nss); + nsInfoEntry.setEncryptionInformation(encryption); + } + + void setShardVersion(const NamespaceString& nss, const ShardVersion& sv) { + auto [nsInfoEntry, idx] = getNsInfoEntry(nss); + nsInfoEntry.setShardVersion(sv); + } + + const ShardVersion& getShardVersion(const NamespaceString& nss) { + auto [nsInfoEntry, idx] = getNsInfoEntry(nss); + invariant(nsInfoEntry.getShardVersion()); + return *nsInfoEntry.getShardVersion(); + } + + void setDbVersion(const NamespaceString& nss, const DatabaseVersion& dbv) { + auto [nsInfoEntry, idx] = getNsInfoEntry(nss); + nsInfoEntry.setDatabaseVersion(dbv); + } + + const DatabaseVersion& getDbVersion(const NamespaceString& nss) { + auto [nsInfoEntry, idx] = getNsInfoEntry(nss); + invariant(nsInfoEntry.getDatabaseVersion()); + return *nsInfoEntry.getDatabaseVersion(); + } + + // TODO SERVER-73231: Add setter for top level let. + + void addInsertOps(const NamespaceString& nss, std::vector docs); + + void addUpdateOp(const NamespaceString& nss, + const BSONObj& query, + const BSONObj& update, + bool upsert, + bool multi, + const StringData& returnField, + const boost::optional>& arrayFilters, + const boost::optional& collation, + const boost::optional& sort, + const boost::optional& returnFields, + const boost::optional& hint); + + void addPipelineUpdateOps(const NamespaceString& nss, + const BSONObj& query, + const std::vector& updates, + bool upsert, + bool useMultiUpdate); + + void addDeleteOp(const NamespaceString& nss, + const BSONObj& query, + bool multiDelete, + bool returnField, + const boost::optional& collation, + const boost::optional& sort, + const boost::optional& returnFields, + const boost::optional& hint); + +private: + BulkWriteCommandRequest* _request; + + stdx::unordered_map _nsInfoIdxes; + + std::vector< + stdx::variant> + _ops; + std::vector _nsInfos; + + /** + * Gets the NamespaceInfoEntry for the associated namespace. If one does not exist + * then it will be created. Returns a reference to the NamespaceInfoEntry and the index in + * the nsInfo array. + */ + std::tuple getNsInfoEntry(const NamespaceString& nss); + + void parseRequestFromOpMsg(const NamespaceString& nss, const OpMsgRequest& request); +}; + +} // namespace mongo diff --git a/src/mongo/s/write_ops/bulk_write_command_modifier_test.cpp b/src/mongo/s/write_ops/bulk_write_command_modifier_test.cpp new file mode 100644 index 00000000000..9fbd6ca3f86 --- /dev/null +++ b/src/mongo/s/write_ops/bulk_write_command_modifier_test.cpp @@ -0,0 +1,494 @@ +/** + * Copyright (C) 2023-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/bson/json.h" +#include "mongo/db/commands/bulk_write_gen.h" +#include "mongo/db/namespace_string.h" +#include "mongo/db/ops/write_ops_gen.h" +#include "mongo/db/ops/write_ops_parsers_test_helpers.h" +#include "mongo/s/shard_version_factory.h" +#include "mongo/s/write_ops/bulk_write_command_modifier.h" +#include "mongo/unittest/assert.h" +#include "mongo/unittest/bson_test_util.h" +#include "mongo/unittest/unittest.h" + +namespace mongo { +namespace { + +TEST(BulkWriteCommandModifier, AddInsert) { + BSONArray insertArray = BSON_ARRAY(BSON("a" << 1) << BSON("b" << 1)); + + BSONObj origInsertRequestObj = BSON("insert" + << "test" + << "documents" << insertArray << "writeConcern" + << BSON("w" << 1) << "ordered" << true); + + for (auto docSeq : {false, true}) { + const auto opMsgRequest(toOpMsg("TestDB", origInsertRequestObj, docSeq)); + + BulkWriteCommandRequest request; + BulkWriteCommandModifier builder(&request); + builder.addInsert(opMsgRequest); + builder.finishBuild(); + + auto nsInfo = request.getNsInfo(); + ASSERT_EQ(1, nsInfo.size()); + ASSERT_EQ("TestDB", nsInfo[0].getNs().db()); + ASSERT_EQ("test", nsInfo[0].getNs().coll()); + ASSERT_EQ(2, request.getOps().size()); + ASSERT_EQ(boost::none, nsInfo[0].getShardVersion()); + } +} + +TEST(BulkWriteCommandModifier, AddOpInsert) { + auto nss = NamespaceString("TestDB", "test"); + auto req = write_ops::InsertCommandRequest(nss); + auto docs = std::vector(); + docs.emplace_back(BSON("a" << 1)); + docs.emplace_back(BSON("b" << 1)); + req.setDocuments(docs); + + BulkWriteCommandRequest request; + BulkWriteCommandModifier builder(&request); + builder.addOp(req); + builder.finishBuild(); + + auto nsInfo = request.getNsInfo(); + ASSERT_EQ(1, nsInfo.size()); + ASSERT_EQ("TestDB", nsInfo[0].getNs().db()); + ASSERT_EQ("test", nsInfo[0].getNs().coll()); + ASSERT_EQ(2, request.getOps().size()); + ASSERT_EQ(boost::none, nsInfo[0].getShardVersion()); +} + +TEST(BulkWriteCommandModifier, AddInsertOps) { + auto nss = NamespaceString("TestDB", "test"); + auto docs = std::vector(); + docs.emplace_back(BSON("a" << 1)); + docs.emplace_back(BSON("b" << 1)); + + BulkWriteCommandRequest request; + BulkWriteCommandModifier builder(&request); + builder.addInsertOps(nss, docs); + builder.finishBuild(); + + auto nsInfo = request.getNsInfo(); + ASSERT_EQ(1, nsInfo.size()); + ASSERT_EQ("TestDB", nsInfo[0].getNs().db()); + ASSERT_EQ("test", nsInfo[0].getNs().coll()); + ASSERT_EQ(2, request.getOps().size()); + ASSERT_EQ(boost::none, nsInfo[0].getShardVersion()); +} + +TEST(BulkWriteCommandModifier, InsertWithShardVersion) { + BSONArray insertArray = BSON_ARRAY(BSON("a" << 1) << BSON("b" << 1)); + + const OID epoch = OID::gen(); + const Timestamp timestamp(2, 2); + const Timestamp majorAndMinor(1, 2); + + BSONObj origInsertRequestObj = BSON("insert" + << "test" + << "documents" << insertArray << "writeConcern" + << BSON("w" << 1) << "ordered" << true << "shardVersion" + << BSON("e" << epoch << "t" << timestamp << "v" + << majorAndMinor)); + + for (auto docSeq : {false, true}) { + const auto opMsgRequest(toOpMsg("TestDB", origInsertRequestObj, docSeq)); + + BulkWriteCommandRequest request = BulkWriteCommandRequest(); + BulkWriteCommandModifier builder(&request); + builder.addInsert(opMsgRequest); + builder.finishBuild(); + + auto nsInfo = request.getNsInfo(); + ASSERT_EQ(1, nsInfo.size()); + ASSERT_EQ("TestDB", nsInfo[0].getNs().db()); + ASSERT_EQ("test", nsInfo[0].getNs().coll()); + ASSERT_NE(boost::none, nsInfo[0].getShardVersion()); + ASSERT_EQ(ShardVersionFactory::make(ChunkVersion({epoch, timestamp}, {1, 2}), + boost::optional(boost::none)) + .toString(), + (*nsInfo[0].getShardVersion()).toString()); + } +} + +TEST(BulkWriteCommandModifier, AddUpdate) { + auto nss = NamespaceString("TestDB", "test"); + const BSONObj query = BSON("x" << 1); + const BSONObj update = BSON("$inc" << BSON("x" << 1)); + const BSONObj collation = BSON("locale" + << "en_US"); + const BSONObj arrayFilter = BSON("i" << 0); + for (bool upsert : {false, true}) { + for (bool multi : {false, true}) { + auto rawUpdate = + BSON("q" << query << "u" << update << "arrayFilters" << BSON_ARRAY(arrayFilter) + << "multi" << multi << "upsert" << upsert << "collation" << collation); + auto cmd = BSON("update" << nss.coll() << "updates" << BSON_ARRAY(rawUpdate)); + for (bool seq : {false, true}) { + auto opMsgRequest = toOpMsg(nss.db(), cmd, seq); + + BulkWriteCommandRequest request; + BulkWriteCommandModifier builder(&request); + builder.addUpdate(opMsgRequest); + builder.finishBuild(); + + auto nsInfo = request.getNsInfo(); + ASSERT_EQ(1, nsInfo.size()); + ASSERT_EQ("TestDB", nsInfo[0].getNs().db()); + ASSERT_EQ("test", nsInfo[0].getNs().coll()); + ASSERT_EQ(boost::none, nsInfo[0].getShardVersion()); + + ASSERT_EQ(1, request.getOps().size()); + auto op = BulkWriteCRUDOp(request.getOps()[0]); + ASSERT_EQ(upsert, op.getUpdate()->getUpsert()); + ASSERT_EQ(multi, op.getUpdate()->getMulti()); + ASSERT_BSONOBJ_EQ(query, op.getUpdate()->getFilter()); + ASSERT_BSONOBJ_EQ(update, op.getUpdate()->getUpdateMods().getUpdateModifier()); + ASSERT_BSONOBJ_EQ(collation, op.getUpdate()->getCollation().value_or(BSONObj())); + ASSERT(op.getUpdate()->getArrayFilters()); + auto filter = (*op.getUpdate()->getArrayFilters())[0]; + ASSERT_BSONOBJ_EQ(arrayFilter, filter); + } + } + } +} + +TEST(BulkWriteCommandModifier, AddOpUpdate) { + auto nss = NamespaceString("TestDB", "test"); + const BSONObj query = BSON("x" << 1); + const BSONObj update = BSON("$inc" << BSON("x" << 1)); + const BSONObj collation = BSON("locale" + << "en_US"); + const BSONObj arrayFilter = BSON("i" << 0); + + auto updateOp = write_ops::UpdateOpEntry(); + updateOp.setQ(query); + updateOp.setU(update); + updateOp.setCollation(collation); + updateOp.setArrayFilters({{arrayFilter}}); + + for (bool upsert : {false, true}) { + for (bool multi : {false, true}) { + auto req = write_ops::UpdateCommandRequest(nss); + updateOp.setMulti(multi); + updateOp.setUpsert(upsert); + req.setUpdates({updateOp}); + + BulkWriteCommandRequest request; + BulkWriteCommandModifier builder(&request); + builder.addOp(req); + builder.finishBuild(); + + auto nsInfo = request.getNsInfo(); + ASSERT_EQ(1, nsInfo.size()); + ASSERT_EQ("TestDB", nsInfo[0].getNs().db()); + ASSERT_EQ("test", nsInfo[0].getNs().coll()); + ASSERT_EQ(boost::none, nsInfo[0].getShardVersion()); + + ASSERT_EQ(1, request.getOps().size()); + auto op = BulkWriteCRUDOp(request.getOps()[0]); + ASSERT_EQ(upsert, op.getUpdate()->getUpsert()); + ASSERT_EQ(multi, op.getUpdate()->getMulti()); + ASSERT_BSONOBJ_EQ(query, op.getUpdate()->getFilter()); + ASSERT_BSONOBJ_EQ(update, op.getUpdate()->getUpdateMods().getUpdateModifier()); + ASSERT_BSONOBJ_EQ(collation, op.getUpdate()->getCollation().value_or(BSONObj())); + ASSERT(op.getUpdate()->getArrayFilters()); + auto filter = (*op.getUpdate()->getArrayFilters())[0]; + ASSERT_BSONOBJ_EQ(arrayFilter, filter); + } + } +} + +TEST(BulkWriteCommandModifier, AddUpdateOps) { + auto nss = NamespaceString("TestDB", "test"); + const BSONObj query = BSON("x" << 1); + const BSONObj update = BSON("$inc" << BSON("x" << 1)); + const BSONObj collation = BSON("locale" + << "en_US"); + const BSONObj arrayFilter = BSON("i" << 0); + + for (bool upsert : {false, true}) { + for (bool multi : {false, true}) { + for (std::string returnField : {"pre", "post"}) { + BulkWriteCommandRequest request; + BulkWriteCommandModifier builder(&request); + builder.addUpdateOp(nss, + query, + update, + upsert, + multi, + returnField, + {{arrayFilter}}, + collation, + boost::none, + boost::none, + boost::none); + builder.finishBuild(); + + auto nsInfo = request.getNsInfo(); + ASSERT_EQ(1, nsInfo.size()); + ASSERT_EQ("TestDB", nsInfo[0].getNs().db()); + ASSERT_EQ("test", nsInfo[0].getNs().coll()); + ASSERT_EQ(boost::none, nsInfo[0].getShardVersion()); + + ASSERT_EQ(1, request.getOps().size()); + auto op = BulkWriteCRUDOp(request.getOps()[0]); + ASSERT_EQ(upsert, op.getUpdate()->getUpsert()); + ASSERT_EQ(multi, op.getUpdate()->getMulti()); + ASSERT_BSONOBJ_EQ(query, op.getUpdate()->getFilter()); + ASSERT_BSONOBJ_EQ(update, op.getUpdate()->getUpdateMods().getUpdateModifier()); + ASSERT_BSONOBJ_EQ(collation, op.getUpdate()->getCollation().value_or(BSONObj())); + ASSERT(op.getUpdate()->getArrayFilters()); + auto filter = (*op.getUpdate()->getArrayFilters())[0]; + ASSERT_BSONOBJ_EQ(arrayFilter, filter); + ASSERT_EQ(returnField, *op.getUpdate()->getReturn()); + } + } + } +} + +TEST(CommandWriteOpsParsers, BulkWriteUpdateWithPipeline) { + auto nss = NamespaceString("TestDB", "test"); + const BSONObj query = BSON("q" << BSON("x" << 1)); + std::vector pipeline{BSON("$addFields" << BSON("x" << 1))}; + const BSONObj update = BSON("u" << pipeline); + for (bool upsert : {false, true}) { + for (bool multi : {false, true}) { + + BulkWriteCommandRequest request; + BulkWriteCommandModifier builder(&request); + builder.addPipelineUpdateOps(nss, query, pipeline, upsert, multi); + builder.finishBuild(); + + auto nsInfo = request.getNsInfo(); + ASSERT_EQ(1, nsInfo.size()); + ASSERT_EQ("TestDB", nsInfo[0].getNs().db()); + ASSERT_EQ("test", nsInfo[0].getNs().coll()); + ASSERT_EQ(boost::none, nsInfo[0].getShardVersion()); + + ASSERT_EQ(1, request.getOps().size()); + auto op = BulkWriteCRUDOp(request.getOps()[0]); + ASSERT_EQ(upsert, op.getUpdate()->getUpsert()); + ASSERT_EQ(multi, op.getUpdate()->getMulti()); + ASSERT_BSONOBJ_EQ(query, op.getUpdate()->getFilter()); + ASSERT_BSONOBJ_EQ(pipeline[0], op.getUpdate()->getUpdateMods().getUpdatePipeline()[0]); + } + } +} + +TEST(BulkWriteCommandModifier, AddDelete) { + auto nss = NamespaceString("TestDB", "test"); + const BSONObj query = BSON("x" << 1); + const BSONObj collation = BSON("locale" + << "en_US"); + for (bool multi : {false, true}) { + auto rawDelete = + BSON("q" << query << "limit" << (multi ? 0 : 1) << "collation" << collation); + auto cmd = BSON("delete" << nss.coll() << "deletes" << BSON_ARRAY(rawDelete)); + for (bool seq : {false, true}) { + auto opMsgRequest = toOpMsg(nss.db(), cmd, seq); + + + BulkWriteCommandRequest request; + BulkWriteCommandModifier builder(&request); + builder.addDelete(opMsgRequest); + builder.finishBuild(); + + auto nsInfo = request.getNsInfo(); + ASSERT_EQ(1, nsInfo.size()); + ASSERT_EQ("TestDB", nsInfo[0].getNs().db()); + ASSERT_EQ("test", nsInfo[0].getNs().coll()); + ASSERT_EQ(boost::none, nsInfo[0].getShardVersion()); + + ASSERT_EQ(1, request.getOps().size()); + auto op = BulkWriteCRUDOp(request.getOps()[0]); + ASSERT_EQ(multi, op.getDelete()->getMulti()); + ASSERT_BSONOBJ_EQ(query, op.getDelete()->getFilter()); + ASSERT_BSONOBJ_EQ(collation, op.getDelete()->getCollation().value_or(BSONObj())); + } + } +} + +TEST(BulkWriteCommandModifier, AddOpDelete) { + auto nss = NamespaceString("TestDB", "test"); + const BSONObj query = BSON("x" << 1); + const BSONObj collation = BSON("locale" + << "en_US"); + + auto delOp = write_ops::DeleteOpEntry(); + delOp.setCollation(collation); + delOp.setQ(query); + for (bool multi : {false, true}) { + auto delReq = write_ops::DeleteCommandRequest(nss); + delOp.setMulti(multi); + delReq.setDeletes({delOp}); + + BulkWriteCommandRequest request; + BulkWriteCommandModifier builder(&request); + builder.addOp(delReq); + builder.finishBuild(); + + auto nsInfo = request.getNsInfo(); + ASSERT_EQ(1, nsInfo.size()); + ASSERT_EQ("TestDB", nsInfo[0].getNs().db()); + ASSERT_EQ("test", nsInfo[0].getNs().coll()); + ASSERT_EQ(boost::none, nsInfo[0].getShardVersion()); + + ASSERT_EQ(1, request.getOps().size()); + auto op = BulkWriteCRUDOp(request.getOps()[0]); + ASSERT_EQ(multi, op.getDelete()->getMulti()); + ASSERT_BSONOBJ_EQ(query, op.getDelete()->getFilter()); + ASSERT_BSONOBJ_EQ(collation, op.getDelete()->getCollation().value_or(BSONObj())); + } +} + +// Add delete ops +TEST(BulkWriteCommandModifier, AddDeleteOps) { + auto nss = NamespaceString("TestDB", "test"); + const BSONObj query = BSON("x" << 1); + const BSONObj collation = BSON("locale" + << "en_US"); + for (bool multi : {false, true}) { + for (bool returnField : {false, true}) { + BulkWriteCommandRequest request; + BulkWriteCommandModifier builder(&request); + builder.addDeleteOp( + nss, query, multi, returnField, collation, boost::none, boost::none, boost::none); + builder.finishBuild(); + + auto nsInfo = request.getNsInfo(); + ASSERT_EQ(1, nsInfo.size()); + ASSERT_EQ("TestDB", nsInfo[0].getNs().db()); + ASSERT_EQ("test", nsInfo[0].getNs().coll()); + ASSERT_EQ(boost::none, nsInfo[0].getShardVersion()); + + ASSERT_EQ(1, request.getOps().size()); + auto op = BulkWriteCRUDOp(request.getOps()[0]); + ASSERT_EQ(multi, op.getDelete()->getMulti()); + ASSERT_EQ(returnField, op.getDelete()->getReturn()); + ASSERT_BSONOBJ_EQ(query, op.getDelete()->getFilter()); + ASSERT_BSONOBJ_EQ(collation, op.getDelete()->getCollation().value_or(BSONObj())); + } + } +} + +TEST(BulkWriteCommandModifier, TestMultiOpsSameNs) { + auto nss = NamespaceString("TestDB", "test"); + auto docs = std::vector(); + docs.emplace_back(BSON("a" << 1)); + docs.emplace_back(BSON("b" << 1)); + + const BSONObj query = BSON("x" << 1); + const BSONObj collation = BSON("locale" + << "en_US"); + + BulkWriteCommandRequest request; + BulkWriteCommandModifier builder(&request); + builder.addInsertOps(nss, docs); + builder.addDeleteOp(nss, query, true, true, collation, boost::none, boost::none, boost::none); + builder.finishBuild(); + + auto nsInfo = request.getNsInfo(); + ASSERT_EQ(1, nsInfo.size()); + ASSERT_EQ("TestDB", nsInfo[0].getNs().db()); + ASSERT_EQ("test", nsInfo[0].getNs().coll()); + ASSERT_EQ(boost::none, nsInfo[0].getShardVersion()); + + ASSERT_EQ(3, request.getOps().size()); + { + auto op = BulkWriteCRUDOp(request.getOps()[0]); + ASSERT_EQ(BulkWriteCRUDOp::kInsert, op.getType()); + } + { + auto op = BulkWriteCRUDOp(request.getOps()[1]); + ASSERT_EQ(BulkWriteCRUDOp::kInsert, op.getType()); + } + { + auto op = BulkWriteCRUDOp(request.getOps()[2]); + ASSERT_EQ(BulkWriteCRUDOp::kDelete, op.getType()); + ASSERT_EQ(true, op.getDelete()->getMulti()); + ASSERT_EQ(true, op.getDelete()->getReturn()); + ASSERT_BSONOBJ_EQ(query, op.getDelete()->getFilter()); + ASSERT_BSONOBJ_EQ(collation, op.getDelete()->getCollation().value_or(BSONObj())); + } +} + +// Multiple ops (different types) different namespaces +TEST(BulkWriteCommandModifier, TestMultiOpsDifferentNs) { + auto nss = NamespaceString("TestDB", "test"); + auto nss2 = NamespaceString("TestDB", "test1"); + auto docs = std::vector(); + docs.emplace_back(BSON("a" << 1)); + docs.emplace_back(BSON("b" << 1)); + + const BSONObj query = BSON("x" << 1); + const BSONObj collation = BSON("locale" + << "en_US"); + + BulkWriteCommandRequest request; + BulkWriteCommandModifier builder(&request); + builder.addInsertOps(nss, docs); + builder.addDeleteOp(nss2, query, true, true, collation, boost::none, boost::none, boost::none); + builder.finishBuild(); + + auto nsInfo = request.getNsInfo(); + ASSERT_EQ(2, nsInfo.size()); + ASSERT_EQ("TestDB", nsInfo[0].getNs().db()); + ASSERT_EQ("test", nsInfo[0].getNs().coll()); + ASSERT_EQ("TestDB", nsInfo[1].getNs().db()); + ASSERT_EQ("test1", nsInfo[1].getNs().coll()); + ASSERT_EQ(boost::none, nsInfo[0].getShardVersion()); + + ASSERT_EQ(3, request.getOps().size()); + { + auto op = BulkWriteCRUDOp(request.getOps()[0]); + ASSERT_EQ(BulkWriteCRUDOp::kInsert, op.getType()); + ASSERT_EQ(0, op.getNsInfoIdx()); + } + { + auto op = BulkWriteCRUDOp(request.getOps()[1]); + ASSERT_EQ(BulkWriteCRUDOp::kInsert, op.getType()); + ASSERT_EQ(0, op.getNsInfoIdx()); + } + { + auto op = BulkWriteCRUDOp(request.getOps()[2]); + ASSERT_EQ(BulkWriteCRUDOp::kDelete, op.getType()); + ASSERT_EQ(1, op.getNsInfoIdx()); + ASSERT_EQ(true, op.getDelete()->getMulti()); + ASSERT_EQ(true, op.getDelete()->getReturn()); + ASSERT_BSONOBJ_EQ(query, op.getDelete()->getFilter()); + ASSERT_BSONOBJ_EQ(collation, op.getDelete()->getCollation().value_or(BSONObj())); + } +} + +} // namespace +} // namespace mongo