Add script to help convert sync tests to async tests (#1825)

This commit is contained in:
Noah Stapp 2024-09-03 13:26:11 -04:00 committed by GitHub
parent a4645f0f8b
commit 3840d9dd0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 148 additions and 0 deletions

View File

@ -248,3 +248,10 @@ you are attempting to validate new spec tests in PyMongo.
## Making a Release
Follow the [Python Driver Release Process Wiki](https://wiki.corp.mongodb.com/display/DRIVERS/Python+Driver+Release+Process).
## Converting a test to async
The `tools/convert_test_to_async.py` script takes in an existing synchronous test file and outputs a
partially-converted asynchronous version of the same name to the `test/asynchronous` directory.
Use this generated file as a starting point for the completed conversion.
The script is used like so: `python tools/convert_test_to_async.py [test_file.py]`

View File

@ -0,0 +1,141 @@
from __future__ import annotations
import asyncio
import sys
from pymongo import AsyncMongoClient
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.database import AsyncDatabase
replacements = {
"Collection": "AsyncCollection",
"Database": "AsyncDatabase",
"Cursor": "AsyncCursor",
"MongoClient": "AsyncMongoClient",
"CommandCursor": "AsyncCommandCursor",
"RawBatchCursor": "AsyncRawBatchCursor",
"RawBatchCommandCursor": "AsyncRawBatchCommandCursor",
"ClientSession": "AsyncClientSession",
"ChangeStream": "AsyncChangeStream",
"CollectionChangeStream": "AsyncCollectionChangeStream",
"DatabaseChangeStream": "AsyncDatabaseChangeStream",
"ClusterChangeStream": "AsyncClusterChangeStream",
"_Bulk": "_AsyncBulk",
"_ClientBulk": "_AsyncClientBulk",
"Connection": "AsyncConnection",
"synchronous": "asynchronous",
"Synchronous": "Asynchronous",
"next": "await anext",
"_Lock": "_ALock",
"_Condition": "_ACondition",
"GridFS": "AsyncGridFS",
"GridFSBucket": "AsyncGridFSBucket",
"GridIn": "AsyncGridIn",
"GridOut": "AsyncGridOut",
"GridOutCursor": "AsyncGridOutCursor",
"GridOutIterator": "AsyncGridOutIterator",
"GridOutChunkIterator": "_AsyncGridOutChunkIterator",
"_grid_in_property": "_a_grid_in_property",
"_grid_out_property": "_a_grid_out_property",
"ClientEncryption": "AsyncClientEncryption",
"MongoCryptCallback": "AsyncMongoCryptCallback",
"ExplicitEncrypter": "AsyncExplicitEncrypter",
"AutoEncrypter": "AsyncAutoEncrypter",
"ContextManager": "AsyncContextManager",
"ClientContext": "AsyncClientContext",
"TestCollection": "AsyncTestCollection",
"IntegrationTest": "AsyncIntegrationTest",
"PyMongoTestCase": "AsyncPyMongoTestCase",
"MockClientTest": "AsyncMockClientTest",
"client_context": "async_client_context",
"setUp": "asyncSetUp",
"tearDown": "asyncTearDown",
"wait_until": "await async_wait_until",
"addCleanup": "addAsyncCleanup",
"TestCase": "IsolatedAsyncioTestCase",
"UnitTest": "AsyncUnitTest",
"MockClient": "AsyncMockClient",
"SpecRunner": "AsyncSpecRunner",
"TransactionsBase": "AsyncTransactionsBase",
"get_pool": "await async_get_pool",
"is_mongos": "await async_is_mongos",
"rs_or_single_client": "await async_rs_or_single_client",
"rs_or_single_client_noauth": "await async_rs_or_single_client_noauth",
"rs_client": "await async_rs_client",
"single_client": "await async_single_client",
"from_client": "await async_from_client",
"closing": "aclosing",
"assertRaisesExactly": "asyncAssertRaisesExactly",
"get_mock_client": "await get_async_mock_client",
"close": "await aclose",
}
async_classes = [AsyncMongoClient, AsyncDatabase, AsyncCollection, AsyncCursor, AsyncCommandCursor]
def get_async_methods() -> set[str]:
result: set[str] = set()
for x in async_classes:
methods = {
k
for k, v in vars(x).items()
if callable(v)
and not isinstance(v, classmethod)
and asyncio.iscoroutinefunction(v)
and v.__name__[0] != "_"
}
result = result | methods
return result
async_methods = get_async_methods()
def apply_replacements(lines: list[str]) -> list[str]:
for i in range(len(lines)):
if "_IS_SYNC = True" in lines[i]:
lines[i] = "_IS_SYNC = False"
if "def test" in lines[i]:
lines[i] = lines[i].replace("def test", "async def test")
for k in replacements:
if k in lines[i]:
lines[i] = lines[i].replace(k, replacements[k])
for k in async_methods:
if k + "(" in lines[i]:
tokens = lines[i].split(" ")
for j in range(len(tokens)):
if k + "(" in tokens[j]:
if j < 2:
tokens.insert(0, "await")
else:
tokens.insert(j, "await")
break
new_line = " ".join(tokens)
lines[i] = new_line
return lines
def process_file(input_file: str, output_file: str) -> None:
with open(input_file, "r+") as f:
lines = f.readlines()
lines = apply_replacements(lines)
with open(output_file, "w+") as f2:
f2.seek(0)
f2.writelines(lines)
f2.truncate()
def main() -> None:
args = sys.argv[1:]
sync_file = "./test/" + args[0]
async_file = "./" + args[0]
process_file(sync_file, async_file)
main()