142 lines
4.7 KiB
Python
142 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import sys
|
|
|
|
from pymongo import AsyncMongoClient
|
|
from pymongo.asynchronous.collection import AsyncCollection
|
|
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
|
from pymongo.asynchronous.cursor import AsyncCursor
|
|
from pymongo.asynchronous.database import AsyncDatabase
|
|
|
|
replacements = {
|
|
"Collection": "AsyncCollection",
|
|
"Database": "AsyncDatabase",
|
|
"Cursor": "AsyncCursor",
|
|
"MongoClient": "AsyncMongoClient",
|
|
"CommandCursor": "AsyncCommandCursor",
|
|
"RawBatchCursor": "AsyncRawBatchCursor",
|
|
"RawBatchCommandCursor": "AsyncRawBatchCommandCursor",
|
|
"ClientSession": "AsyncClientSession",
|
|
"ChangeStream": "AsyncChangeStream",
|
|
"CollectionChangeStream": "AsyncCollectionChangeStream",
|
|
"DatabaseChangeStream": "AsyncDatabaseChangeStream",
|
|
"ClusterChangeStream": "AsyncClusterChangeStream",
|
|
"_Bulk": "_AsyncBulk",
|
|
"_ClientBulk": "_AsyncClientBulk",
|
|
"Connection": "AsyncConnection",
|
|
"synchronous": "asynchronous",
|
|
"Synchronous": "Asynchronous",
|
|
"next": "await anext",
|
|
"_Lock": "_ALock",
|
|
"_Condition": "_ACondition",
|
|
"GridFS": "AsyncGridFS",
|
|
"GridFSBucket": "AsyncGridFSBucket",
|
|
"GridIn": "AsyncGridIn",
|
|
"GridOut": "AsyncGridOut",
|
|
"GridOutCursor": "AsyncGridOutCursor",
|
|
"GridOutIterator": "AsyncGridOutIterator",
|
|
"GridOutChunkIterator": "_AsyncGridOutChunkIterator",
|
|
"_grid_in_property": "_a_grid_in_property",
|
|
"_grid_out_property": "_a_grid_out_property",
|
|
"ClientEncryption": "AsyncClientEncryption",
|
|
"MongoCryptCallback": "AsyncMongoCryptCallback",
|
|
"ExplicitEncrypter": "AsyncExplicitEncrypter",
|
|
"AutoEncrypter": "AsyncAutoEncrypter",
|
|
"ContextManager": "AsyncContextManager",
|
|
"ClientContext": "AsyncClientContext",
|
|
"TestCollection": "AsyncTestCollection",
|
|
"IntegrationTest": "AsyncIntegrationTest",
|
|
"PyMongoTestCase": "AsyncPyMongoTestCase",
|
|
"MockClientTest": "AsyncMockClientTest",
|
|
"client_context": "async_client_context",
|
|
"setUp": "asyncSetUp",
|
|
"tearDown": "asyncTearDown",
|
|
"wait_until": "await async_wait_until",
|
|
"addCleanup": "addAsyncCleanup",
|
|
"TestCase": "IsolatedAsyncioTestCase",
|
|
"UnitTest": "AsyncUnitTest",
|
|
"MockClient": "AsyncMockClient",
|
|
"SpecRunner": "AsyncSpecRunner",
|
|
"TransactionsBase": "AsyncTransactionsBase",
|
|
"get_pool": "await async_get_pool",
|
|
"is_mongos": "await async_is_mongos",
|
|
"rs_or_single_client": "await async_rs_or_single_client",
|
|
"rs_or_single_client_noauth": "await async_rs_or_single_client_noauth",
|
|
"rs_client": "await async_rs_client",
|
|
"single_client": "await async_single_client",
|
|
"from_client": "await async_from_client",
|
|
"closing": "aclosing",
|
|
"assertRaisesExactly": "asyncAssertRaisesExactly",
|
|
"get_mock_client": "await get_async_mock_client",
|
|
"close": "await aclose",
|
|
}
|
|
|
|
async_classes = [AsyncMongoClient, AsyncDatabase, AsyncCollection, AsyncCursor, AsyncCommandCursor]
|
|
|
|
|
|
def get_async_methods() -> set[str]:
|
|
result: set[str] = set()
|
|
for x in async_classes:
|
|
methods = {
|
|
k
|
|
for k, v in vars(x).items()
|
|
if callable(v)
|
|
and not isinstance(v, classmethod)
|
|
and asyncio.iscoroutinefunction(v)
|
|
and v.__name__[0] != "_"
|
|
}
|
|
result = result | methods
|
|
return result
|
|
|
|
|
|
async_methods = get_async_methods()
|
|
|
|
|
|
def apply_replacements(lines: list[str]) -> list[str]:
|
|
for i in range(len(lines)):
|
|
if "_IS_SYNC = True" in lines[i]:
|
|
lines[i] = "_IS_SYNC = False"
|
|
if "def test" in lines[i]:
|
|
lines[i] = lines[i].replace("def test", "async def test")
|
|
for k in replacements:
|
|
if k in lines[i]:
|
|
lines[i] = lines[i].replace(k, replacements[k])
|
|
for k in async_methods:
|
|
if k + "(" in lines[i]:
|
|
tokens = lines[i].split(" ")
|
|
for j in range(len(tokens)):
|
|
if k + "(" in tokens[j]:
|
|
if j < 2:
|
|
tokens.insert(0, "await")
|
|
else:
|
|
tokens.insert(j, "await")
|
|
break
|
|
new_line = " ".join(tokens)
|
|
|
|
lines[i] = new_line
|
|
|
|
return lines
|
|
|
|
|
|
def process_file(input_file: str, output_file: str) -> None:
|
|
with open(input_file, "r+") as f:
|
|
lines = f.readlines()
|
|
lines = apply_replacements(lines)
|
|
|
|
with open(output_file, "w+") as f2:
|
|
f2.seek(0)
|
|
f2.writelines(lines)
|
|
f2.truncate()
|
|
|
|
|
|
def main() -> None:
|
|
args = sys.argv[1:]
|
|
sync_file = "./test/" + args[0]
|
|
async_file = "./" + args[0]
|
|
|
|
process_file(sync_file, async_file)
|
|
|
|
|
|
main()
|