Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
e296cf9f1b
@ -227,6 +227,186 @@ tasks:
|
||||
- noauth
|
||||
- nossl
|
||||
- sync_async
|
||||
- name: test-4.2-standalone-auth-ssl-sync
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: server
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
SYNC: sync
|
||||
TEST_SUITES: default
|
||||
tags:
|
||||
- "4.2"
|
||||
- standalone
|
||||
- auth
|
||||
- ssl
|
||||
- sync
|
||||
- name: test-4.2-standalone-auth-ssl-async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: server
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
SYNC: async
|
||||
TEST_SUITES: default_async
|
||||
tags:
|
||||
- "4.2"
|
||||
- standalone
|
||||
- auth
|
||||
- ssl
|
||||
- async
|
||||
- name: test-4.2-standalone-auth-ssl-sync_async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: server
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
SYNC: sync_async
|
||||
TEST_SUITES: ""
|
||||
tags:
|
||||
- "4.2"
|
||||
- standalone
|
||||
- auth
|
||||
- ssl
|
||||
- sync_async
|
||||
- name: test-4.2-standalone-noauth-ssl-sync
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: server
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
SYNC: sync
|
||||
TEST_SUITES: default
|
||||
tags:
|
||||
- "4.2"
|
||||
- standalone
|
||||
- noauth
|
||||
- ssl
|
||||
- sync
|
||||
- name: test-4.2-standalone-noauth-ssl-async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: server
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
SYNC: async
|
||||
TEST_SUITES: default_async
|
||||
tags:
|
||||
- "4.2"
|
||||
- standalone
|
||||
- noauth
|
||||
- ssl
|
||||
- async
|
||||
- name: test-4.2-standalone-noauth-ssl-sync_async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: server
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
SYNC: sync_async
|
||||
TEST_SUITES: ""
|
||||
tags:
|
||||
- "4.2"
|
||||
- standalone
|
||||
- noauth
|
||||
- ssl
|
||||
- sync_async
|
||||
- name: test-4.2-standalone-noauth-nossl-sync
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: server
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
SYNC: sync
|
||||
TEST_SUITES: default
|
||||
tags:
|
||||
- "4.2"
|
||||
- standalone
|
||||
- noauth
|
||||
- nossl
|
||||
- sync
|
||||
- name: test-4.2-standalone-noauth-nossl-async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: server
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
SYNC: async
|
||||
TEST_SUITES: default_async
|
||||
tags:
|
||||
- "4.2"
|
||||
- standalone
|
||||
- noauth
|
||||
- nossl
|
||||
- async
|
||||
- name: test-4.2-standalone-noauth-nossl-sync_async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: server
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
SYNC: sync_async
|
||||
TEST_SUITES: ""
|
||||
tags:
|
||||
- "4.2"
|
||||
- standalone
|
||||
- noauth
|
||||
- nossl
|
||||
- sync_async
|
||||
- name: test-4.4-standalone-auth-ssl-sync
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
@ -1667,6 +1847,186 @@ tasks:
|
||||
- noauth
|
||||
- nossl
|
||||
- sync_async
|
||||
- name: test-4.2-replica_set-auth-ssl-sync
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: replica_set
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
SYNC: sync
|
||||
TEST_SUITES: default
|
||||
tags:
|
||||
- "4.2"
|
||||
- replica_set
|
||||
- auth
|
||||
- ssl
|
||||
- sync
|
||||
- name: test-4.2-replica_set-auth-ssl-async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: replica_set
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
SYNC: async
|
||||
TEST_SUITES: default_async
|
||||
tags:
|
||||
- "4.2"
|
||||
- replica_set
|
||||
- auth
|
||||
- ssl
|
||||
- async
|
||||
- name: test-4.2-replica_set-auth-ssl-sync_async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: replica_set
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
SYNC: sync_async
|
||||
TEST_SUITES: ""
|
||||
tags:
|
||||
- "4.2"
|
||||
- replica_set
|
||||
- auth
|
||||
- ssl
|
||||
- sync_async
|
||||
- name: test-4.2-replica_set-noauth-ssl-sync
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: replica_set
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
SYNC: sync
|
||||
TEST_SUITES: default
|
||||
tags:
|
||||
- "4.2"
|
||||
- replica_set
|
||||
- noauth
|
||||
- ssl
|
||||
- sync
|
||||
- name: test-4.2-replica_set-noauth-ssl-async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: replica_set
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
SYNC: async
|
||||
TEST_SUITES: default_async
|
||||
tags:
|
||||
- "4.2"
|
||||
- replica_set
|
||||
- noauth
|
||||
- ssl
|
||||
- async
|
||||
- name: test-4.2-replica_set-noauth-ssl-sync_async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: replica_set
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
SYNC: sync_async
|
||||
TEST_SUITES: ""
|
||||
tags:
|
||||
- "4.2"
|
||||
- replica_set
|
||||
- noauth
|
||||
- ssl
|
||||
- sync_async
|
||||
- name: test-4.2-replica_set-noauth-nossl-sync
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: replica_set
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
SYNC: sync
|
||||
TEST_SUITES: default
|
||||
tags:
|
||||
- "4.2"
|
||||
- replica_set
|
||||
- noauth
|
||||
- nossl
|
||||
- sync
|
||||
- name: test-4.2-replica_set-noauth-nossl-async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: replica_set
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
SYNC: async
|
||||
TEST_SUITES: default_async
|
||||
tags:
|
||||
- "4.2"
|
||||
- replica_set
|
||||
- noauth
|
||||
- nossl
|
||||
- async
|
||||
- name: test-4.2-replica_set-noauth-nossl-sync_async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: replica_set
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
SYNC: sync_async
|
||||
TEST_SUITES: ""
|
||||
tags:
|
||||
- "4.2"
|
||||
- replica_set
|
||||
- noauth
|
||||
- nossl
|
||||
- sync_async
|
||||
- name: test-4.4-replica_set-auth-ssl-sync
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
@ -3107,6 +3467,186 @@ tasks:
|
||||
- noauth
|
||||
- nossl
|
||||
- sync_async
|
||||
- name: test-4.2-sharded_cluster-auth-ssl-sync
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: sharded_cluster
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
SYNC: sync
|
||||
TEST_SUITES: default
|
||||
tags:
|
||||
- "4.2"
|
||||
- sharded_cluster
|
||||
- auth
|
||||
- ssl
|
||||
- sync
|
||||
- name: test-4.2-sharded_cluster-auth-ssl-async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: sharded_cluster
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
SYNC: async
|
||||
TEST_SUITES: default_async
|
||||
tags:
|
||||
- "4.2"
|
||||
- sharded_cluster
|
||||
- auth
|
||||
- ssl
|
||||
- async
|
||||
- name: test-4.2-sharded_cluster-auth-ssl-sync_async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: sharded_cluster
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: auth
|
||||
SSL: ssl
|
||||
SYNC: sync_async
|
||||
TEST_SUITES: ""
|
||||
tags:
|
||||
- "4.2"
|
||||
- sharded_cluster
|
||||
- auth
|
||||
- ssl
|
||||
- sync_async
|
||||
- name: test-4.2-sharded_cluster-noauth-ssl-sync
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: sharded_cluster
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
SYNC: sync
|
||||
TEST_SUITES: default
|
||||
tags:
|
||||
- "4.2"
|
||||
- sharded_cluster
|
||||
- noauth
|
||||
- ssl
|
||||
- sync
|
||||
- name: test-4.2-sharded_cluster-noauth-ssl-async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: sharded_cluster
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
SYNC: async
|
||||
TEST_SUITES: default_async
|
||||
tags:
|
||||
- "4.2"
|
||||
- sharded_cluster
|
||||
- noauth
|
||||
- ssl
|
||||
- async
|
||||
- name: test-4.2-sharded_cluster-noauth-ssl-sync_async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: sharded_cluster
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
SYNC: sync_async
|
||||
TEST_SUITES: ""
|
||||
tags:
|
||||
- "4.2"
|
||||
- sharded_cluster
|
||||
- noauth
|
||||
- ssl
|
||||
- sync_async
|
||||
- name: test-4.2-sharded_cluster-noauth-nossl-sync
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: sharded_cluster
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
SYNC: sync
|
||||
TEST_SUITES: default
|
||||
tags:
|
||||
- "4.2"
|
||||
- sharded_cluster
|
||||
- noauth
|
||||
- nossl
|
||||
- sync
|
||||
- name: test-4.2-sharded_cluster-noauth-nossl-async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: sharded_cluster
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
SYNC: async
|
||||
TEST_SUITES: default_async
|
||||
tags:
|
||||
- "4.2"
|
||||
- sharded_cluster
|
||||
- noauth
|
||||
- nossl
|
||||
- async
|
||||
- name: test-4.2-sharded_cluster-noauth-nossl-sync_async
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
vars:
|
||||
VERSION: "4.2"
|
||||
TOPOLOGY: sharded_cluster
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
- func: run tests
|
||||
vars:
|
||||
AUTH: noauth
|
||||
SSL: nossl
|
||||
SYNC: sync_async
|
||||
TEST_SUITES: ""
|
||||
tags:
|
||||
- "4.2"
|
||||
- sharded_cluster
|
||||
- noauth
|
||||
- nossl
|
||||
- sync_async
|
||||
- name: test-4.4-sharded_cluster-auth-ssl-sync
|
||||
commands:
|
||||
- func: bootstrap mongo-orchestration
|
||||
|
||||
@ -817,10 +817,23 @@ buildvariants:
|
||||
PYTHON_BINARY: /opt/python/3.13/bin/python3
|
||||
|
||||
# Ocsp tests
|
||||
- name: ocsp-rhel8-v4.4-python3.9
|
||||
- name: ocsp-rhel8-v4.2-python3.9
|
||||
tasks:
|
||||
- name: .ocsp
|
||||
display_name: OCSP RHEL8 v4.4 Python3.9
|
||||
display_name: OCSP RHEL8 v4.2 Python3.9
|
||||
run_on:
|
||||
- rhel87-small
|
||||
batchtime: 20160
|
||||
expansions:
|
||||
AUTH: noauth
|
||||
SSL: ssl
|
||||
TOPOLOGY: server
|
||||
VERSION: "4.2"
|
||||
PYTHON_BINARY: /opt/python/3.9/bin/python3
|
||||
- name: ocsp-rhel8-v4.4-python3.10
|
||||
tasks:
|
||||
- name: .ocsp
|
||||
display_name: OCSP RHEL8 v4.4 Python3.10
|
||||
run_on:
|
||||
- rhel87-small
|
||||
batchtime: 20160
|
||||
@ -829,11 +842,11 @@ buildvariants:
|
||||
SSL: ssl
|
||||
TOPOLOGY: server
|
||||
VERSION: "4.4"
|
||||
PYTHON_BINARY: /opt/python/3.9/bin/python3
|
||||
- name: ocsp-rhel8-v5.0-python3.10
|
||||
PYTHON_BINARY: /opt/python/3.10/bin/python3
|
||||
- name: ocsp-rhel8-v5.0-python3.11
|
||||
tasks:
|
||||
- name: .ocsp
|
||||
display_name: OCSP RHEL8 v5.0 Python3.10
|
||||
display_name: OCSP RHEL8 v5.0 Python3.11
|
||||
run_on:
|
||||
- rhel87-small
|
||||
batchtime: 20160
|
||||
@ -842,11 +855,11 @@ buildvariants:
|
||||
SSL: ssl
|
||||
TOPOLOGY: server
|
||||
VERSION: "5.0"
|
||||
PYTHON_BINARY: /opt/python/3.10/bin/python3
|
||||
- name: ocsp-rhel8-v6.0-python3.11
|
||||
PYTHON_BINARY: /opt/python/3.11/bin/python3
|
||||
- name: ocsp-rhel8-v6.0-python3.12
|
||||
tasks:
|
||||
- name: .ocsp
|
||||
display_name: OCSP RHEL8 v6.0 Python3.11
|
||||
display_name: OCSP RHEL8 v6.0 Python3.12
|
||||
run_on:
|
||||
- rhel87-small
|
||||
batchtime: 20160
|
||||
@ -855,11 +868,11 @@ buildvariants:
|
||||
SSL: ssl
|
||||
TOPOLOGY: server
|
||||
VERSION: "6.0"
|
||||
PYTHON_BINARY: /opt/python/3.11/bin/python3
|
||||
- name: ocsp-rhel8-v7.0-python3.12
|
||||
PYTHON_BINARY: /opt/python/3.12/bin/python3
|
||||
- name: ocsp-rhel8-v7.0-python3.13
|
||||
tasks:
|
||||
- name: .ocsp
|
||||
display_name: OCSP RHEL8 v7.0 Python3.12
|
||||
display_name: OCSP RHEL8 v7.0 Python3.13
|
||||
run_on:
|
||||
- rhel87-small
|
||||
batchtime: 20160
|
||||
@ -868,11 +881,11 @@ buildvariants:
|
||||
SSL: ssl
|
||||
TOPOLOGY: server
|
||||
VERSION: "7.0"
|
||||
PYTHON_BINARY: /opt/python/3.12/bin/python3
|
||||
- name: ocsp-rhel8-v8.0-python3.13
|
||||
PYTHON_BINARY: /opt/python/3.13/bin/python3
|
||||
- name: ocsp-rhel8-v8.0-pypy3.10
|
||||
tasks:
|
||||
- name: .ocsp
|
||||
display_name: OCSP RHEL8 v8.0 Python3.13
|
||||
display_name: OCSP RHEL8 v8.0 PyPy3.10
|
||||
run_on:
|
||||
- rhel87-small
|
||||
batchtime: 20160
|
||||
@ -881,11 +894,11 @@ buildvariants:
|
||||
SSL: ssl
|
||||
TOPOLOGY: server
|
||||
VERSION: "8.0"
|
||||
PYTHON_BINARY: /opt/python/3.13/bin/python3
|
||||
- name: ocsp-rhel8-rapid-pypy3.10
|
||||
PYTHON_BINARY: /opt/python/pypy3.10/bin/python3
|
||||
- name: ocsp-rhel8-rapid-python3.9
|
||||
tasks:
|
||||
- name: .ocsp
|
||||
display_name: OCSP RHEL8 rapid PyPy3.10
|
||||
display_name: OCSP RHEL8 rapid Python3.9
|
||||
run_on:
|
||||
- rhel87-small
|
||||
batchtime: 20160
|
||||
@ -894,11 +907,11 @@ buildvariants:
|
||||
SSL: ssl
|
||||
TOPOLOGY: server
|
||||
VERSION: rapid
|
||||
PYTHON_BINARY: /opt/python/pypy3.10/bin/python3
|
||||
- name: ocsp-rhel8-latest-python3.9
|
||||
PYTHON_BINARY: /opt/python/3.9/bin/python3
|
||||
- name: ocsp-rhel8-latest-python3.10
|
||||
tasks:
|
||||
- name: .ocsp
|
||||
display_name: OCSP RHEL8 latest Python3.9
|
||||
display_name: OCSP RHEL8 latest Python3.10
|
||||
run_on:
|
||||
- rhel87-small
|
||||
batchtime: 20160
|
||||
@ -907,7 +920,7 @@ buildvariants:
|
||||
SSL: ssl
|
||||
TOPOLOGY: server
|
||||
VERSION: latest
|
||||
PYTHON_BINARY: /opt/python/3.9/bin/python3
|
||||
PYTHON_BINARY: /opt/python/3.10/bin/python3
|
||||
- name: ocsp-win64-v4.4-python3.9
|
||||
tasks:
|
||||
- name: .ocsp-rsa !.ocsp-staple
|
||||
@ -1066,6 +1079,19 @@ buildvariants:
|
||||
PYTHON_BINARY: /opt/python/3.9/bin/python3
|
||||
|
||||
# Server tests
|
||||
- name: test-rhel8-python3.9-cov-no-c
|
||||
tasks:
|
||||
- name: .standalone .sync_async
|
||||
- name: .replica_set .sync_async
|
||||
- name: .sharded_cluster .sync_async
|
||||
display_name: "* Test RHEL8 Python3.9 cov No C"
|
||||
run_on:
|
||||
- rhel87-small
|
||||
expansions:
|
||||
COVERAGE: coverage
|
||||
NO_EXT: "1"
|
||||
PYTHON_BINARY: /opt/python/3.9/bin/python3
|
||||
tags: [coverage_tag]
|
||||
- name: test-rhel8-python3.9-cov
|
||||
tasks:
|
||||
- name: .standalone .sync_async
|
||||
@ -1078,6 +1104,19 @@ buildvariants:
|
||||
COVERAGE: coverage
|
||||
PYTHON_BINARY: /opt/python/3.9/bin/python3
|
||||
tags: [coverage_tag]
|
||||
- name: test-rhel8-python3.13-cov-no-c
|
||||
tasks:
|
||||
- name: .standalone .sync_async
|
||||
- name: .replica_set .sync_async
|
||||
- name: .sharded_cluster .sync_async
|
||||
display_name: "* Test RHEL8 Python3.13 cov No C"
|
||||
run_on:
|
||||
- rhel87-small
|
||||
expansions:
|
||||
COVERAGE: coverage
|
||||
NO_EXT: "1"
|
||||
PYTHON_BINARY: /opt/python/3.13/bin/python3
|
||||
tags: [coverage_tag]
|
||||
- name: test-rhel8-python3.13-cov
|
||||
tasks:
|
||||
- name: .standalone .sync_async
|
||||
@ -1090,6 +1129,19 @@ buildvariants:
|
||||
COVERAGE: coverage
|
||||
PYTHON_BINARY: /opt/python/3.13/bin/python3
|
||||
tags: [coverage_tag]
|
||||
- name: test-rhel8-pypy3.10-cov-no-c
|
||||
tasks:
|
||||
- name: .standalone .sync_async
|
||||
- name: .replica_set .sync_async
|
||||
- name: .sharded_cluster .sync_async
|
||||
display_name: "* Test RHEL8 PyPy3.10 cov No C"
|
||||
run_on:
|
||||
- rhel87-small
|
||||
expansions:
|
||||
COVERAGE: coverage
|
||||
NO_EXT: "1"
|
||||
PYTHON_BINARY: /opt/python/pypy3.10/bin/python3
|
||||
tags: [coverage_tag]
|
||||
- name: test-rhel8-pypy3.10-cov
|
||||
tasks:
|
||||
- name: .standalone .sync_async
|
||||
@ -1338,6 +1390,7 @@ buildvariants:
|
||||
- name: storage-inmemory-rhel8-python3.9
|
||||
tasks:
|
||||
- name: .standalone .noauth .nossl .4.0 .sync_async
|
||||
- name: .standalone .noauth .nossl .4.2 .sync_async
|
||||
- name: .standalone .noauth .nossl .4.4 .sync_async
|
||||
- name: .standalone .noauth .nossl .5.0 .sync_async
|
||||
- name: .standalone .noauth .nossl .6.0 .sync_async
|
||||
|
||||
@ -26,7 +26,7 @@ from shrub.v3.shrub_service import ShrubService
|
||||
# Globals
|
||||
##############
|
||||
|
||||
ALL_VERSIONS = ["4.0", "4.4", "5.0", "6.0", "7.0", "8.0", "rapid", "latest"]
|
||||
ALL_VERSIONS = ["4.0", "4.2", "4.4", "5.0", "6.0", "7.0", "8.0", "rapid", "latest"]
|
||||
CPYTHONS = ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
PYPYS = ["pypy3.10"]
|
||||
ALL_PYTHONS = CPYTHONS + PYPYS
|
||||
@ -279,8 +279,9 @@ def create_server_variants() -> list[BuildVariant]:
|
||||
host = DEFAULT_HOST
|
||||
# Prefix the display name with an asterisk so it is sorted first.
|
||||
base_display_name = "* Test"
|
||||
for python in [*MIN_MAX_PYTHON, PYPYS[-1]]:
|
||||
for python, c_ext in product([*MIN_MAX_PYTHON, PYPYS[-1]], C_EXTS):
|
||||
expansions = dict(COVERAGE="coverage")
|
||||
handle_c_ext(c_ext, expansions)
|
||||
display_name = get_display_name(base_display_name, host, python=python, **expansions)
|
||||
variant = create_variant(
|
||||
[f".{t} .sync_async" for t in TOPOLOGIES],
|
||||
|
||||
@ -178,7 +178,7 @@ documentation including narrative docs, and the [Sphinx docstring format](https:
|
||||
You can build the documentation locally by running:
|
||||
|
||||
```bash
|
||||
just docs-build
|
||||
just docs
|
||||
```
|
||||
|
||||
When updating docs, it can be helpful to run the live docs server as:
|
||||
@ -261,6 +261,11 @@ To prevent the `synchro` hook from accidentally overwriting code, it first check
|
||||
of a file is changing and not its async counterpart, and will fail.
|
||||
In the unlikely scenario that you want to override this behavior, first export `OVERRIDE_SYNCHRO_CHECK=1`.
|
||||
|
||||
Sometimes, the `synchro` hook will fail and introduce changes many previously unmodified files. This is due to static
|
||||
Python errors, such as missing imports, incorrect syntax, or other fatal typos. To resolve these issues,
|
||||
run `pre-commit run --all-files --hook-stage manual ruff` and fix all reported errors before running the `synchro`
|
||||
hook again.
|
||||
|
||||
## Converting a test to async
|
||||
The `tools/convert_test_to_async.py` script takes in an existing synchronous test file and outputs a
|
||||
partially-converted asynchronous version of the same name to the `test/asynchronous` directory.
|
||||
|
||||
@ -420,3 +420,10 @@ the collection:
|
||||
DuplicateKeyError: E11000 duplicate key error index: test_database.profiles.$user_id_1 dup key: { : 212 }
|
||||
|
||||
.. seealso:: The MongoDB documentation on `indexes <https://www.mongodb.com/docs/manual/indexes/>`_
|
||||
|
||||
Task Cancellation
|
||||
-----------------
|
||||
`Cancelling <https://docs.python.org/3/library/asyncio-task.html#task-cancellation>`_ an asyncio Task
|
||||
that is running a PyMongo operation is treated as a fatal interrupt. Any connections, cursors, and transactions
|
||||
involved in a cancelled Task will be safely closed and cleaned up as part of the cancellation. If those resources are
|
||||
also used elsewhere, attempting to utilize them after the cancellation will result in an error.
|
||||
|
||||
@ -231,7 +231,7 @@ class AsyncGridFS:
|
||||
try:
|
||||
doc = await anext(cursor)
|
||||
return AsyncGridOut(self._collection, file_document=doc, session=session)
|
||||
except StopIteration:
|
||||
except StopAsyncIteration:
|
||||
raise NoFile("no version %d for filename %r" % (version, filename)) from None
|
||||
|
||||
async def get_last_version(
|
||||
|
||||
@ -391,7 +391,8 @@ class AsyncChangeStream(Generic[_DocumentType]):
|
||||
if not _resumable(exc) and not exc.timeout:
|
||||
await self.close()
|
||||
raise
|
||||
except Exception:
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException:
|
||||
await self.close()
|
||||
raise
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ Causally Consistent Reads
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with client.start_session(causal_consistency=True) as session:
|
||||
async with client.start_session(causal_consistency=True) as session:
|
||||
collection = client.db.collection
|
||||
await collection.update_one({"_id": 1}, {"$set": {"x": 10}}, session=session)
|
||||
secondary_c = collection.with_options(read_preference=ReadPreference.SECONDARY)
|
||||
@ -53,8 +53,8 @@ operation:
|
||||
|
||||
orders = client.db.orders
|
||||
inventory = client.db.inventory
|
||||
with client.start_session() as session:
|
||||
async with session.start_transaction():
|
||||
async with client.start_session() as session:
|
||||
async with await session.start_transaction():
|
||||
await orders.insert_one({"sku": "abc123", "qty": 100}, session=session)
|
||||
await inventory.update_one(
|
||||
{"sku": "abc123", "qty": {"$gte": 100}},
|
||||
@ -62,7 +62,7 @@ operation:
|
||||
session=session,
|
||||
)
|
||||
|
||||
Upon normal completion of ``async with session.start_transaction()`` block, the
|
||||
Upon normal completion of ``async with await session.start_transaction()`` block, the
|
||||
transaction automatically calls :meth:`AsyncClientSession.commit_transaction`.
|
||||
If the block exits with an exception, the transaction automatically calls
|
||||
:meth:`AsyncClientSession.abort_transaction`.
|
||||
@ -113,7 +113,7 @@ replica set secondaries.
|
||||
.. code-block:: python
|
||||
|
||||
# Each read using this session reads data from the same point in time.
|
||||
with client.start_session(snapshot=True) as session:
|
||||
async with client.start_session(snapshot=True) as session:
|
||||
order = await orders.find_one({"sku": "abc123"}, session=session)
|
||||
inventory = await inventory.find_one({"sku": "abc123"}, session=session)
|
||||
|
||||
@ -619,7 +619,7 @@ class AsyncClientSession:
|
||||
await inventory.update_one({"sku": "abc123", "qty": {"$gte": 100}},
|
||||
{"$inc": {"qty": -100}}, session=session)
|
||||
|
||||
with client.start_session() as session:
|
||||
async with client.start_session() as session:
|
||||
await session.with_transaction(callback)
|
||||
|
||||
To pass arbitrary arguments to the ``callback``, wrap your callable
|
||||
@ -628,7 +628,7 @@ class AsyncClientSession:
|
||||
async def callback(session, custom_arg, custom_kwarg=None):
|
||||
# Transaction operations...
|
||||
|
||||
with client.start_session() as session:
|
||||
async with client.start_session() as session:
|
||||
await session.with_transaction(
|
||||
lambda s: callback(s, "custom_arg", custom_kwarg=1))
|
||||
|
||||
@ -697,7 +697,8 @@ class AsyncClientSession:
|
||||
)
|
||||
try:
|
||||
ret = await callback(self)
|
||||
except Exception as exc:
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as exc:
|
||||
if self.in_transaction:
|
||||
await self.abort_transaction()
|
||||
if (
|
||||
|
||||
@ -1126,7 +1126,8 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
self._killed = True
|
||||
await self.close()
|
||||
raise
|
||||
except Exception:
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException:
|
||||
await self.close()
|
||||
raise
|
||||
|
||||
|
||||
@ -127,8 +127,6 @@ def _wrap_encryption_errors() -> Iterator[None]:
|
||||
# BSON encoding/decoding errors are unrelated to encryption so
|
||||
# we should propagate them unchanged.
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise EncryptionError(exc) from exc
|
||||
|
||||
@ -766,8 +764,6 @@ class AsyncClientEncryption(Generic[_DocumentType]):
|
||||
await database.create_collection(name=name, **kwargs),
|
||||
encrypted_fields,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise EncryptedCollectionError(exc, encrypted_fields) from exc
|
||||
|
||||
|
||||
@ -276,7 +276,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
:param type_registry: instance of
|
||||
:class:`~bson.codec_options.TypeRegistry` to enable encoding
|
||||
and decoding of custom types.
|
||||
:param datetime_conversion: Specifies how UTC datetimes should be decoded
|
||||
:param kwargs: **Additional optional parameters available as keyword arguments:**
|
||||
|
||||
- `datetime_conversion` (optional): Specifies how UTC datetimes should be decoded
|
||||
within BSON. Valid options include 'datetime_ms' to return as a
|
||||
DatetimeMS, 'datetime' to return as a datetime.datetime and
|
||||
raising a ValueError for out-of-range values, 'datetime_auto' to
|
||||
@ -284,9 +286,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
out-of-range and 'datetime_clamp' to clamp to the minimum and
|
||||
maximum possible datetimes. Defaults to 'datetime'. See
|
||||
:ref:`handling-out-of-range-datetimes` for details.
|
||||
|
||||
| **Other optional parameters can be passed as keyword arguments:**
|
||||
|
||||
- `directConnection` (optional): if ``True``, forces this client to
|
||||
connect directly to the specified MongoDB host as a standalone.
|
||||
If ``false``, the client connects to the entire replica set of
|
||||
@ -2044,8 +2043,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
for address, cursor_id, conn_mgr in pinned_cursors:
|
||||
try:
|
||||
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
# Raise the exception when client is closed so that it
|
||||
@ -2060,8 +2057,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
for address, cursor_ids in address_to_cursor_ids.items():
|
||||
try:
|
||||
await self._kill_cursors(cursor_ids, address, topology, session=None)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
raise
|
||||
@ -2076,8 +2071,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
try:
|
||||
await self._process_kill_cursors()
|
||||
await self._topology.update_pool()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
return
|
||||
|
||||
@ -262,8 +262,6 @@ class Monitor(MonitorBase):
|
||||
details = cast(Mapping[str, Any], exc.details)
|
||||
await self._topology.receive_cluster_time(details.get("$clusterTime"))
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except ReferenceError:
|
||||
raise
|
||||
except Exception as error:
|
||||
@ -429,8 +427,6 @@ class SrvMonitor(MonitorBase):
|
||||
if len(seedlist) == 0:
|
||||
# As per the spec: this should be treated as a failure.
|
||||
raise Exception
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
# As per the spec, upon encountering an error:
|
||||
# - An error must not be raised
|
||||
@ -494,8 +490,6 @@ class _RttMonitor(MonitorBase):
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
await self.close()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
await self._pool.reset()
|
||||
|
||||
|
||||
@ -559,7 +559,7 @@ class AsyncConnection:
|
||||
)
|
||||
except (OperationFailure, NotPrimaryError):
|
||||
raise
|
||||
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
|
||||
# Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves.
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
|
||||
@ -576,6 +576,7 @@ class AsyncConnection:
|
||||
|
||||
try:
|
||||
await async_sendall(self.conn, message)
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
|
||||
@ -586,6 +587,7 @@ class AsyncConnection:
|
||||
"""
|
||||
try:
|
||||
return await receive_message(self, request_id, self.max_message_size)
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
|
||||
@ -704,8 +706,6 @@ class AsyncConnection:
|
||||
# shutdown.
|
||||
try:
|
||||
self.conn.close()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
@ -1269,6 +1269,7 @@ class Pool:
|
||||
|
||||
try:
|
||||
sock = await _configured_socket(self.address, self.opts)
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as error:
|
||||
async with self.lock:
|
||||
self.active_contexts.discard(tmp_context)
|
||||
@ -1308,6 +1309,7 @@ class Pool:
|
||||
handler.contribute_socket(conn, completed_handshake=False)
|
||||
|
||||
await conn.authenticate()
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException:
|
||||
async with self.lock:
|
||||
self.active_contexts.discard(conn.cancel_context)
|
||||
@ -1369,6 +1371,7 @@ class Pool:
|
||||
async with self.lock:
|
||||
self.active_contexts.add(conn.cancel_context)
|
||||
yield conn
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException:
|
||||
# Exception in caller. Ensure the connection gets returned.
|
||||
# Note that when pinned is True, the session owns the
|
||||
@ -1515,6 +1518,7 @@ class Pool:
|
||||
async with self._max_connecting_cond:
|
||||
self._pending -= 1
|
||||
self._max_connecting_cond.notify()
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException:
|
||||
if conn:
|
||||
# We checked out a socket but authentication failed.
|
||||
|
||||
@ -100,6 +100,7 @@ class AsyncPeriodicExecutor:
|
||||
if not await self._target():
|
||||
self._stopped = True
|
||||
break
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException:
|
||||
self._stopped = True
|
||||
raise
|
||||
@ -232,6 +233,7 @@ class PeriodicExecutor:
|
||||
if not self._target():
|
||||
self._stopped = True
|
||||
break
|
||||
# Catch KeyboardInterrupt, etc. and cleanup.
|
||||
except BaseException:
|
||||
with self._lock:
|
||||
self._stopped = True
|
||||
|
||||
@ -389,7 +389,8 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
if not _resumable(exc) and not exc.timeout:
|
||||
self.close()
|
||||
raise
|
||||
except Exception:
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
|
||||
@ -694,7 +694,8 @@ class ClientSession:
|
||||
self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms)
|
||||
try:
|
||||
ret = callback(self)
|
||||
except Exception as exc:
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as exc:
|
||||
if self.in_transaction:
|
||||
self.abort_transaction()
|
||||
if (
|
||||
|
||||
@ -1124,7 +1124,8 @@ class Cursor(Generic[_DocumentType]):
|
||||
self._killed = True
|
||||
self.close()
|
||||
raise
|
||||
except Exception:
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
"""Support for explicit client-side field level encryption."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import enum
|
||||
import socket
|
||||
@ -127,8 +126,6 @@ def _wrap_encryption_errors() -> Iterator[None]:
|
||||
# BSON encoding/decoding errors are unrelated to encryption so
|
||||
# we should propagate them unchanged.
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise EncryptionError(exc) from exc
|
||||
|
||||
@ -760,8 +757,6 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
database.create_collection(name=name, **kwargs),
|
||||
encrypted_fields,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise EncryptedCollectionError(exc, encrypted_fields) from exc
|
||||
|
||||
|
||||
@ -274,7 +274,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
:param type_registry: instance of
|
||||
:class:`~bson.codec_options.TypeRegistry` to enable encoding
|
||||
and decoding of custom types.
|
||||
:param datetime_conversion: Specifies how UTC datetimes should be decoded
|
||||
:param kwargs: **Additional optional parameters available as keyword arguments:**
|
||||
|
||||
- `datetime_conversion` (optional): Specifies how UTC datetimes should be decoded
|
||||
within BSON. Valid options include 'datetime_ms' to return as a
|
||||
DatetimeMS, 'datetime' to return as a datetime.datetime and
|
||||
raising a ValueError for out-of-range values, 'datetime_auto' to
|
||||
@ -282,9 +284,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
out-of-range and 'datetime_clamp' to clamp to the minimum and
|
||||
maximum possible datetimes. Defaults to 'datetime'. See
|
||||
:ref:`handling-out-of-range-datetimes` for details.
|
||||
|
||||
| **Other optional parameters can be passed as keyword arguments:**
|
||||
|
||||
- `directConnection` (optional): if ``True``, forces this client to
|
||||
connect directly to the specified MongoDB host as a standalone.
|
||||
If ``false``, the client connects to the entire replica set of
|
||||
@ -2038,8 +2037,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
for address, cursor_id, conn_mgr in pinned_cursors:
|
||||
try:
|
||||
self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
# Raise the exception when client is closed so that it
|
||||
@ -2054,8 +2051,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
for address, cursor_ids in address_to_cursor_ids.items():
|
||||
try:
|
||||
self._kill_cursors(cursor_ids, address, topology, session=None)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
raise
|
||||
@ -2070,8 +2065,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
try:
|
||||
self._process_kill_cursors()
|
||||
self._topology.update_pool()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
return
|
||||
|
||||
@ -260,8 +260,6 @@ class Monitor(MonitorBase):
|
||||
details = cast(Mapping[str, Any], exc.details)
|
||||
self._topology.receive_cluster_time(details.get("$clusterTime"))
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except ReferenceError:
|
||||
raise
|
||||
except Exception as error:
|
||||
@ -427,8 +425,6 @@ class SrvMonitor(MonitorBase):
|
||||
if len(seedlist) == 0:
|
||||
# As per the spec: this should be treated as a failure.
|
||||
raise Exception
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
# As per the spec, upon encountering an error:
|
||||
# - An error must not be raised
|
||||
@ -492,8 +488,6 @@ class _RttMonitor(MonitorBase):
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
self.close()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
self._pool.reset()
|
||||
|
||||
|
||||
@ -559,7 +559,7 @@ class Connection:
|
||||
)
|
||||
except (OperationFailure, NotPrimaryError):
|
||||
raise
|
||||
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
|
||||
# Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves.
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
|
||||
@ -576,6 +576,7 @@ class Connection:
|
||||
|
||||
try:
|
||||
sendall(self.conn, message)
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
|
||||
@ -586,6 +587,7 @@ class Connection:
|
||||
"""
|
||||
try:
|
||||
return receive_message(self, request_id, self.max_message_size)
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
|
||||
@ -702,8 +704,6 @@ class Connection:
|
||||
# shutdown.
|
||||
try:
|
||||
self.conn.close()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
@ -1263,6 +1263,7 @@ class Pool:
|
||||
|
||||
try:
|
||||
sock = _configured_socket(self.address, self.opts)
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as error:
|
||||
with self.lock:
|
||||
self.active_contexts.discard(tmp_context)
|
||||
@ -1302,6 +1303,7 @@ class Pool:
|
||||
handler.contribute_socket(conn, completed_handshake=False)
|
||||
|
||||
conn.authenticate()
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException:
|
||||
with self.lock:
|
||||
self.active_contexts.discard(conn.cancel_context)
|
||||
@ -1363,6 +1365,7 @@ class Pool:
|
||||
with self.lock:
|
||||
self.active_contexts.add(conn.cancel_context)
|
||||
yield conn
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException:
|
||||
# Exception in caller. Ensure the connection gets returned.
|
||||
# Note that when pinned is True, the session owns the
|
||||
@ -1509,6 +1512,7 @@ class Pool:
|
||||
with self._max_connecting_cond:
|
||||
self._pending -= 1
|
||||
self._max_connecting_cond.notify()
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException:
|
||||
if conn:
|
||||
# We checked out a socket but authentication failed.
|
||||
|
||||
@ -66,7 +66,7 @@ class DummyMonitor:
|
||||
def cancel_check(self):
|
||||
pass
|
||||
|
||||
def join(self):
|
||||
async def join(self):
|
||||
pass
|
||||
|
||||
def open(self):
|
||||
@ -75,7 +75,7 @@ class DummyMonitor:
|
||||
def request_check(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
async def close(self):
|
||||
self.opened = False
|
||||
|
||||
|
||||
|
||||
126
test/asynchronous/test_async_cancellation.py
Normal file
126
test/asynchronous/test_async_cancellation.py
Normal file
@ -0,0 +1,126 @@
|
||||
# Copyright 2025-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test that async cancellation performed by users clean up resources correctly."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from test.utils import async_get_pool, delay, one
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, connected
|
||||
|
||||
|
||||
class TestAsyncCancellation(AsyncIntegrationTest):
|
||||
async def test_async_cancellation_closes_connection(self):
|
||||
pool = await async_get_pool(self.client)
|
||||
await self.client.db.test.insert_one({"x": 1})
|
||||
self.addAsyncCleanup(self.client.db.test.delete_many, {})
|
||||
|
||||
conn = one(pool.conns)
|
||||
|
||||
async def task():
|
||||
await self.client.db.test.find_one({"$where": delay(0.2)})
|
||||
|
||||
task = asyncio.create_task(task())
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
task.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
self.assertTrue(conn.closed)
|
||||
|
||||
@async_client_context.require_transactions
|
||||
async def test_async_cancellation_aborts_transaction(self):
|
||||
await self.client.db.test.insert_one({"x": 1})
|
||||
self.addAsyncCleanup(self.client.db.test.delete_many, {})
|
||||
|
||||
session = self.client.start_session()
|
||||
|
||||
async def callback(session):
|
||||
await self.client.db.test.find_one({"$where": delay(0.2)}, session=session)
|
||||
|
||||
async def task():
|
||||
await session.with_transaction(callback)
|
||||
|
||||
task = asyncio.create_task(task())
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
task.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
self.assertFalse(session.in_transaction)
|
||||
|
||||
@async_client_context.require_failCommand_blockConnection
|
||||
async def test_async_cancellation_closes_cursor(self):
|
||||
await self.client.db.test.insert_many([{"x": 1}, {"x": 2}])
|
||||
self.addAsyncCleanup(self.client.db.test.delete_many, {})
|
||||
|
||||
cursor = self.client.db.test.find({}, batch_size=1)
|
||||
await cursor.next()
|
||||
|
||||
# Make sure getMore commands block
|
||||
fail_command = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": "alwaysOn",
|
||||
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
|
||||
}
|
||||
|
||||
async def task():
|
||||
async with self.fail_point(fail_command):
|
||||
await cursor.next()
|
||||
|
||||
task = asyncio.create_task(task())
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
task.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
self.assertTrue(cursor._killed)
|
||||
|
||||
@async_client_context.require_change_streams
|
||||
@async_client_context.require_failCommand_blockConnection
|
||||
async def test_async_cancellation_closes_change_stream(self):
|
||||
self.addAsyncCleanup(self.client.db.test.delete_many, {})
|
||||
change_stream = await self.client.db.test.watch(batch_size=2)
|
||||
|
||||
# Make sure getMore commands block
|
||||
fail_command = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": "alwaysOn",
|
||||
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
|
||||
}
|
||||
|
||||
async def task():
|
||||
async with self.fail_point(fail_command):
|
||||
await self.client.db.test.insert_many([{"x": 1}, {"x": 2}])
|
||||
await change_stream.next()
|
||||
|
||||
task = asyncio.create_task(task())
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
task.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
self.assertTrue(change_stream._closed)
|
||||
@ -961,7 +961,6 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
|
||||
@async_client_context.require_replica_set
|
||||
@async_client_context.require_secondaries_count(1)
|
||||
async def test_write_concern_failure_ordered(self):
|
||||
self.skipTest("Skipping until PYTHON-4865 is resolved.")
|
||||
details = None
|
||||
|
||||
# Ensure we don't raise on wnote.
|
||||
|
||||
479
test/asynchronous/test_connection_monitoring.py
Normal file
479
test/asynchronous/test_connection_monitoring.py
Normal file
@ -0,0 +1,479 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Execute Transactions Spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest
|
||||
from test.asynchronous.pymongo_mocks import DummyMonitor
|
||||
from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator, SpecRunnerTask
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
async_client_context,
|
||||
async_get_pool,
|
||||
async_get_pools,
|
||||
async_wait_until,
|
||||
camel_to_snake,
|
||||
)
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from bson.son import SON
|
||||
from pymongo.asynchronous.pool import PoolState, _PoolClosedError
|
||||
from pymongo.errors import (
|
||||
ConnectionFailure,
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
WaitQueueTimeoutError,
|
||||
)
|
||||
from pymongo.monitoring import (
|
||||
ConnectionCheckedInEvent,
|
||||
ConnectionCheckedOutEvent,
|
||||
ConnectionCheckOutFailedEvent,
|
||||
ConnectionCheckOutFailedReason,
|
||||
ConnectionCheckOutStartedEvent,
|
||||
ConnectionClosedEvent,
|
||||
ConnectionClosedReason,
|
||||
ConnectionCreatedEvent,
|
||||
ConnectionReadyEvent,
|
||||
PoolClearedEvent,
|
||||
PoolClosedEvent,
|
||||
PoolCreatedEvent,
|
||||
PoolReadyEvent,
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.topology_description import updated_topology_description
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
OBJECT_TYPES = {
|
||||
# Event types.
|
||||
"ConnectionCheckedIn": ConnectionCheckedInEvent,
|
||||
"ConnectionCheckedOut": ConnectionCheckedOutEvent,
|
||||
"ConnectionCheckOutFailed": ConnectionCheckOutFailedEvent,
|
||||
"ConnectionClosed": ConnectionClosedEvent,
|
||||
"ConnectionCreated": ConnectionCreatedEvent,
|
||||
"ConnectionReady": ConnectionReadyEvent,
|
||||
"ConnectionCheckOutStarted": ConnectionCheckOutStartedEvent,
|
||||
"ConnectionPoolCreated": PoolCreatedEvent,
|
||||
"ConnectionPoolReady": PoolReadyEvent,
|
||||
"ConnectionPoolCleared": PoolClearedEvent,
|
||||
"ConnectionPoolClosed": PoolClosedEvent,
|
||||
# Error types.
|
||||
"PoolClosedError": _PoolClosedError,
|
||||
"WaitQueueTimeoutError": WaitQueueTimeoutError,
|
||||
}
|
||||
|
||||
|
||||
class AsyncTestCMAP(AsyncIntegrationTest):
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "connection_monitoring")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "connection_monitoring")
|
||||
|
||||
# Test operations:
|
||||
|
||||
async def start(self, op):
|
||||
"""Run the 'start' thread operation."""
|
||||
target = op["target"]
|
||||
thread = SpecRunnerTask(target)
|
||||
await thread.start()
|
||||
self.targets[target] = thread
|
||||
|
||||
async def wait(self, op):
|
||||
"""Run the 'wait' operation."""
|
||||
await asyncio.sleep(op["ms"] / 1000.0)
|
||||
|
||||
async def wait_for_thread(self, op):
|
||||
"""Run the 'waitForThread' operation."""
|
||||
target = op["target"]
|
||||
thread = self.targets[target]
|
||||
await thread.stop()
|
||||
await thread.join()
|
||||
if thread.exc:
|
||||
raise thread.exc
|
||||
self.assertFalse(thread.ops)
|
||||
|
||||
async def wait_for_event(self, op):
|
||||
"""Run the 'waitForEvent' operation."""
|
||||
event = OBJECT_TYPES[op["event"]]
|
||||
count = op["count"]
|
||||
timeout = op.get("timeout", 10000) / 1000.0
|
||||
await async_wait_until(
|
||||
lambda: self.listener.event_count(event) >= count,
|
||||
f"find {count} {event} event(s)",
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def check_out(self, op):
|
||||
"""Run the 'checkOut' operation."""
|
||||
label = op["label"]
|
||||
async with self.pool.checkout() as conn:
|
||||
# Call 'pin_cursor' so we can hold the socket.
|
||||
conn.pin_cursor()
|
||||
if label:
|
||||
self.labels[label] = conn
|
||||
else:
|
||||
self.addAsyncCleanup(conn.close_conn, None)
|
||||
|
||||
async def check_in(self, op):
|
||||
"""Run the 'checkIn' operation."""
|
||||
label = op["connection"]
|
||||
conn = self.labels[label]
|
||||
await self.pool.checkin(conn)
|
||||
|
||||
async def ready(self, op):
|
||||
"""Run the 'ready' operation."""
|
||||
await self.pool.ready()
|
||||
|
||||
async def clear(self, op):
|
||||
"""Run the 'clear' operation."""
|
||||
if "interruptInUseConnections" in op:
|
||||
await self.pool.reset(interrupt_connections=op["interruptInUseConnections"])
|
||||
else:
|
||||
await self.pool.reset()
|
||||
|
||||
async def close(self, op):
|
||||
"""Run the 'close' operation."""
|
||||
await self.pool.close()
|
||||
|
||||
async def run_operation(self, op):
|
||||
"""Run a single operation in a test."""
|
||||
op_name = camel_to_snake(op["name"])
|
||||
thread = op["thread"]
|
||||
meth = getattr(self, op_name)
|
||||
if thread:
|
||||
await self.targets[thread].schedule(lambda: meth(op))
|
||||
else:
|
||||
await meth(op)
|
||||
|
||||
async def run_operations(self, ops):
|
||||
"""Run a test's operations."""
|
||||
for op in ops:
|
||||
self._ops.append(op)
|
||||
await self.run_operation(op)
|
||||
|
||||
def check_object(self, actual, expected):
|
||||
"""Assert that the actual object matches the expected object."""
|
||||
self.assertEqual(type(actual), OBJECT_TYPES[expected["type"]])
|
||||
for attr, expected_val in expected.items():
|
||||
if attr == "type":
|
||||
continue
|
||||
c2s = camel_to_snake(attr)
|
||||
if c2s == "interrupt_in_use_connections":
|
||||
c2s = "interrupt_connections"
|
||||
actual_val = getattr(actual, c2s)
|
||||
if expected_val == 42:
|
||||
self.assertIsNotNone(actual_val)
|
||||
else:
|
||||
self.assertEqual(actual_val, expected_val)
|
||||
|
||||
def check_event(self, actual, expected):
|
||||
"""Assert that the actual event matches the expected event."""
|
||||
self.check_object(actual, expected)
|
||||
|
||||
def actual_events(self, ignore):
|
||||
"""Return all the non-ignored events."""
|
||||
ignore = tuple(OBJECT_TYPES[name] for name in ignore)
|
||||
return [event for event in self.listener.events if not isinstance(event, ignore)]
|
||||
|
||||
def check_events(self, events, ignore):
|
||||
"""Check the events of a test."""
|
||||
actual_events = self.actual_events(ignore)
|
||||
for actual, expected in zip(actual_events, events):
|
||||
self.logs.append(f"Checking event actual: {actual!r} vs expected: {expected!r}")
|
||||
self.check_event(actual, expected)
|
||||
|
||||
if len(events) > len(actual_events):
|
||||
self.fail(f"missing events: {events[len(actual_events) :]!r}")
|
||||
|
||||
def check_error(self, actual, expected):
|
||||
message = expected.pop("message")
|
||||
self.check_object(actual, expected)
|
||||
self.assertIn(message, str(actual))
|
||||
|
||||
async def _set_fail_point(self, client, command_args):
|
||||
cmd = SON([("configureFailPoint", "failCommand")])
|
||||
cmd.update(command_args)
|
||||
await client.admin.command(cmd)
|
||||
|
||||
async def set_fail_point(self, command_args):
|
||||
if not async_client_context.supports_failCommand_fail_point:
|
||||
self.skipTest("failCommand fail point must be supported")
|
||||
await self._set_fail_point(self.client, command_args)
|
||||
|
||||
async def run_scenario(self, scenario_def, test):
|
||||
"""Run a CMAP spec test."""
|
||||
self.logs: list = []
|
||||
self.assertEqual(scenario_def["version"], 1)
|
||||
self.assertIn(scenario_def["style"], ["unit", "integration"])
|
||||
self.listener = CMAPListener()
|
||||
self._ops: list = []
|
||||
|
||||
# Configure the fail point before creating the client.
|
||||
if "failPoint" in test:
|
||||
fp = test["failPoint"]
|
||||
await self.set_fail_point(fp)
|
||||
self.addAsyncCleanup(
|
||||
self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"}
|
||||
)
|
||||
|
||||
opts = test["poolOptions"].copy()
|
||||
opts["event_listeners"] = [self.listener]
|
||||
opts["_monitor_class"] = DummyMonitor
|
||||
opts["connect"] = False
|
||||
# Support backgroundThreadIntervalMS, default to 50ms.
|
||||
interval = opts.pop("backgroundThreadIntervalMS", 50)
|
||||
if interval < 0:
|
||||
kill_cursor_frequency = 99999999
|
||||
else:
|
||||
kill_cursor_frequency = interval / 1000.0
|
||||
with client_knobs(kill_cursor_frequency=kill_cursor_frequency, min_heartbeat_interval=0.05):
|
||||
client = await self.async_single_client(**opts)
|
||||
# Update the SD to a known type because the DummyMonitor will not.
|
||||
# Note we cannot simply call topology.on_change because that would
|
||||
# internally call pool.ready() which introduces unexpected
|
||||
# PoolReadyEvents. Instead, update the initial state before
|
||||
# opening the Topology.
|
||||
td = async_client_context.client._topology.description
|
||||
sd = td.server_descriptions()[
|
||||
(await async_client_context.host, await async_client_context.port)
|
||||
]
|
||||
client._topology._description = updated_topology_description(
|
||||
client._topology._description, sd
|
||||
)
|
||||
# When backgroundThreadIntervalMS is negative we do not start the
|
||||
# background thread to ensure it never runs.
|
||||
if interval < 0:
|
||||
await client._topology.open()
|
||||
else:
|
||||
await client._get_topology()
|
||||
self.pool = list(client._topology._servers.values())[0].pool
|
||||
|
||||
# Map of target names to Thread objects.
|
||||
self.targets: dict = {}
|
||||
# Map of label names to AsyncConnection objects
|
||||
self.labels: dict = {}
|
||||
|
||||
async def cleanup():
|
||||
for t in self.targets.values():
|
||||
await t.stop()
|
||||
for t in self.targets.values():
|
||||
await t.join(5)
|
||||
for conn in self.labels.values():
|
||||
conn.close_conn(None)
|
||||
|
||||
self.addAsyncCleanup(cleanup)
|
||||
|
||||
try:
|
||||
if test["error"]:
|
||||
with self.assertRaises(PyMongoError) as ctx:
|
||||
await self.run_operations(test["operations"])
|
||||
self.check_error(ctx.exception, test["error"])
|
||||
else:
|
||||
await self.run_operations(test["operations"])
|
||||
|
||||
self.check_events(test["events"], test["ignore"])
|
||||
except Exception:
|
||||
# Print the events after a test failure.
|
||||
print("\nFailed test: {!r}".format(test["description"]))
|
||||
print("Operations:")
|
||||
for op in self._ops:
|
||||
print(op)
|
||||
print("Threads:")
|
||||
print(self.targets)
|
||||
print("AsyncConnections:")
|
||||
print(self.labels)
|
||||
print("Events:")
|
||||
for event in self.listener.events:
|
||||
print(event)
|
||||
print("Log:")
|
||||
for log in self.logs:
|
||||
print(log)
|
||||
raise
|
||||
|
||||
POOL_OPTIONS = {
|
||||
"maxPoolSize": 50,
|
||||
"minPoolSize": 1,
|
||||
"maxIdleTimeMS": 10000,
|
||||
"waitQueueTimeoutMS": 10000,
|
||||
}
|
||||
|
||||
#
|
||||
# Prose tests. Numbers correspond to the prose test number in the spec.
|
||||
#
|
||||
async def test_1_client_connection_pool_options(self):
|
||||
client = await self.async_rs_or_single_client(**self.POOL_OPTIONS)
|
||||
pool_opts = (await async_get_pool(client)).opts
|
||||
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
|
||||
|
||||
async def test_2_all_client_pools_have_same_options(self):
|
||||
client = await self.async_rs_or_single_client(**self.POOL_OPTIONS)
|
||||
await client.admin.command("ping")
|
||||
# Discover at least one secondary.
|
||||
if await async_client_context.has_secondaries:
|
||||
await client.admin.command("ping", read_preference=ReadPreference.SECONDARY)
|
||||
pools = await async_get_pools(client)
|
||||
pool_opts = pools[0].opts
|
||||
|
||||
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
|
||||
for pool in pools[1:]:
|
||||
self.assertEqual(pool.opts, pool_opts)
|
||||
|
||||
async def test_3_uri_connection_pool_options(self):
|
||||
opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()])
|
||||
uri = f"mongodb://{await async_client_context.pair}/?{opts}"
|
||||
client = await self.async_rs_or_single_client(uri)
|
||||
pool_opts = (await async_get_pool(client)).opts
|
||||
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
|
||||
|
||||
async def test_4_subscribe_to_events(self):
|
||||
listener = CMAPListener()
|
||||
client = await self.async_single_client(event_listeners=[listener])
|
||||
self.assertEqual(listener.event_count(PoolCreatedEvent), 1)
|
||||
|
||||
# Creates a new connection.
|
||||
await client.admin.command("ping")
|
||||
self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 1)
|
||||
self.assertEqual(listener.event_count(ConnectionCreatedEvent), 1)
|
||||
self.assertEqual(listener.event_count(ConnectionReadyEvent), 1)
|
||||
self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 1)
|
||||
self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 1)
|
||||
|
||||
# Uses the existing connection.
|
||||
await client.admin.command("ping")
|
||||
self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 2)
|
||||
self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 2)
|
||||
self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 2)
|
||||
|
||||
await client.close()
|
||||
self.assertEqual(listener.event_count(PoolClosedEvent), 1)
|
||||
self.assertEqual(listener.event_count(ConnectionClosedEvent), 1)
|
||||
|
||||
async def test_5_check_out_fails_connection_error(self):
|
||||
listener = CMAPListener()
|
||||
client = await self.async_single_client(event_listeners=[listener])
|
||||
pool = await async_get_pool(client)
|
||||
|
||||
def mock_connect(*args, **kwargs):
|
||||
raise ConnectionFailure("connect failed")
|
||||
|
||||
pool.connect = mock_connect
|
||||
# Un-patch Pool.connect to break the cyclic reference.
|
||||
self.addCleanup(delattr, pool, "connect")
|
||||
|
||||
# Attempt to create a new connection.
|
||||
with self.assertRaisesRegex(ConnectionFailure, "connect failed"):
|
||||
await client.admin.command("ping")
|
||||
|
||||
self.assertIsInstance(listener.events[0], PoolCreatedEvent)
|
||||
self.assertIsInstance(listener.events[1], PoolReadyEvent)
|
||||
self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent)
|
||||
self.assertIsInstance(listener.events[3], ConnectionCheckOutFailedEvent)
|
||||
self.assertIsInstance(listener.events[4], PoolClearedEvent)
|
||||
|
||||
failed_event = listener.events[3]
|
||||
self.assertEqual(failed_event.reason, ConnectionCheckOutFailedReason.CONN_ERROR)
|
||||
|
||||
@async_client_context.require_no_fips
|
||||
async def test_5_check_out_fails_auth_error(self):
|
||||
listener = CMAPListener()
|
||||
client = await self.async_single_client_noauth(
|
||||
username="notauser", password="fail", event_listeners=[listener]
|
||||
)
|
||||
|
||||
# Attempt to create a new connection.
|
||||
with self.assertRaisesRegex(OperationFailure, "failed"):
|
||||
await client.admin.command("ping")
|
||||
|
||||
self.assertIsInstance(listener.events[0], PoolCreatedEvent)
|
||||
self.assertIsInstance(listener.events[1], PoolReadyEvent)
|
||||
self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent)
|
||||
self.assertIsInstance(listener.events[3], ConnectionCreatedEvent)
|
||||
# Error happens here.
|
||||
self.assertIsInstance(listener.events[4], ConnectionClosedEvent)
|
||||
self.assertIsInstance(listener.events[5], ConnectionCheckOutFailedEvent)
|
||||
self.assertEqual(listener.events[5].reason, ConnectionCheckOutFailedReason.CONN_ERROR)
|
||||
|
||||
#
|
||||
# Extra non-spec tests
|
||||
#
|
||||
def assertRepr(self, obj):
|
||||
new_obj = eval(repr(obj))
|
||||
self.assertEqual(type(new_obj), type(obj))
|
||||
self.assertEqual(repr(new_obj), repr(obj))
|
||||
|
||||
async def test_events_repr(self):
|
||||
host = ("localhost", 27017)
|
||||
self.assertRepr(ConnectionCheckedInEvent(host, 1))
|
||||
self.assertRepr(ConnectionCheckedOutEvent(host, 1, time.monotonic()))
|
||||
self.assertRepr(
|
||||
ConnectionCheckOutFailedEvent(
|
||||
host, ConnectionCheckOutFailedReason.POOL_CLOSED, time.monotonic()
|
||||
)
|
||||
)
|
||||
self.assertRepr(ConnectionClosedEvent(host, 1, ConnectionClosedReason.POOL_CLOSED))
|
||||
self.assertRepr(ConnectionCreatedEvent(host, 1))
|
||||
self.assertRepr(ConnectionReadyEvent(host, 1, time.monotonic()))
|
||||
self.assertRepr(ConnectionCheckOutStartedEvent(host))
|
||||
self.assertRepr(PoolCreatedEvent(host, {}))
|
||||
self.assertRepr(PoolClearedEvent(host))
|
||||
self.assertRepr(PoolClearedEvent(host, service_id=ObjectId()))
|
||||
self.assertRepr(PoolClosedEvent(host))
|
||||
|
||||
async def test_close_leaves_pool_unpaused(self):
|
||||
listener = CMAPListener()
|
||||
client = await self.async_single_client(event_listeners=[listener])
|
||||
await client.admin.command("ping")
|
||||
pool = await async_get_pool(client)
|
||||
await client.close()
|
||||
self.assertEqual(1, listener.event_count(PoolClosedEvent))
|
||||
self.assertEqual(PoolState.CLOSED, pool.state)
|
||||
# Checking out a connection should fail
|
||||
with self.assertRaises(_PoolClosedError):
|
||||
async with pool.checkout():
|
||||
pass
|
||||
|
||||
|
||||
def create_test(scenario_def, test, name):
|
||||
async def run_scenario(self):
|
||||
await self.run_scenario(scenario_def, test)
|
||||
|
||||
return run_scenario
|
||||
|
||||
|
||||
class CMAPSpecTestCreator(AsyncSpecTestCreator):
|
||||
def tests(self, scenario_def):
|
||||
"""Extract the tests from a spec file.
|
||||
|
||||
CMAP tests do not have a 'tests' field. The whole file represents
|
||||
a single test case.
|
||||
"""
|
||||
return [scenario_def]
|
||||
|
||||
|
||||
test_creator = CMAPSpecTestCreator(create_test, AsyncTestCMAP, AsyncTestCMAP.TEST_PATH)
|
||||
test_creator.create_tests()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -739,7 +739,7 @@ class AsyncTestSpec(AsyncSpecRunner):
|
||||
return errors
|
||||
|
||||
|
||||
async def create_test(scenario_def, test, name):
|
||||
def create_test(scenario_def, test, name):
|
||||
@async_client_context.require_test_commands
|
||||
async def run_scenario(self):
|
||||
await self.run_scenario(scenario_def, test)
|
||||
|
||||
602
test/asynchronous/test_gridfs.py
Normal file
602
test/asynchronous/test_gridfs.py
Normal file
@ -0,0 +1,602 @@
|
||||
#
|
||||
# Copyright 2009-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the gridfs package."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from io import BytesIO
|
||||
from test.asynchronous.helpers import ConcurrentRunner
|
||||
from unittest.mock import patch
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
from test.utils import async_joinall, one
|
||||
|
||||
import gridfs
|
||||
from bson.binary import Binary
|
||||
from gridfs.asynchronous.grid_file import DEFAULT_CHUNK_SIZE, AsyncGridOutCursor
|
||||
from gridfs.errors import CorruptGridFile, FileExists, NoFile
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
NotPrimaryError,
|
||||
ServerSelectionTimeoutError,
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class JustWrite(ConcurrentRunner):
|
||||
def __init__(self, fs, n):
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.n = n
|
||||
self.daemon = True
|
||||
|
||||
async def run(self):
|
||||
for _ in range(self.n):
|
||||
file = self.fs.new_file(filename="test")
|
||||
await file.write(b"hello")
|
||||
await file.close()
|
||||
|
||||
|
||||
class JustRead(ConcurrentRunner):
|
||||
def __init__(self, fs, n, results):
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.n = n
|
||||
self.results = results
|
||||
self.daemon = True
|
||||
|
||||
async def run(self):
|
||||
for _ in range(self.n):
|
||||
file = await self.fs.get("test")
|
||||
data = await file.read()
|
||||
self.results.append(data)
|
||||
assert data == b"hello"
|
||||
|
||||
|
||||
class TestGridfsNoConnect(unittest.IsolatedAsyncioTestCase):
|
||||
db: AsyncDatabase
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.db = AsyncMongoClient(connect=False).pymongo_test
|
||||
|
||||
async def test_gridfs(self):
|
||||
self.assertRaises(TypeError, gridfs.AsyncGridFS, "foo")
|
||||
self.assertRaises(TypeError, gridfs.AsyncGridFS, self.db, 5)
|
||||
|
||||
|
||||
class TestGridfs(AsyncIntegrationTest):
|
||||
fs: gridfs.AsyncGridFS
|
||||
alt: gridfs.AsyncGridFS
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.fs = gridfs.AsyncGridFS(self.db)
|
||||
self.alt = gridfs.AsyncGridFS(self.db, "alt")
|
||||
await self.cleanup_colls(
|
||||
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
|
||||
)
|
||||
|
||||
async def test_basic(self):
|
||||
oid = await self.fs.put(b"hello world")
|
||||
self.assertEqual(b"hello world", await (await self.fs.get(oid)).read())
|
||||
self.assertEqual(1, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
await self.fs.delete(oid)
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.get(oid)
|
||||
self.assertEqual(0, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.get("foo")
|
||||
oid = await self.fs.put(b"hello world", _id="foo")
|
||||
self.assertEqual("foo", oid)
|
||||
self.assertEqual(b"hello world", await (await self.fs.get("foo")).read())
|
||||
|
||||
async def test_multi_chunk_delete(self):
|
||||
await self.db.fs.drop()
|
||||
self.assertEqual(0, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
|
||||
gfs = gridfs.AsyncGridFS(self.db)
|
||||
oid = await gfs.put(b"hello", chunkSize=1)
|
||||
self.assertEqual(1, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(5, await self.db.fs.chunks.count_documents({}))
|
||||
await gfs.delete(oid)
|
||||
self.assertEqual(0, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
async def test_list(self):
|
||||
self.assertEqual([], await self.fs.list())
|
||||
await self.fs.put(b"hello world")
|
||||
self.assertEqual([], await self.fs.list())
|
||||
|
||||
# PYTHON-598: in server versions before 2.5.x, creating an index on
|
||||
# filename, uploadDate causes list() to include None.
|
||||
await self.fs.get_last_version()
|
||||
self.assertEqual([], await self.fs.list())
|
||||
|
||||
await self.fs.put(b"", filename="mike")
|
||||
await self.fs.put(b"foo", filename="test")
|
||||
await self.fs.put(b"", filename="hello world")
|
||||
|
||||
self.assertEqual({"mike", "test", "hello world"}, set(await self.fs.list()))
|
||||
|
||||
async def test_empty_file(self):
|
||||
oid = await self.fs.put(b"")
|
||||
self.assertEqual(b"", await (await self.fs.get(oid)).read())
|
||||
self.assertEqual(1, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
raw = await self.db.fs.files.find_one()
|
||||
assert raw is not None
|
||||
self.assertEqual(0, raw["length"])
|
||||
self.assertEqual(oid, raw["_id"])
|
||||
self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime))
|
||||
self.assertEqual(255 * 1024, raw["chunkSize"])
|
||||
self.assertNotIn("md5", raw)
|
||||
|
||||
async def test_corrupt_chunk(self):
|
||||
files_id = await self.fs.put(b"foobar")
|
||||
await self.db.fs.chunks.update_one(
|
||||
{"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}}
|
||||
)
|
||||
try:
|
||||
out = await self.fs.get(files_id)
|
||||
with self.assertRaises(CorruptGridFile):
|
||||
await out.read()
|
||||
|
||||
out = await self.fs.get(files_id)
|
||||
with self.assertRaises(CorruptGridFile):
|
||||
await out.readline()
|
||||
finally:
|
||||
await self.fs.delete(files_id)
|
||||
|
||||
async def test_put_ensures_index(self):
|
||||
chunks = self.db.fs.chunks
|
||||
files = self.db.fs.files
|
||||
# Ensure the collections are removed.
|
||||
await chunks.drop()
|
||||
await files.drop()
|
||||
await self.fs.put(b"junk")
|
||||
|
||||
self.assertTrue(
|
||||
any(
|
||||
info.get("key") == [("files_id", 1), ("n", 1)]
|
||||
for info in (await chunks.index_information()).values()
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
any(
|
||||
info.get("key") == [("filename", 1), ("uploadDate", 1)]
|
||||
for info in (await files.index_information()).values()
|
||||
)
|
||||
)
|
||||
|
||||
async def test_alt_collection(self):
|
||||
oid = await self.alt.put(b"hello world")
|
||||
self.assertEqual(b"hello world", await (await self.alt.get(oid)).read())
|
||||
self.assertEqual(1, await self.db.alt.files.count_documents({}))
|
||||
self.assertEqual(1, await self.db.alt.chunks.count_documents({}))
|
||||
|
||||
await self.alt.delete(oid)
|
||||
with self.assertRaises(NoFile):
|
||||
await self.alt.get(oid)
|
||||
self.assertEqual(0, await self.db.alt.files.count_documents({}))
|
||||
self.assertEqual(0, await self.db.alt.chunks.count_documents({}))
|
||||
|
||||
with self.assertRaises(NoFile):
|
||||
await self.alt.get("foo")
|
||||
oid = await self.alt.put(b"hello world", _id="foo")
|
||||
self.assertEqual("foo", oid)
|
||||
self.assertEqual(b"hello world", await (await self.alt.get("foo")).read())
|
||||
|
||||
await self.alt.put(b"", filename="mike")
|
||||
await self.alt.put(b"foo", filename="test")
|
||||
await self.alt.put(b"", filename="hello world")
|
||||
|
||||
self.assertEqual({"mike", "test", "hello world"}, set(await self.alt.list()))
|
||||
|
||||
async def test_threaded_reads(self):
|
||||
await self.fs.put(b"hello", _id="test")
|
||||
|
||||
tasks = []
|
||||
results: list = []
|
||||
for i in range(10):
|
||||
tasks.append(JustRead(self.fs, 10, results))
|
||||
await tasks[i].start()
|
||||
|
||||
await async_joinall(tasks)
|
||||
|
||||
self.assertEqual(100 * [b"hello"], results)
|
||||
|
||||
async def test_threaded_writes(self):
|
||||
tasks = []
|
||||
for i in range(10):
|
||||
tasks.append(JustWrite(self.fs, 10))
|
||||
await tasks[i].start()
|
||||
|
||||
await async_joinall(tasks)
|
||||
|
||||
f = await self.fs.get_last_version("test")
|
||||
self.assertEqual(await f.read(), b"hello")
|
||||
|
||||
# Should have created 100 versions of 'test' file
|
||||
self.assertEqual(100, await self.db.fs.files.count_documents({"filename": "test"}))
|
||||
|
||||
async def test_get_last_version(self):
|
||||
one = await self.fs.put(b"foo", filename="test")
|
||||
await asyncio.sleep(0.01)
|
||||
two = self.fs.new_file(filename="test")
|
||||
await two.write(b"bar")
|
||||
await two.close()
|
||||
await asyncio.sleep(0.01)
|
||||
two = two._id
|
||||
three = await self.fs.put(b"baz", filename="test")
|
||||
|
||||
self.assertEqual(b"baz", await (await self.fs.get_last_version("test")).read())
|
||||
await self.fs.delete(three)
|
||||
self.assertEqual(b"bar", await (await self.fs.get_last_version("test")).read())
|
||||
await self.fs.delete(two)
|
||||
self.assertEqual(b"foo", await (await self.fs.get_last_version("test")).read())
|
||||
await self.fs.delete(one)
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.get_last_version("test")
|
||||
|
||||
async def test_get_last_version_with_metadata(self):
|
||||
one = await self.fs.put(b"foo", filename="test", author="author")
|
||||
await asyncio.sleep(0.01)
|
||||
two = await self.fs.put(b"bar", filename="test", author="author")
|
||||
|
||||
self.assertEqual(b"bar", await (await self.fs.get_last_version(author="author")).read())
|
||||
await self.fs.delete(two)
|
||||
self.assertEqual(b"foo", await (await self.fs.get_last_version(author="author")).read())
|
||||
await self.fs.delete(one)
|
||||
|
||||
one = await self.fs.put(b"foo", filename="test", author="author1")
|
||||
await asyncio.sleep(0.01)
|
||||
two = await self.fs.put(b"bar", filename="test", author="author2")
|
||||
|
||||
self.assertEqual(b"foo", await (await self.fs.get_last_version(author="author1")).read())
|
||||
self.assertEqual(b"bar", await (await self.fs.get_last_version(author="author2")).read())
|
||||
self.assertEqual(b"bar", await (await self.fs.get_last_version(filename="test")).read())
|
||||
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.get_last_version(author="author3")
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.get_last_version(filename="nottest", author="author1")
|
||||
|
||||
await self.fs.delete(one)
|
||||
await self.fs.delete(two)
|
||||
|
||||
async def test_get_version(self):
|
||||
await self.fs.put(b"foo", filename="test")
|
||||
await asyncio.sleep(0.01)
|
||||
await self.fs.put(b"bar", filename="test")
|
||||
await asyncio.sleep(0.01)
|
||||
await self.fs.put(b"baz", filename="test")
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
self.assertEqual(b"foo", await (await self.fs.get_version("test", 0)).read())
|
||||
self.assertEqual(b"bar", await (await self.fs.get_version("test", 1)).read())
|
||||
self.assertEqual(b"baz", await (await self.fs.get_version("test", 2)).read())
|
||||
|
||||
self.assertEqual(b"baz", await (await self.fs.get_version("test", -1)).read())
|
||||
self.assertEqual(b"bar", await (await self.fs.get_version("test", -2)).read())
|
||||
self.assertEqual(b"foo", await (await self.fs.get_version("test", -3)).read())
|
||||
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.get_version("test", 3)
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.get_version("test", -4)
|
||||
|
||||
async def test_get_version_with_metadata(self):
|
||||
one = await self.fs.put(b"foo", filename="test", author="author1")
|
||||
await asyncio.sleep(0.01)
|
||||
two = await self.fs.put(b"bar", filename="test", author="author1")
|
||||
await asyncio.sleep(0.01)
|
||||
three = await self.fs.put(b"baz", filename="test", author="author2")
|
||||
|
||||
self.assertEqual(
|
||||
b"foo",
|
||||
await (await self.fs.get_version(filename="test", author="author1", version=-2)).read(),
|
||||
)
|
||||
self.assertEqual(
|
||||
b"bar",
|
||||
await (await self.fs.get_version(filename="test", author="author1", version=-1)).read(),
|
||||
)
|
||||
self.assertEqual(
|
||||
b"foo",
|
||||
await (await self.fs.get_version(filename="test", author="author1", version=0)).read(),
|
||||
)
|
||||
self.assertEqual(
|
||||
b"bar",
|
||||
await (await self.fs.get_version(filename="test", author="author1", version=1)).read(),
|
||||
)
|
||||
self.assertEqual(
|
||||
b"baz",
|
||||
await (await self.fs.get_version(filename="test", author="author2", version=0)).read(),
|
||||
)
|
||||
self.assertEqual(
|
||||
b"baz", await (await self.fs.get_version(filename="test", version=-1)).read()
|
||||
)
|
||||
self.assertEqual(
|
||||
b"baz", await (await self.fs.get_version(filename="test", version=2)).read()
|
||||
)
|
||||
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.get_version(filename="test", author="author3")
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.get_version(filename="test", author="author1", version=2)
|
||||
|
||||
await self.fs.delete(one)
|
||||
await self.fs.delete(two)
|
||||
await self.fs.delete(three)
|
||||
|
||||
async def test_put_filelike(self):
|
||||
oid = await self.fs.put(BytesIO(b"hello world"), chunk_size=1)
|
||||
self.assertEqual(11, await self.db.fs.chunks.count_documents({}))
|
||||
self.assertEqual(b"hello world", await (await self.fs.get(oid)).read())
|
||||
|
||||
async def test_file_exists(self):
|
||||
oid = await self.fs.put(b"hello")
|
||||
with self.assertRaises(FileExists):
|
||||
await self.fs.put(b"world", _id=oid)
|
||||
|
||||
one = self.fs.new_file(_id=123)
|
||||
await one.write(b"some content")
|
||||
await one.close()
|
||||
|
||||
# Attempt to upload a file with more chunks to the same _id.
|
||||
with patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_SIZE", DEFAULT_CHUNK_SIZE):
|
||||
two = self.fs.new_file(_id=123)
|
||||
with self.assertRaises(FileExists):
|
||||
await two.write(b"x" * DEFAULT_CHUNK_SIZE * 3)
|
||||
# Original file is still readable (no extra chunks were uploaded).
|
||||
self.assertEqual(await (await self.fs.get(123)).read(), b"some content")
|
||||
|
||||
two = self.fs.new_file(_id=123)
|
||||
await two.write(b"some content")
|
||||
with self.assertRaises(FileExists):
|
||||
await two.close()
|
||||
# Original file is still readable.
|
||||
self.assertEqual(await (await self.fs.get(123)).read(), b"some content")
|
||||
|
||||
async def test_exists(self):
|
||||
oid = await self.fs.put(b"hello")
|
||||
self.assertTrue(await self.fs.exists(oid))
|
||||
self.assertTrue(await self.fs.exists({"_id": oid}))
|
||||
self.assertTrue(await self.fs.exists(_id=oid))
|
||||
|
||||
self.assertFalse(await self.fs.exists(filename="mike"))
|
||||
self.assertFalse(await self.fs.exists("mike"))
|
||||
|
||||
oid = await self.fs.put(b"hello", filename="mike", foo=12)
|
||||
self.assertTrue(await self.fs.exists(oid))
|
||||
self.assertTrue(await self.fs.exists({"_id": oid}))
|
||||
self.assertTrue(await self.fs.exists(_id=oid))
|
||||
self.assertTrue(await self.fs.exists(filename="mike"))
|
||||
self.assertTrue(await self.fs.exists({"filename": "mike"}))
|
||||
self.assertTrue(await self.fs.exists(foo=12))
|
||||
self.assertTrue(await self.fs.exists({"foo": 12}))
|
||||
self.assertTrue(await self.fs.exists(foo={"$gt": 11}))
|
||||
self.assertTrue(await self.fs.exists({"foo": {"$gt": 11}}))
|
||||
|
||||
self.assertFalse(await self.fs.exists(foo=13))
|
||||
self.assertFalse(await self.fs.exists({"foo": 13}))
|
||||
self.assertFalse(await self.fs.exists(foo={"$gt": 12}))
|
||||
self.assertFalse(await self.fs.exists({"foo": {"$gt": 12}}))
|
||||
|
||||
async def test_put_unicode(self):
|
||||
with self.assertRaises(TypeError):
|
||||
await self.fs.put("hello")
|
||||
|
||||
oid = await self.fs.put("hello", encoding="utf-8")
|
||||
self.assertEqual(b"hello", await (await self.fs.get(oid)).read())
|
||||
self.assertEqual("utf-8", (await self.fs.get(oid)).encoding)
|
||||
|
||||
oid = await self.fs.put("aé", encoding="iso-8859-1")
|
||||
self.assertEqual("aé".encode("iso-8859-1"), await (await self.fs.get(oid)).read())
|
||||
self.assertEqual("iso-8859-1", (await self.fs.get(oid)).encoding)
|
||||
|
||||
async def test_missing_length_iter(self):
|
||||
# Test fix that guards against PHP-237
|
||||
await self.fs.put(b"", filename="empty")
|
||||
doc = await self.db.fs.files.find_one({"filename": "empty"})
|
||||
assert doc is not None
|
||||
doc.pop("length")
|
||||
await self.db.fs.files.replace_one({"_id": doc["_id"]}, doc)
|
||||
f = await self.fs.get_last_version(filename="empty")
|
||||
|
||||
async def iterate_file(grid_file):
|
||||
async for _chunk in grid_file:
|
||||
pass
|
||||
return True
|
||||
|
||||
self.assertTrue(await iterate_file(f))
|
||||
|
||||
async def test_gridfs_lazy_connect(self):
|
||||
client = await self.async_single_client(
|
||||
"badhost", connect=False, serverSelectionTimeoutMS=10
|
||||
)
|
||||
db = client.db
|
||||
gfs = gridfs.AsyncGridFS(db)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
await gfs.list()
|
||||
|
||||
fs = gridfs.AsyncGridFS(db)
|
||||
f = fs.new_file()
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
await f.close()
|
||||
|
||||
async def test_gridfs_find(self):
|
||||
await self.fs.put(b"test2", filename="two")
|
||||
await asyncio.sleep(0.01)
|
||||
await self.fs.put(b"test2+", filename="two")
|
||||
await asyncio.sleep(0.01)
|
||||
await self.fs.put(b"test1", filename="one")
|
||||
await asyncio.sleep(0.01)
|
||||
await self.fs.put(b"test2++", filename="two")
|
||||
files = self.db.fs.files
|
||||
self.assertEqual(3, await files.count_documents({"filename": "two"}))
|
||||
self.assertEqual(4, await files.count_documents({}))
|
||||
cursor = self.fs.find(no_cursor_timeout=False).sort("uploadDate", -1).skip(1).limit(2)
|
||||
gout = await cursor.next()
|
||||
self.assertEqual(b"test1", await gout.read())
|
||||
await cursor.rewind()
|
||||
gout = await cursor.next()
|
||||
self.assertEqual(b"test1", await gout.read())
|
||||
gout = await cursor.next()
|
||||
self.assertEqual(b"test2+", await gout.read())
|
||||
with self.assertRaises(StopAsyncIteration):
|
||||
await cursor.__anext__()
|
||||
await cursor.rewind()
|
||||
items = await cursor.to_list()
|
||||
self.assertEqual(len(items), 2)
|
||||
await cursor.rewind()
|
||||
items = await cursor.to_list(1)
|
||||
self.assertEqual(len(items), 1)
|
||||
await cursor.close()
|
||||
self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})
|
||||
|
||||
async def test_delete_not_initialized(self):
|
||||
# Creating a cursor with invalid arguments will not run __init__
|
||||
# but will still call __del__.
|
||||
cursor = AsyncGridOutCursor.__new__(AsyncGridOutCursor) # Skip calling __init__
|
||||
with self.assertRaises(TypeError):
|
||||
cursor.__init__(self.db.fs.files, {}, {"_id": True}) # type: ignore
|
||||
cursor.__del__() # no error
|
||||
|
||||
async def test_gridfs_find_one(self):
|
||||
self.assertEqual(None, await self.fs.find_one())
|
||||
|
||||
id1 = await self.fs.put(b"test1", filename="file1")
|
||||
res = await self.fs.find_one()
|
||||
assert res is not None
|
||||
self.assertEqual(b"test1", await res.read())
|
||||
|
||||
id2 = await self.fs.put(b"test2", filename="file2", meta="data")
|
||||
res1 = await self.fs.find_one(id1)
|
||||
assert res1 is not None
|
||||
self.assertEqual(b"test1", await res1.read())
|
||||
res2 = await self.fs.find_one(id2)
|
||||
assert res2 is not None
|
||||
self.assertEqual(b"test2", await res2.read())
|
||||
|
||||
res3 = await self.fs.find_one({"filename": "file1"})
|
||||
assert res3 is not None
|
||||
self.assertEqual(b"test1", await res3.read())
|
||||
|
||||
res4 = await self.fs.find_one(id2)
|
||||
assert res4 is not None
|
||||
self.assertEqual("data", res4.meta)
|
||||
|
||||
async def test_grid_in_non_int_chunksize(self):
|
||||
# Lua, and perhaps other buggy AsyncGridFS clients, store size as a float.
|
||||
data = b"data"
|
||||
await self.fs.put(data, filename="f")
|
||||
await self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}})
|
||||
|
||||
self.assertEqual(data, await (await self.fs.get_version("f")).read())
|
||||
|
||||
async def test_unacknowledged(self):
|
||||
# w=0 is prohibited.
|
||||
with self.assertRaises(ConfigurationError):
|
||||
gridfs.AsyncGridFS((await self.async_rs_or_single_client(w=0)).pymongo_test)
|
||||
|
||||
async def test_md5(self):
|
||||
gin = self.fs.new_file()
|
||||
await gin.write(b"no md5 sum")
|
||||
await gin.close()
|
||||
self.assertIsNone(gin.md5)
|
||||
|
||||
gout = await self.fs.get(gin._id)
|
||||
self.assertIsNone(gout.md5)
|
||||
|
||||
_id = await self.fs.put(b"still no md5 sum")
|
||||
gout = await self.fs.get(_id)
|
||||
self.assertIsNone(gout.md5)
|
||||
|
||||
|
||||
class TestGridfsReplicaSet(AsyncIntegrationTest):
|
||||
@async_client_context.require_secondaries_count(1)
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def asyncTearDownClass(cls):
|
||||
await async_client_context.client.drop_database("gfsreplica")
|
||||
|
||||
async def test_gridfs_replica_set(self):
|
||||
rsc = await self.async_rs_client(
|
||||
w=async_client_context.w, read_preference=ReadPreference.SECONDARY
|
||||
)
|
||||
|
||||
fs = gridfs.AsyncGridFS(rsc.gfsreplica, "gfsreplicatest")
|
||||
|
||||
gin = fs.new_file()
|
||||
self.assertEqual(gin._coll.read_preference, ReadPreference.PRIMARY)
|
||||
|
||||
oid = await fs.put(b"foo")
|
||||
content = await (await fs.get(oid)).read()
|
||||
self.assertEqual(b"foo", content)
|
||||
|
||||
async def test_gridfs_secondary(self):
|
||||
secondary_host, secondary_port = one(await self.client.secondaries)
|
||||
secondary_connection = await self.async_single_client(
|
||||
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY
|
||||
)
|
||||
|
||||
# Should detect it's connected to secondary and not attempt to
|
||||
# create index
|
||||
fs = gridfs.AsyncGridFS(secondary_connection.gfsreplica, "gfssecondarytest")
|
||||
|
||||
# This won't detect secondary, raises error
|
||||
with self.assertRaises(NotPrimaryError):
|
||||
await fs.put(b"foo")
|
||||
|
||||
async def test_gridfs_secondary_lazy(self):
|
||||
# Should detect it's connected to secondary and not attempt to
|
||||
# create index.
|
||||
secondary_host, secondary_port = one(await self.client.secondaries)
|
||||
client = await self.async_single_client(
|
||||
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False
|
||||
)
|
||||
|
||||
# Still no connection.
|
||||
fs = gridfs.AsyncGridFS(client.gfsreplica, "gfssecondarylazytest")
|
||||
|
||||
# Connects, doesn't create index.
|
||||
with self.assertRaises(NoFile):
|
||||
await fs.get_last_version()
|
||||
with self.assertRaises(NotPrimaryError):
|
||||
await fs.put("data", encoding="utf-8")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
574
test/asynchronous/test_gridfs_bucket.py
Normal file
574
test/asynchronous/test_gridfs_bucket.py
Normal file
@ -0,0 +1,574 @@
|
||||
#
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the gridfs package."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import itertools
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from io import BytesIO
|
||||
from test.asynchronous.helpers import ConcurrentRunner
|
||||
from unittest.mock import patch
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
from test.utils import async_joinall, joinall, one
|
||||
|
||||
import gridfs
|
||||
from bson.binary import Binary
|
||||
from bson.int64 import Int64
|
||||
from bson.objectid import ObjectId
|
||||
from bson.son import SON
|
||||
from gridfs.errors import CorruptGridFile, NoFile
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
NotPrimaryError,
|
||||
ServerSelectionTimeoutError,
|
||||
WriteConcernError,
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class JustWrite(ConcurrentRunner):
|
||||
def __init__(self, gfs, num):
|
||||
super().__init__()
|
||||
self.gfs = gfs
|
||||
self.num = num
|
||||
self.daemon = True
|
||||
|
||||
async def run(self):
|
||||
for _ in range(self.num):
|
||||
file = self.gfs.open_upload_stream("test")
|
||||
await file.write(b"hello")
|
||||
await file.close()
|
||||
|
||||
|
||||
class JustRead(ConcurrentRunner):
|
||||
def __init__(self, gfs, num, results):
|
||||
super().__init__()
|
||||
self.gfs = gfs
|
||||
self.num = num
|
||||
self.results = results
|
||||
self.daemon = True
|
||||
|
||||
async def run(self):
|
||||
for _ in range(self.num):
|
||||
file = await self.gfs.open_download_stream_by_name("test")
|
||||
data = await file.read()
|
||||
self.results.append(data)
|
||||
assert data == b"hello"
|
||||
|
||||
|
||||
class TestGridfs(AsyncIntegrationTest):
|
||||
fs: gridfs.AsyncGridFSBucket
|
||||
alt: gridfs.AsyncGridFSBucket
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.fs = gridfs.AsyncGridFSBucket(self.db)
|
||||
self.alt = gridfs.AsyncGridFSBucket(self.db, bucket_name="alt")
|
||||
await self.cleanup_colls(
|
||||
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
|
||||
)
|
||||
|
||||
async def test_basic(self):
|
||||
oid = await self.fs.upload_from_stream("test_filename", b"hello world")
|
||||
self.assertEqual(b"hello world", await (await self.fs.open_download_stream(oid)).read())
|
||||
self.assertEqual(1, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
await self.fs.delete(oid)
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.open_download_stream(oid)
|
||||
self.assertEqual(0, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
async def test_multi_chunk_delete(self):
|
||||
self.assertEqual(0, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
|
||||
gfs = gridfs.AsyncGridFSBucket(self.db)
|
||||
oid = await gfs.upload_from_stream("test_filename", b"hello", chunk_size_bytes=1)
|
||||
self.assertEqual(1, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(5, await self.db.fs.chunks.count_documents({}))
|
||||
await gfs.delete(oid)
|
||||
self.assertEqual(0, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
async def test_empty_file(self):
|
||||
oid = await self.fs.upload_from_stream("test_filename", b"")
|
||||
self.assertEqual(b"", await (await self.fs.open_download_stream(oid)).read())
|
||||
self.assertEqual(1, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
raw = await self.db.fs.files.find_one()
|
||||
assert raw is not None
|
||||
self.assertEqual(0, raw["length"])
|
||||
self.assertEqual(oid, raw["_id"])
|
||||
self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime))
|
||||
self.assertEqual(255 * 1024, raw["chunkSize"])
|
||||
self.assertNotIn("md5", raw)
|
||||
|
||||
async def test_corrupt_chunk(self):
|
||||
files_id = await self.fs.upload_from_stream("test_filename", b"foobar")
|
||||
await self.db.fs.chunks.update_one(
|
||||
{"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}}
|
||||
)
|
||||
try:
|
||||
out = await self.fs.open_download_stream(files_id)
|
||||
with self.assertRaises(CorruptGridFile):
|
||||
await out.read()
|
||||
|
||||
out = await self.fs.open_download_stream(files_id)
|
||||
with self.assertRaises(CorruptGridFile):
|
||||
await out.readline()
|
||||
finally:
|
||||
await self.fs.delete(files_id)
|
||||
|
||||
async def test_upload_ensures_index(self):
|
||||
chunks = self.db.fs.chunks
|
||||
files = self.db.fs.files
|
||||
# Ensure the collections are removed.
|
||||
await chunks.drop()
|
||||
await files.drop()
|
||||
await self.fs.upload_from_stream("filename", b"junk")
|
||||
|
||||
self.assertTrue(
|
||||
any(
|
||||
info.get("key") == [("files_id", 1), ("n", 1)]
|
||||
for info in (await chunks.index_information()).values()
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
any(
|
||||
info.get("key") == [("filename", 1), ("uploadDate", 1)]
|
||||
for info in (await files.index_information()).values()
|
||||
)
|
||||
)
|
||||
|
||||
async def test_ensure_index_shell_compat(self):
|
||||
files = self.db.fs.files
|
||||
for i, j in itertools.combinations_with_replacement([1, 1.0, Int64(1)], 2):
|
||||
# Create the index with different numeric types (as might be done
|
||||
# from the mongo shell).
|
||||
shell_index = [("filename", i), ("uploadDate", j)]
|
||||
await self.db.command(
|
||||
"createIndexes",
|
||||
files.name,
|
||||
indexes=[{"key": SON(shell_index), "name": "filename_1.0_uploadDate_1.0"}],
|
||||
)
|
||||
|
||||
# No error.
|
||||
await self.fs.upload_from_stream("filename", b"data")
|
||||
|
||||
self.assertTrue(
|
||||
any(
|
||||
info.get("key") == [("filename", 1), ("uploadDate", 1)]
|
||||
for info in (await files.index_information()).values()
|
||||
)
|
||||
)
|
||||
await files.drop()
|
||||
|
||||
async def test_alt_collection(self):
|
||||
oid = await self.alt.upload_from_stream("test_filename", b"hello world")
|
||||
self.assertEqual(b"hello world", await (await self.alt.open_download_stream(oid)).read())
|
||||
self.assertEqual(1, await self.db.alt.files.count_documents({}))
|
||||
self.assertEqual(1, await self.db.alt.chunks.count_documents({}))
|
||||
|
||||
await self.alt.delete(oid)
|
||||
with self.assertRaises(NoFile):
|
||||
await self.alt.open_download_stream(oid)
|
||||
self.assertEqual(0, await self.db.alt.files.count_documents({}))
|
||||
self.assertEqual(0, await self.db.alt.chunks.count_documents({}))
|
||||
|
||||
with self.assertRaises(NoFile):
|
||||
await self.alt.open_download_stream("foo")
|
||||
await self.alt.upload_from_stream("foo", b"hello world")
|
||||
self.assertEqual(
|
||||
b"hello world", await (await self.alt.open_download_stream_by_name("foo")).read()
|
||||
)
|
||||
|
||||
await self.alt.upload_from_stream("mike", b"")
|
||||
await self.alt.upload_from_stream("test", b"foo")
|
||||
await self.alt.upload_from_stream("hello world", b"")
|
||||
|
||||
self.assertEqual(
|
||||
{"mike", "test", "hello world", "foo"},
|
||||
{k["filename"] for k in await self.db.alt.files.find().to_list()},
|
||||
)
|
||||
|
||||
async def test_threaded_reads(self):
|
||||
await self.fs.upload_from_stream("test", b"hello")
|
||||
|
||||
threads = []
|
||||
results: list = []
|
||||
for i in range(10):
|
||||
threads.append(JustRead(self.fs, 10, results))
|
||||
await threads[i].start()
|
||||
|
||||
await async_joinall(threads)
|
||||
|
||||
self.assertEqual(100 * [b"hello"], results)
|
||||
|
||||
async def test_threaded_writes(self):
|
||||
threads = []
|
||||
for i in range(10):
|
||||
threads.append(JustWrite(self.fs, 10))
|
||||
await threads[i].start()
|
||||
|
||||
await async_joinall(threads)
|
||||
|
||||
fstr = await self.fs.open_download_stream_by_name("test")
|
||||
self.assertEqual(await fstr.read(), b"hello")
|
||||
|
||||
# Should have created 100 versions of 'test' file
|
||||
self.assertEqual(100, await self.db.fs.files.count_documents({"filename": "test"}))
|
||||
|
||||
async def test_get_last_version(self):
|
||||
one = await self.fs.upload_from_stream("test", b"foo")
|
||||
await asyncio.sleep(0.01)
|
||||
two = self.fs.open_upload_stream("test")
|
||||
await two.write(b"bar")
|
||||
await two.close()
|
||||
await asyncio.sleep(0.01)
|
||||
two = two._id
|
||||
three = await self.fs.upload_from_stream("test", b"baz")
|
||||
|
||||
self.assertEqual(b"baz", await (await self.fs.open_download_stream_by_name("test")).read())
|
||||
await self.fs.delete(three)
|
||||
self.assertEqual(b"bar", await (await self.fs.open_download_stream_by_name("test")).read())
|
||||
await self.fs.delete(two)
|
||||
self.assertEqual(b"foo", await (await self.fs.open_download_stream_by_name("test")).read())
|
||||
await self.fs.delete(one)
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.open_download_stream_by_name("test")
|
||||
|
||||
async def test_get_version(self):
|
||||
await self.fs.upload_from_stream("test", b"foo")
|
||||
await asyncio.sleep(0.01)
|
||||
await self.fs.upload_from_stream("test", b"bar")
|
||||
await asyncio.sleep(0.01)
|
||||
await self.fs.upload_from_stream("test", b"baz")
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
self.assertEqual(
|
||||
b"foo", await (await self.fs.open_download_stream_by_name("test", revision=0)).read()
|
||||
)
|
||||
self.assertEqual(
|
||||
b"bar", await (await self.fs.open_download_stream_by_name("test", revision=1)).read()
|
||||
)
|
||||
self.assertEqual(
|
||||
b"baz", await (await self.fs.open_download_stream_by_name("test", revision=2)).read()
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
b"baz", await (await self.fs.open_download_stream_by_name("test", revision=-1)).read()
|
||||
)
|
||||
self.assertEqual(
|
||||
b"bar", await (await self.fs.open_download_stream_by_name("test", revision=-2)).read()
|
||||
)
|
||||
self.assertEqual(
|
||||
b"foo", await (await self.fs.open_download_stream_by_name("test", revision=-3)).read()
|
||||
)
|
||||
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.open_download_stream_by_name("test", revision=3)
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.open_download_stream_by_name("test", revision=-4)
|
||||
|
||||
async def test_upload_from_stream(self):
|
||||
oid = await self.fs.upload_from_stream(
|
||||
"test_file", BytesIO(b"hello world"), chunk_size_bytes=1
|
||||
)
|
||||
self.assertEqual(11, await self.db.fs.chunks.count_documents({}))
|
||||
self.assertEqual(b"hello world", await (await self.fs.open_download_stream(oid)).read())
|
||||
|
||||
async def test_upload_from_stream_with_id(self):
|
||||
oid = ObjectId()
|
||||
await self.fs.upload_from_stream_with_id(
|
||||
oid, "test_file_custom_id", BytesIO(b"custom id"), chunk_size_bytes=1
|
||||
)
|
||||
self.assertEqual(b"custom id", await (await self.fs.open_download_stream(oid)).read())
|
||||
|
||||
@patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_CHUNKS", 3)
|
||||
@async_client_context.require_failCommand_fail_point
|
||||
async def test_upload_bulk_write_error(self):
|
||||
# Test BulkWriteError from insert_many is converted to an insert_one style error.
|
||||
expected_wce = {
|
||||
"code": 100,
|
||||
"codeName": "UnsatisfiableWriteConcern",
|
||||
"errmsg": "Not enough data-bearing nodes",
|
||||
}
|
||||
cause_wce = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 2},
|
||||
"data": {"failCommands": ["insert"], "writeConcernError": expected_wce},
|
||||
}
|
||||
gin = self.fs.open_upload_stream("test_file", chunk_size_bytes=1)
|
||||
async with self.fail_point(cause_wce):
|
||||
# Assert we raise WriteConcernError, not BulkWriteError.
|
||||
with self.assertRaises(WriteConcernError):
|
||||
await gin.write(b"hello world")
|
||||
# 3 chunks were uploaded.
|
||||
self.assertEqual(3, await self.db.fs.chunks.count_documents({"files_id": gin._id}))
|
||||
await gin.abort()
|
||||
|
||||
@patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_CHUNKS", 10)
|
||||
async def test_upload_batching(self):
|
||||
async with self.fs.open_upload_stream("test_file", chunk_size_bytes=1) as gin:
|
||||
await gin.write(b"s" * (10 - 1))
|
||||
# No chunks were uploaded yet.
|
||||
self.assertEqual(0, await self.db.fs.chunks.count_documents({"files_id": gin._id}))
|
||||
await gin.write(b"s")
|
||||
# All chunks were uploaded since we hit the _UPLOAD_BUFFER_CHUNKS limit.
|
||||
self.assertEqual(10, await self.db.fs.chunks.count_documents({"files_id": gin._id}))
|
||||
|
||||
async def test_open_upload_stream(self):
|
||||
gin = self.fs.open_upload_stream("from_stream")
|
||||
await gin.write(b"from stream")
|
||||
await gin.close()
|
||||
self.assertEqual(b"from stream", await (await self.fs.open_download_stream(gin._id)).read())
|
||||
|
||||
async def test_open_upload_stream_with_id(self):
|
||||
oid = ObjectId()
|
||||
gin = self.fs.open_upload_stream_with_id(oid, "from_stream_custom_id")
|
||||
await gin.write(b"from stream with custom id")
|
||||
await gin.close()
|
||||
self.assertEqual(
|
||||
b"from stream with custom id", await (await self.fs.open_download_stream(oid)).read()
|
||||
)
|
||||
|
||||
async def test_missing_length_iter(self):
|
||||
# Test fix that guards against PHP-237
|
||||
await self.fs.upload_from_stream("empty", b"")
|
||||
doc = await self.db.fs.files.find_one({"filename": "empty"})
|
||||
assert doc is not None
|
||||
doc.pop("length")
|
||||
await self.db.fs.files.replace_one({"_id": doc["_id"]}, doc)
|
||||
fstr = await self.fs.open_download_stream_by_name("empty")
|
||||
|
||||
async def iterate_file(grid_file):
|
||||
async for _ in grid_file:
|
||||
pass
|
||||
return True
|
||||
|
||||
self.assertTrue(await iterate_file(fstr))
|
||||
|
||||
async def test_gridfs_lazy_connect(self):
|
||||
client = await self.async_single_client(
|
||||
"badhost", connect=False, serverSelectionTimeoutMS=0
|
||||
)
|
||||
cdb = client.db
|
||||
gfs = gridfs.AsyncGridFSBucket(cdb)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
await gfs.delete(0)
|
||||
|
||||
gfs = gridfs.AsyncGridFSBucket(cdb)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
await gfs.upload_from_stream("test", b"") # Still no connection.
|
||||
|
||||
async def test_gridfs_find(self):
|
||||
await self.fs.upload_from_stream("two", b"test2")
|
||||
await asyncio.sleep(0.01)
|
||||
await self.fs.upload_from_stream("two", b"test2+")
|
||||
await asyncio.sleep(0.01)
|
||||
await self.fs.upload_from_stream("one", b"test1")
|
||||
await asyncio.sleep(0.01)
|
||||
await self.fs.upload_from_stream("two", b"test2++")
|
||||
files = self.db.fs.files
|
||||
self.assertEqual(3, await files.count_documents({"filename": "two"}))
|
||||
self.assertEqual(4, await files.count_documents({}))
|
||||
cursor = self.fs.find(
|
||||
{}, no_cursor_timeout=False, sort=[("uploadDate", -1)], skip=1, limit=2
|
||||
)
|
||||
gout = await cursor.next()
|
||||
self.assertEqual(b"test1", await gout.read())
|
||||
await cursor.rewind()
|
||||
gout = await cursor.next()
|
||||
self.assertEqual(b"test1", await gout.read())
|
||||
gout = await cursor.next()
|
||||
self.assertEqual(b"test2+", await gout.read())
|
||||
with self.assertRaises(StopAsyncIteration):
|
||||
await cursor.next()
|
||||
await cursor.close()
|
||||
self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})
|
||||
|
||||
async def test_grid_in_non_int_chunksize(self):
|
||||
# Lua, and perhaps other buggy AsyncGridFS clients, store size as a float.
|
||||
data = b"data"
|
||||
await self.fs.upload_from_stream("f", data)
|
||||
await self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}})
|
||||
|
||||
self.assertEqual(data, await (await self.fs.open_download_stream_by_name("f")).read())
|
||||
|
||||
async def test_unacknowledged(self):
|
||||
# w=0 is prohibited.
|
||||
with self.assertRaises(ConfigurationError):
|
||||
gridfs.AsyncGridFSBucket((await self.async_rs_or_single_client(w=0)).pymongo_test)
|
||||
|
||||
async def test_rename(self):
|
||||
_id = await self.fs.upload_from_stream("first_name", b"testing")
|
||||
self.assertEqual(
|
||||
b"testing", await (await self.fs.open_download_stream_by_name("first_name")).read()
|
||||
)
|
||||
|
||||
await self.fs.rename(_id, "second_name")
|
||||
with self.assertRaises(NoFile):
|
||||
await self.fs.open_download_stream_by_name("first_name")
|
||||
self.assertEqual(
|
||||
b"testing", await (await self.fs.open_download_stream_by_name("second_name")).read()
|
||||
)
|
||||
|
||||
@patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_SIZE", 5)
|
||||
async def test_abort(self):
|
||||
gin = self.fs.open_upload_stream("test_filename", chunk_size_bytes=5)
|
||||
await gin.write(b"test1")
|
||||
await gin.write(b"test2")
|
||||
await gin.write(b"test3")
|
||||
self.assertEqual(3, await self.db.fs.chunks.count_documents({"files_id": gin._id}))
|
||||
await gin.abort()
|
||||
self.assertTrue(gin.closed)
|
||||
with self.assertRaises(ValueError):
|
||||
await gin.write(b"test4")
|
||||
self.assertEqual(0, await self.db.fs.chunks.count_documents({"files_id": gin._id}))
|
||||
|
||||
async def test_download_to_stream(self):
|
||||
file1 = BytesIO(b"hello world")
|
||||
# Test with one chunk.
|
||||
oid = await self.fs.upload_from_stream("one_chunk", file1)
|
||||
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
|
||||
file2 = BytesIO()
|
||||
await self.fs.download_to_stream(oid, file2)
|
||||
file1.seek(0)
|
||||
file2.seek(0)
|
||||
self.assertEqual(file1.read(), file2.read())
|
||||
|
||||
# Test with many chunks.
|
||||
await self.db.drop_collection("fs.files")
|
||||
await self.db.drop_collection("fs.chunks")
|
||||
file1.seek(0)
|
||||
oid = await self.fs.upload_from_stream("many_chunks", file1, chunk_size_bytes=1)
|
||||
self.assertEqual(11, await self.db.fs.chunks.count_documents({}))
|
||||
file2 = BytesIO()
|
||||
await self.fs.download_to_stream(oid, file2)
|
||||
file1.seek(0)
|
||||
file2.seek(0)
|
||||
self.assertEqual(file1.read(), file2.read())
|
||||
|
||||
async def test_download_to_stream_by_name(self):
|
||||
file1 = BytesIO(b"hello world")
|
||||
# Test with one chunk.
|
||||
_ = await self.fs.upload_from_stream("one_chunk", file1)
|
||||
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
|
||||
file2 = BytesIO()
|
||||
await self.fs.download_to_stream_by_name("one_chunk", file2)
|
||||
file1.seek(0)
|
||||
file2.seek(0)
|
||||
self.assertEqual(file1.read(), file2.read())
|
||||
|
||||
# Test with many chunks.
|
||||
await self.db.drop_collection("fs.files")
|
||||
await self.db.drop_collection("fs.chunks")
|
||||
file1.seek(0)
|
||||
await self.fs.upload_from_stream("many_chunks", file1, chunk_size_bytes=1)
|
||||
self.assertEqual(11, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
file2 = BytesIO()
|
||||
await self.fs.download_to_stream_by_name("many_chunks", file2)
|
||||
file1.seek(0)
|
||||
file2.seek(0)
|
||||
self.assertEqual(file1.read(), file2.read())
|
||||
|
||||
async def test_md5(self):
|
||||
gin = self.fs.open_upload_stream("no md5")
|
||||
await gin.write(b"no md5 sum")
|
||||
await gin.close()
|
||||
self.assertIsNone(gin.md5)
|
||||
|
||||
gout = await self.fs.open_download_stream(gin._id)
|
||||
self.assertIsNone(gout.md5)
|
||||
|
||||
gin = self.fs.open_upload_stream_with_id(ObjectId(), "also no md5")
|
||||
await gin.write(b"also no md5 sum")
|
||||
await gin.close()
|
||||
self.assertIsNone(gin.md5)
|
||||
|
||||
gout = await self.fs.open_download_stream(gin._id)
|
||||
self.assertIsNone(gout.md5)
|
||||
|
||||
|
||||
class TestGridfsBucketReplicaSet(AsyncIntegrationTest):
|
||||
@async_client_context.require_secondaries_count(1)
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def asyncTearDownClass(cls):
|
||||
await async_client_context.client.drop_database("gfsbucketreplica")
|
||||
|
||||
async def test_gridfs_replica_set(self):
|
||||
rsc = await self.async_rs_client(
|
||||
w=async_client_context.w, read_preference=ReadPreference.SECONDARY
|
||||
)
|
||||
|
||||
gfs = gridfs.AsyncGridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest")
|
||||
oid = await gfs.upload_from_stream("test_filename", b"foo")
|
||||
content = await (await gfs.open_download_stream(oid)).read()
|
||||
self.assertEqual(b"foo", content)
|
||||
|
||||
async def test_gridfs_secondary(self):
|
||||
secondary_host, secondary_port = one(await self.client.secondaries)
|
||||
secondary_connection = await self.async_single_client(
|
||||
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY
|
||||
)
|
||||
|
||||
# Should detect it's connected to secondary and not attempt to
|
||||
# create index
|
||||
gfs = gridfs.AsyncGridFSBucket(
|
||||
secondary_connection.gfsbucketreplica, "gfsbucketsecondarytest"
|
||||
)
|
||||
|
||||
# This won't detect secondary, raises error
|
||||
with self.assertRaises(NotPrimaryError):
|
||||
await gfs.upload_from_stream("test_filename", b"foo")
|
||||
|
||||
async def test_gridfs_secondary_lazy(self):
|
||||
# Should detect it's connected to secondary and not attempt to
|
||||
# create index.
|
||||
secondary_host, secondary_port = one(await self.client.secondaries)
|
||||
client = await self.async_single_client(
|
||||
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False
|
||||
)
|
||||
|
||||
# Still no connection.
|
||||
gfs = gridfs.AsyncGridFSBucket(client.gfsbucketreplica, "gfsbucketsecondarylazytest")
|
||||
|
||||
# Connects, doesn't create index.
|
||||
with self.assertRaises(NoFile):
|
||||
await gfs.open_download_stream_by_name("test_filename")
|
||||
with self.assertRaises(NotPrimaryError):
|
||||
await gfs.upload_from_stream("test_filename", b"data")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
211
test/asynchronous/test_server_selection.py
Normal file
211
test/asynchronous/test_server_selection.py
Normal file
@ -0,0 +1,211 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the topology module's Server Selection Spec implementation."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from pymongo import AsyncMongoClient, ReadPreference
|
||||
from pymongo.asynchronous.settings import TopologySettings
|
||||
from pymongo.asynchronous.topology import Topology
|
||||
from pymongo.errors import ServerSelectionTimeoutError
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
from pymongo.typings import strip_optional
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
from test.asynchronous.utils_selection_tests import (
|
||||
create_selection_tests,
|
||||
get_addresses,
|
||||
get_topology_settings_dict,
|
||||
make_server_description,
|
||||
)
|
||||
from test.utils import (
|
||||
EventListener,
|
||||
FunctionCallRecorder,
|
||||
OvertCommandListener,
|
||||
async_wait_until,
|
||||
)
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(
|
||||
Path(__file__).resolve().parent, "server_selection", "server_selection"
|
||||
)
|
||||
else:
|
||||
TEST_PATH = os.path.join(
|
||||
Path(__file__).resolve().parent.parent, "server_selection", "server_selection"
|
||||
)
|
||||
|
||||
|
||||
class SelectionStoreSelector:
|
||||
"""No-op selector that keeps track of what was passed to it."""
|
||||
|
||||
def __init__(self):
|
||||
self.selection = None
|
||||
|
||||
def __call__(self, selection):
|
||||
self.selection = selection
|
||||
return selection
|
||||
|
||||
|
||||
class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
class TestCustomServerSelectorFunction(AsyncIntegrationTest):
|
||||
@async_client_context.require_replica_set
|
||||
async def test_functional_select_max_port_number_host(self):
|
||||
# Selector that returns server with highest port number.
|
||||
def custom_selector(servers):
|
||||
ports = [s.address[1] for s in servers]
|
||||
idx = ports.index(max(ports))
|
||||
return [servers[idx]]
|
||||
|
||||
# Initialize client with appropriate listeners.
|
||||
listener = OvertCommandListener()
|
||||
client = await self.async_rs_or_single_client(
|
||||
server_selector=custom_selector, event_listeners=[listener]
|
||||
)
|
||||
coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll
|
||||
self.addAsyncCleanup(client.drop_database, "testdb")
|
||||
|
||||
# Wait the node list to be fully populated.
|
||||
async def all_hosts_started():
|
||||
return len((await client.admin.command(HelloCompat.LEGACY_CMD))["hosts"]) == len(
|
||||
client._topology._description.readable_servers
|
||||
)
|
||||
|
||||
await async_wait_until(all_hosts_started, "receive heartbeat from all hosts")
|
||||
|
||||
expected_port = max(
|
||||
[strip_optional(n.address[1]) for n in client._topology._description.readable_servers]
|
||||
)
|
||||
|
||||
# Insert 1 record and access it 10 times.
|
||||
await coll.insert_one({"name": "John Doe"})
|
||||
for _ in range(10):
|
||||
await coll.find_one({"name": "John Doe"})
|
||||
|
||||
# Confirm all find commands are run against appropriate host.
|
||||
for command in listener.started_events:
|
||||
if command.command_name == "find":
|
||||
self.assertEqual(command.connection_id[1], expected_port)
|
||||
|
||||
async def test_invalid_server_selector(self):
|
||||
# Client initialization must fail if server_selector is not callable.
|
||||
for selector_candidate in [[], 10, "string", {}]:
|
||||
with self.assertRaisesRegex(ValueError, "must be a callable"):
|
||||
AsyncMongoClient(connect=False, server_selector=selector_candidate)
|
||||
|
||||
# None value for server_selector is OK.
|
||||
AsyncMongoClient(connect=False, server_selector=None)
|
||||
|
||||
@async_client_context.require_replica_set
|
||||
async def test_selector_called(self):
|
||||
selector = FunctionCallRecorder(lambda x: x)
|
||||
|
||||
# Client setup.
|
||||
mongo_client = await self.async_rs_or_single_client(server_selector=selector)
|
||||
test_collection = mongo_client.testdb.test_collection
|
||||
self.addAsyncCleanup(mongo_client.drop_database, "testdb")
|
||||
|
||||
# Do N operations and test selector is called at least N times.
|
||||
await test_collection.insert_one({"age": 20, "name": "John"})
|
||||
await test_collection.insert_one({"age": 31, "name": "Jane"})
|
||||
await test_collection.update_one({"name": "Jane"}, {"$set": {"age": 21}})
|
||||
await test_collection.find_one({"name": "Roe"})
|
||||
self.assertGreaterEqual(selector.call_count, 4)
|
||||
|
||||
@async_client_context.require_replica_set
|
||||
async def test_latency_threshold_application(self):
|
||||
selector = SelectionStoreSelector()
|
||||
|
||||
scenario_def: dict = {
|
||||
"topology_description": {
|
||||
"type": "ReplicaSetWithPrimary",
|
||||
"servers": [
|
||||
{"address": "b:27017", "avg_rtt_ms": 10000, "type": "RSSecondary", "tag": {}},
|
||||
{"address": "c:27017", "avg_rtt_ms": 20000, "type": "RSSecondary", "tag": {}},
|
||||
{"address": "a:27017", "avg_rtt_ms": 30000, "type": "RSPrimary", "tag": {}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
# Create & populate Topology such that all but one server is too slow.
|
||||
rtt_times = [srv["avg_rtt_ms"] for srv in scenario_def["topology_description"]["servers"]]
|
||||
min_rtt_idx = rtt_times.index(min(rtt_times))
|
||||
seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"])
|
||||
settings = get_topology_settings_dict(
|
||||
heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector
|
||||
)
|
||||
topology = Topology(TopologySettings(**settings))
|
||||
await topology.open()
|
||||
for server in scenario_def["topology_description"]["servers"]:
|
||||
server_description = make_server_description(server, hosts)
|
||||
await topology.on_change(server_description)
|
||||
|
||||
# Invoke server selection and assert no filtering based on latency
|
||||
# prior to custom server selection logic kicking in.
|
||||
server = await topology.select_server(ReadPreference.NEAREST, _Op.TEST)
|
||||
assert selector.selection is not None
|
||||
self.assertEqual(len(selector.selection), len(topology.description.server_descriptions()))
|
||||
|
||||
# Ensure proper filtering based on latency after custom selection.
|
||||
self.assertEqual(server.description.address, seeds[min_rtt_idx])
|
||||
|
||||
@async_client_context.require_replica_set
|
||||
async def test_server_selector_bypassed(self):
|
||||
selector = FunctionCallRecorder(lambda x: x)
|
||||
|
||||
scenario_def = {
|
||||
"topology_description": {
|
||||
"type": "ReplicaSetNoPrimary",
|
||||
"servers": [
|
||||
{"address": "b:27017", "avg_rtt_ms": 10000, "type": "RSSecondary", "tag": {}},
|
||||
{"address": "c:27017", "avg_rtt_ms": 20000, "type": "RSSecondary", "tag": {}},
|
||||
{"address": "a:27017", "avg_rtt_ms": 30000, "type": "RSSecondary", "tag": {}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
# Create & populate Topology such that no server is writeable.
|
||||
seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"])
|
||||
settings = get_topology_settings_dict(
|
||||
heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector
|
||||
)
|
||||
topology = Topology(TopologySettings(**settings))
|
||||
await topology.open()
|
||||
for server in scenario_def["topology_description"]["servers"]:
|
||||
server_description = make_server_description(server, hosts)
|
||||
await topology.on_change(server_description)
|
||||
|
||||
# Invoke server selection and assert no calls to our custom selector.
|
||||
with self.assertRaisesRegex(ServerSelectionTimeoutError, "No primary available for writes"):
|
||||
await topology.select_server(
|
||||
writable_server_selector, _Op.TEST, server_selection_timeout=0.1
|
||||
)
|
||||
self.assertEqual(selector.call_count, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
179
test/asynchronous/test_server_selection_in_window.py
Normal file
179
test/asynchronous/test_server_selection_in_window.py
Normal file
@ -0,0 +1,179 @@
|
||||
# Copyright 2020-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the topology module's Server Selection Spec implementation."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
from test.asynchronous.helpers import ConcurrentRunner
|
||||
from test.asynchronous.utils_selection_tests import create_topology
|
||||
from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
OvertCommandListener,
|
||||
async_get_pool,
|
||||
async_wait_until,
|
||||
)
|
||||
|
||||
from pymongo.common import clean_node
|
||||
from pymongo.monitoring import ConnectionReadyEvent
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
|
||||
_IS_SYNC = False
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window")
|
||||
else:
|
||||
TEST_PATH = os.path.join(
|
||||
Path(__file__).resolve().parent.parent, "server_selection", "in_window"
|
||||
)
|
||||
|
||||
|
||||
class TestAllScenarios(unittest.IsolatedAsyncioTestCase):
|
||||
async def run_scenario(self, scenario_def):
|
||||
topology = await create_topology(scenario_def)
|
||||
|
||||
# Update mock operation_count state:
|
||||
for mock in scenario_def["mocked_topology_state"]:
|
||||
address = clean_node(mock["address"])
|
||||
server = topology.get_server_by_address(address)
|
||||
server.pool.operation_count = mock["operation_count"]
|
||||
|
||||
pref = ReadPreference.NEAREST
|
||||
counts = {address: 0 for address in topology._description.server_descriptions()}
|
||||
|
||||
# Number of times to repeat server selection
|
||||
iterations = scenario_def["iterations"]
|
||||
for _ in range(iterations):
|
||||
server = await topology.select_server(pref, _Op.TEST, server_selection_timeout=0)
|
||||
counts[server.description.address] += 1
|
||||
|
||||
# Verify expected_frequencies
|
||||
outcome = scenario_def["outcome"]
|
||||
tolerance = outcome["tolerance"]
|
||||
expected_frequencies = outcome["expected_frequencies"]
|
||||
for host_str, freq in expected_frequencies.items():
|
||||
address = clean_node(host_str)
|
||||
actual_freq = float(counts[address]) / iterations
|
||||
if freq == 0:
|
||||
# Should be exactly 0.
|
||||
self.assertEqual(actual_freq, 0)
|
||||
else:
|
||||
# Should be within 'tolerance'.
|
||||
self.assertAlmostEqual(actual_freq, freq, delta=tolerance)
|
||||
|
||||
|
||||
def create_test(scenario_def, test, name):
|
||||
async def run_scenario(self):
|
||||
await self.run_scenario(scenario_def)
|
||||
|
||||
return run_scenario
|
||||
|
||||
|
||||
class CustomSpecTestCreator(AsyncSpecTestCreator):
|
||||
def tests(self, scenario_def):
|
||||
"""Extract the tests from a spec file.
|
||||
|
||||
Server selection in_window tests do not have a 'tests' field.
|
||||
The whole file represents a single test case.
|
||||
"""
|
||||
return [scenario_def]
|
||||
|
||||
|
||||
CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests()
|
||||
|
||||
|
||||
class FinderTask(ConcurrentRunner):
|
||||
def __init__(self, collection, iterations):
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
self.collection = collection
|
||||
self.iterations = iterations
|
||||
self.passed = False
|
||||
|
||||
async def run(self):
|
||||
for _ in range(self.iterations):
|
||||
await self.collection.find_one({})
|
||||
self.passed = True
|
||||
|
||||
|
||||
class TestProse(AsyncIntegrationTest):
|
||||
async def frequencies(self, client, listener, n_finds=10):
|
||||
coll = client.test.test
|
||||
N_TASKS = 10
|
||||
tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)]
|
||||
for task in tasks:
|
||||
await task.start()
|
||||
for task in tasks:
|
||||
await task.join()
|
||||
for task in tasks:
|
||||
self.assertTrue(task.passed)
|
||||
|
||||
events = listener.started_events
|
||||
self.assertEqual(len(events), n_finds * N_TASKS)
|
||||
nodes = client.nodes
|
||||
self.assertEqual(len(nodes), 2)
|
||||
freqs = {address: 0.0 for address in nodes}
|
||||
for event in events:
|
||||
freqs[event.connection_id] += 1
|
||||
for address in freqs:
|
||||
freqs[address] = freqs[address] / float(len(events))
|
||||
return freqs
|
||||
|
||||
@async_client_context.require_failCommand_appName
|
||||
@async_client_context.require_multiple_mongoses
|
||||
async def test_load_balancing(self):
|
||||
listener = OvertCommandListener()
|
||||
cmap_listener = CMAPListener()
|
||||
# PYTHON-2584: Use a large localThresholdMS to avoid the impact of
|
||||
# varying RTTs.
|
||||
client = await self.async_rs_client(
|
||||
async_client_context.mongos_seeds(),
|
||||
appName="loadBalancingTest",
|
||||
event_listeners=[listener, cmap_listener],
|
||||
localThresholdMS=30000,
|
||||
minPoolSize=10,
|
||||
)
|
||||
await async_wait_until(lambda: len(client.nodes) == 2, "discover both nodes")
|
||||
# Wait for both pools to be populated.
|
||||
await cmap_listener.async_wait_for_event(ConnectionReadyEvent, 20)
|
||||
# Delay find commands on only one mongos.
|
||||
delay_finds = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 10000},
|
||||
"data": {
|
||||
"failCommands": ["find"],
|
||||
"blockConnection": True,
|
||||
"blockTimeMS": 500,
|
||||
"appName": "loadBalancingTest",
|
||||
},
|
||||
}
|
||||
async with self.fail_point(delay_finds):
|
||||
nodes = async_client_context.client.nodes
|
||||
self.assertEqual(len(nodes), 1)
|
||||
delayed_server = next(iter(nodes))
|
||||
freqs = await self.frequencies(client, listener)
|
||||
self.assertLessEqual(freqs[delayed_server], 0.25)
|
||||
listener.reset()
|
||||
freqs = await self.frequencies(client, listener, n_finds=150)
|
||||
self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
203
test/asynchronous/utils_selection_tests.py
Normal file
203
test/asynchronous/utils_selection_tests.py
Normal file
@ -0,0 +1,203 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for testing Server Selection and Max Staleness."""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
from test.asynchronous import AsyncPyMongoTestCase
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.pymongo_mocks import DummyMonitor
|
||||
from test.utils import AsyncMockPool, parse_read_preference
|
||||
from test.utils_selection_tests_shared import (
|
||||
get_addresses,
|
||||
get_topology_type_name,
|
||||
make_server_description,
|
||||
)
|
||||
|
||||
from bson import json_util
|
||||
from pymongo.asynchronous.settings import TopologySettings
|
||||
from pymongo.asynchronous.topology import Topology
|
||||
from pymongo.common import HEARTBEAT_FREQUENCY
|
||||
from pymongo.errors import AutoReconnect, ConfigurationError
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def get_topology_settings_dict(**kwargs):
|
||||
settings = {
|
||||
"monitor_class": DummyMonitor,
|
||||
"heartbeat_frequency": HEARTBEAT_FREQUENCY,
|
||||
"pool_class": AsyncMockPool,
|
||||
}
|
||||
settings.update(kwargs)
|
||||
return settings
|
||||
|
||||
|
||||
async def create_topology(scenario_def, **kwargs):
|
||||
# Initialize topologies.
|
||||
if "heartbeatFrequencyMS" in scenario_def:
|
||||
frequency = int(scenario_def["heartbeatFrequencyMS"]) / 1000.0
|
||||
else:
|
||||
frequency = HEARTBEAT_FREQUENCY
|
||||
|
||||
seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"])
|
||||
|
||||
topology_type = get_topology_type_name(scenario_def)
|
||||
if topology_type == "LoadBalanced":
|
||||
kwargs.setdefault("load_balanced", True)
|
||||
# Force topology description to ReplicaSet
|
||||
elif topology_type in ["ReplicaSetNoPrimary", "ReplicaSetWithPrimary"]:
|
||||
kwargs.setdefault("replica_set_name", "rs")
|
||||
settings = get_topology_settings_dict(heartbeat_frequency=frequency, seeds=seeds, **kwargs)
|
||||
|
||||
# "Eligible servers" is defined in the server selection spec as
|
||||
# the set of servers matching both the ReadPreference's mode
|
||||
# and tag sets.
|
||||
topology = Topology(TopologySettings(**settings))
|
||||
await topology.open()
|
||||
|
||||
# Update topologies with server descriptions.
|
||||
for server in scenario_def["topology_description"]["servers"]:
|
||||
server_description = make_server_description(server, hosts)
|
||||
await topology.on_change(server_description)
|
||||
|
||||
# Assert that descriptions match
|
||||
assert (
|
||||
scenario_def["topology_description"]["type"] == topology.description.topology_type_name
|
||||
), topology.description.topology_type_name
|
||||
|
||||
return topology
|
||||
|
||||
|
||||
def create_test(scenario_def):
|
||||
async def run_scenario(self):
|
||||
_, hosts = get_addresses(scenario_def["topology_description"]["servers"])
|
||||
# "Eligible servers" is defined in the server selection spec as
|
||||
# the set of servers matching both the ReadPreference's mode
|
||||
# and tag sets.
|
||||
top_latency = await create_topology(scenario_def)
|
||||
|
||||
# "In latency window" is defined in the server selection
|
||||
# spec as the subset of suitable_servers that falls within the
|
||||
# allowable latency window.
|
||||
top_suitable = await create_topology(scenario_def, local_threshold_ms=1000000)
|
||||
|
||||
# Create server selector.
|
||||
if scenario_def.get("operation") == "write":
|
||||
pref = writable_server_selector
|
||||
else:
|
||||
# Make first letter lowercase to match read_pref's modes.
|
||||
pref_def = scenario_def["read_preference"]
|
||||
if scenario_def.get("error"):
|
||||
with self.assertRaises((ConfigurationError, ValueError)):
|
||||
# Error can be raised when making Read Pref or selecting.
|
||||
pref = parse_read_preference(pref_def)
|
||||
await top_latency.select_server(pref, _Op.TEST)
|
||||
return
|
||||
|
||||
pref = parse_read_preference(pref_def)
|
||||
|
||||
# Select servers.
|
||||
if not scenario_def.get("suitable_servers"):
|
||||
with self.assertRaises(AutoReconnect):
|
||||
await top_suitable.select_server(pref, _Op.TEST, server_selection_timeout=0)
|
||||
|
||||
return
|
||||
|
||||
if not scenario_def["in_latency_window"]:
|
||||
with self.assertRaises(AutoReconnect):
|
||||
await top_latency.select_server(pref, _Op.TEST, server_selection_timeout=0)
|
||||
|
||||
return
|
||||
|
||||
actual_suitable_s = await top_suitable.select_servers(
|
||||
pref, _Op.TEST, server_selection_timeout=0
|
||||
)
|
||||
actual_latency_s = await top_latency.select_servers(
|
||||
pref, _Op.TEST, server_selection_timeout=0
|
||||
)
|
||||
|
||||
expected_suitable_servers = {}
|
||||
for server in scenario_def["suitable_servers"]:
|
||||
server_description = make_server_description(server, hosts)
|
||||
expected_suitable_servers[server["address"]] = server_description
|
||||
|
||||
actual_suitable_servers = {}
|
||||
for s in actual_suitable_s:
|
||||
actual_suitable_servers[
|
||||
"%s:%d" % (s.description.address[0], s.description.address[1])
|
||||
] = s.description
|
||||
|
||||
self.assertEqual(len(actual_suitable_servers), len(expected_suitable_servers))
|
||||
for k, actual in actual_suitable_servers.items():
|
||||
expected = expected_suitable_servers[k]
|
||||
self.assertEqual(expected.address, actual.address)
|
||||
self.assertEqual(expected.server_type, actual.server_type)
|
||||
self.assertEqual(expected.round_trip_time, actual.round_trip_time)
|
||||
self.assertEqual(expected.tags, actual.tags)
|
||||
self.assertEqual(expected.all_hosts, actual.all_hosts)
|
||||
|
||||
expected_latency_servers = {}
|
||||
for server in scenario_def["in_latency_window"]:
|
||||
server_description = make_server_description(server, hosts)
|
||||
expected_latency_servers[server["address"]] = server_description
|
||||
|
||||
actual_latency_servers = {}
|
||||
for s in actual_latency_s:
|
||||
actual_latency_servers[
|
||||
"%s:%d" % (s.description.address[0], s.description.address[1])
|
||||
] = s.description
|
||||
|
||||
self.assertEqual(len(actual_latency_servers), len(expected_latency_servers))
|
||||
for k, actual in actual_latency_servers.items():
|
||||
expected = expected_latency_servers[k]
|
||||
self.assertEqual(expected.address, actual.address)
|
||||
self.assertEqual(expected.server_type, actual.server_type)
|
||||
self.assertEqual(expected.round_trip_time, actual.round_trip_time)
|
||||
self.assertEqual(expected.tags, actual.tags)
|
||||
self.assertEqual(expected.all_hosts, actual.all_hosts)
|
||||
|
||||
return run_scenario
|
||||
|
||||
|
||||
def create_selection_tests(test_dir):
|
||||
class TestAllScenarios(AsyncPyMongoTestCase):
|
||||
pass
|
||||
|
||||
for dirpath, _, filenames in os.walk(test_dir):
|
||||
dirname = os.path.split(dirpath)
|
||||
dirname = os.path.split(dirname[-2])[-1] + "_" + dirname[-1]
|
||||
|
||||
for filename in filenames:
|
||||
if os.path.splitext(filename)[1] != ".json":
|
||||
continue
|
||||
with open(os.path.join(dirpath, filename)) as scenario_stream:
|
||||
scenario_def = json_util.loads(scenario_stream.read())
|
||||
|
||||
# Construct test from scenario.
|
||||
new_test = create_test(scenario_def)
|
||||
test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}"
|
||||
|
||||
new_test.__name__ = test_name
|
||||
setattr(TestAllScenarios, new_test.__name__, new_test)
|
||||
|
||||
return TestAllScenarios
|
||||
@ -229,7 +229,7 @@ class AsyncSpecTestCreator:
|
||||
str(test_def["description"].replace(" ", "_").replace(".", "_")),
|
||||
)
|
||||
|
||||
new_test = await self._create_test(scenario_def, test_def, test_name)
|
||||
new_test = self._create_test(scenario_def, test_def, test_name)
|
||||
new_test = self._ensure_min_max_server_version(scenario_def, new_test)
|
||||
new_test = self.ensure_run_on(scenario_def, new_test)
|
||||
|
||||
|
||||
@ -959,7 +959,6 @@ class TestBulkWriteConcern(BulkTestBase):
|
||||
@client_context.require_replica_set
|
||||
@client_context.require_secondaries_count(1)
|
||||
def test_write_concern_failure_ordered(self):
|
||||
self.skipTest("Skipping until PYTHON-4865 is resolved.")
|
||||
details = None
|
||||
|
||||
# Ensure we don't raise on wnote.
|
||||
|
||||
@ -15,9 +15,11 @@
|
||||
"""Execute Transactions Spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@ -60,6 +62,8 @@ from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.synchronous.pool import PoolState, _PoolClosedError
|
||||
from pymongo.topology_description import updated_topology_description
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
OBJECT_TYPES = {
|
||||
# Event types.
|
||||
"ConnectionCheckedIn": ConnectionCheckedInEvent,
|
||||
@ -81,7 +85,10 @@ OBJECT_TYPES = {
|
||||
|
||||
class TestCMAP(IntegrationTest):
|
||||
# Location of JSON test specifications.
|
||||
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "connection_monitoring")
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "connection_monitoring")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "connection_monitoring")
|
||||
|
||||
# Test operations:
|
||||
|
||||
@ -258,7 +265,6 @@ class TestCMAP(IntegrationTest):
|
||||
client._topology.open()
|
||||
else:
|
||||
client._get_topology()
|
||||
self.addCleanup(client.close)
|
||||
self.pool = list(client._topology._servers.values())[0].pool
|
||||
|
||||
# Map of target names to Thread objects.
|
||||
@ -315,13 +321,11 @@ class TestCMAP(IntegrationTest):
|
||||
#
|
||||
def test_1_client_connection_pool_options(self):
|
||||
client = self.rs_or_single_client(**self.POOL_OPTIONS)
|
||||
self.addCleanup(client.close)
|
||||
pool_opts = get_pool(client).opts
|
||||
pool_opts = (get_pool(client)).opts
|
||||
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
|
||||
|
||||
def test_2_all_client_pools_have_same_options(self):
|
||||
client = self.rs_or_single_client(**self.POOL_OPTIONS)
|
||||
self.addCleanup(client.close)
|
||||
client.admin.command("ping")
|
||||
# Discover at least one secondary.
|
||||
if client_context.has_secondaries:
|
||||
@ -337,14 +341,12 @@ class TestCMAP(IntegrationTest):
|
||||
opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()])
|
||||
uri = f"mongodb://{client_context.pair}/?{opts}"
|
||||
client = self.rs_or_single_client(uri)
|
||||
self.addCleanup(client.close)
|
||||
pool_opts = get_pool(client).opts
|
||||
pool_opts = (get_pool(client)).opts
|
||||
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
|
||||
|
||||
def test_4_subscribe_to_events(self):
|
||||
listener = CMAPListener()
|
||||
client = self.single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
self.assertEqual(listener.event_count(PoolCreatedEvent), 1)
|
||||
|
||||
# Creates a new connection.
|
||||
@ -368,7 +370,6 @@ class TestCMAP(IntegrationTest):
|
||||
def test_5_check_out_fails_connection_error(self):
|
||||
listener = CMAPListener()
|
||||
client = self.single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
pool = get_pool(client)
|
||||
|
||||
def mock_connect(*args, **kwargs):
|
||||
@ -397,7 +398,6 @@ class TestCMAP(IntegrationTest):
|
||||
client = self.single_client_noauth(
|
||||
username="notauser", password="fail", event_listeners=[listener]
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
|
||||
# Attempt to create a new connection.
|
||||
with self.assertRaisesRegex(OperationFailure, "failed"):
|
||||
|
||||
@ -16,11 +16,13 @@
|
||||
"""Tests for the gridfs package."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from io import BytesIO
|
||||
from test.helpers import ConcurrentRunner
|
||||
from unittest.mock import patch
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
@ -41,10 +43,12 @@ from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
class JustWrite(threading.Thread):
|
||||
|
||||
class JustWrite(ConcurrentRunner):
|
||||
def __init__(self, fs, n):
|
||||
threading.Thread.__init__(self)
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.n = n
|
||||
self.daemon = True
|
||||
@ -56,9 +60,9 @@ class JustWrite(threading.Thread):
|
||||
file.close()
|
||||
|
||||
|
||||
class JustRead(threading.Thread):
|
||||
class JustRead(ConcurrentRunner):
|
||||
def __init__(self, fs, n, results):
|
||||
threading.Thread.__init__(self)
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.n = n
|
||||
self.results = results
|
||||
@ -98,19 +102,21 @@ class TestGridfs(IntegrationTest):
|
||||
|
||||
def test_basic(self):
|
||||
oid = self.fs.put(b"hello world")
|
||||
self.assertEqual(b"hello world", self.fs.get(oid).read())
|
||||
self.assertEqual(b"hello world", (self.fs.get(oid)).read())
|
||||
self.assertEqual(1, self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(1, self.db.fs.chunks.count_documents({}))
|
||||
|
||||
self.fs.delete(oid)
|
||||
self.assertRaises(NoFile, self.fs.get, oid)
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.get(oid)
|
||||
self.assertEqual(0, self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
|
||||
|
||||
self.assertRaises(NoFile, self.fs.get, "foo")
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.get("foo")
|
||||
oid = self.fs.put(b"hello world", _id="foo")
|
||||
self.assertEqual("foo", oid)
|
||||
self.assertEqual(b"hello world", self.fs.get("foo").read())
|
||||
self.assertEqual(b"hello world", (self.fs.get("foo")).read())
|
||||
|
||||
def test_multi_chunk_delete(self):
|
||||
self.db.fs.drop()
|
||||
@ -142,7 +148,7 @@ class TestGridfs(IntegrationTest):
|
||||
|
||||
def test_empty_file(self):
|
||||
oid = self.fs.put(b"")
|
||||
self.assertEqual(b"", self.fs.get(oid).read())
|
||||
self.assertEqual(b"", (self.fs.get(oid)).read())
|
||||
self.assertEqual(1, self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
|
||||
|
||||
@ -159,10 +165,12 @@ class TestGridfs(IntegrationTest):
|
||||
self.db.fs.chunks.update_one({"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}})
|
||||
try:
|
||||
out = self.fs.get(files_id)
|
||||
self.assertRaises(CorruptGridFile, out.read)
|
||||
with self.assertRaises(CorruptGridFile):
|
||||
out.read()
|
||||
|
||||
out = self.fs.get(files_id)
|
||||
self.assertRaises(CorruptGridFile, out.readline)
|
||||
with self.assertRaises(CorruptGridFile):
|
||||
out.readline()
|
||||
finally:
|
||||
self.fs.delete(files_id)
|
||||
|
||||
@ -177,31 +185,33 @@ class TestGridfs(IntegrationTest):
|
||||
self.assertTrue(
|
||||
any(
|
||||
info.get("key") == [("files_id", 1), ("n", 1)]
|
||||
for info in chunks.index_information().values()
|
||||
for info in (chunks.index_information()).values()
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
any(
|
||||
info.get("key") == [("filename", 1), ("uploadDate", 1)]
|
||||
for info in files.index_information().values()
|
||||
for info in (files.index_information()).values()
|
||||
)
|
||||
)
|
||||
|
||||
def test_alt_collection(self):
|
||||
oid = self.alt.put(b"hello world")
|
||||
self.assertEqual(b"hello world", self.alt.get(oid).read())
|
||||
self.assertEqual(b"hello world", (self.alt.get(oid)).read())
|
||||
self.assertEqual(1, self.db.alt.files.count_documents({}))
|
||||
self.assertEqual(1, self.db.alt.chunks.count_documents({}))
|
||||
|
||||
self.alt.delete(oid)
|
||||
self.assertRaises(NoFile, self.alt.get, oid)
|
||||
with self.assertRaises(NoFile):
|
||||
self.alt.get(oid)
|
||||
self.assertEqual(0, self.db.alt.files.count_documents({}))
|
||||
self.assertEqual(0, self.db.alt.chunks.count_documents({}))
|
||||
|
||||
self.assertRaises(NoFile, self.alt.get, "foo")
|
||||
with self.assertRaises(NoFile):
|
||||
self.alt.get("foo")
|
||||
oid = self.alt.put(b"hello world", _id="foo")
|
||||
self.assertEqual("foo", oid)
|
||||
self.assertEqual(b"hello world", self.alt.get("foo").read())
|
||||
self.assertEqual(b"hello world", (self.alt.get("foo")).read())
|
||||
|
||||
self.alt.put(b"", filename="mike")
|
||||
self.alt.put(b"foo", filename="test")
|
||||
@ -212,23 +222,23 @@ class TestGridfs(IntegrationTest):
|
||||
def test_threaded_reads(self):
|
||||
self.fs.put(b"hello", _id="test")
|
||||
|
||||
threads = []
|
||||
tasks = []
|
||||
results: list = []
|
||||
for i in range(10):
|
||||
threads.append(JustRead(self.fs, 10, results))
|
||||
threads[i].start()
|
||||
tasks.append(JustRead(self.fs, 10, results))
|
||||
tasks[i].start()
|
||||
|
||||
joinall(threads)
|
||||
joinall(tasks)
|
||||
|
||||
self.assertEqual(100 * [b"hello"], results)
|
||||
|
||||
def test_threaded_writes(self):
|
||||
threads = []
|
||||
tasks = []
|
||||
for i in range(10):
|
||||
threads.append(JustWrite(self.fs, 10))
|
||||
threads[i].start()
|
||||
tasks.append(JustWrite(self.fs, 10))
|
||||
tasks[i].start()
|
||||
|
||||
joinall(threads)
|
||||
joinall(tasks)
|
||||
|
||||
f = self.fs.get_last_version("test")
|
||||
self.assertEqual(f.read(), b"hello")
|
||||
@ -246,34 +256,37 @@ class TestGridfs(IntegrationTest):
|
||||
two = two._id
|
||||
three = self.fs.put(b"baz", filename="test")
|
||||
|
||||
self.assertEqual(b"baz", self.fs.get_last_version("test").read())
|
||||
self.assertEqual(b"baz", (self.fs.get_last_version("test")).read())
|
||||
self.fs.delete(three)
|
||||
self.assertEqual(b"bar", self.fs.get_last_version("test").read())
|
||||
self.assertEqual(b"bar", (self.fs.get_last_version("test")).read())
|
||||
self.fs.delete(two)
|
||||
self.assertEqual(b"foo", self.fs.get_last_version("test").read())
|
||||
self.assertEqual(b"foo", (self.fs.get_last_version("test")).read())
|
||||
self.fs.delete(one)
|
||||
self.assertRaises(NoFile, self.fs.get_last_version, "test")
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.get_last_version("test")
|
||||
|
||||
def test_get_last_version_with_metadata(self):
|
||||
one = self.fs.put(b"foo", filename="test", author="author")
|
||||
time.sleep(0.01)
|
||||
two = self.fs.put(b"bar", filename="test", author="author")
|
||||
|
||||
self.assertEqual(b"bar", self.fs.get_last_version(author="author").read())
|
||||
self.assertEqual(b"bar", (self.fs.get_last_version(author="author")).read())
|
||||
self.fs.delete(two)
|
||||
self.assertEqual(b"foo", self.fs.get_last_version(author="author").read())
|
||||
self.assertEqual(b"foo", (self.fs.get_last_version(author="author")).read())
|
||||
self.fs.delete(one)
|
||||
|
||||
one = self.fs.put(b"foo", filename="test", author="author1")
|
||||
time.sleep(0.01)
|
||||
two = self.fs.put(b"bar", filename="test", author="author2")
|
||||
|
||||
self.assertEqual(b"foo", self.fs.get_last_version(author="author1").read())
|
||||
self.assertEqual(b"bar", self.fs.get_last_version(author="author2").read())
|
||||
self.assertEqual(b"bar", self.fs.get_last_version(filename="test").read())
|
||||
self.assertEqual(b"foo", (self.fs.get_last_version(author="author1")).read())
|
||||
self.assertEqual(b"bar", (self.fs.get_last_version(author="author2")).read())
|
||||
self.assertEqual(b"bar", (self.fs.get_last_version(filename="test")).read())
|
||||
|
||||
self.assertRaises(NoFile, self.fs.get_last_version, author="author3")
|
||||
self.assertRaises(NoFile, self.fs.get_last_version, filename="nottest", author="author1")
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.get_last_version(author="author3")
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.get_last_version(filename="nottest", author="author1")
|
||||
|
||||
self.fs.delete(one)
|
||||
self.fs.delete(two)
|
||||
@ -286,16 +299,18 @@ class TestGridfs(IntegrationTest):
|
||||
self.fs.put(b"baz", filename="test")
|
||||
time.sleep(0.01)
|
||||
|
||||
self.assertEqual(b"foo", self.fs.get_version("test", 0).read())
|
||||
self.assertEqual(b"bar", self.fs.get_version("test", 1).read())
|
||||
self.assertEqual(b"baz", self.fs.get_version("test", 2).read())
|
||||
self.assertEqual(b"foo", (self.fs.get_version("test", 0)).read())
|
||||
self.assertEqual(b"bar", (self.fs.get_version("test", 1)).read())
|
||||
self.assertEqual(b"baz", (self.fs.get_version("test", 2)).read())
|
||||
|
||||
self.assertEqual(b"baz", self.fs.get_version("test", -1).read())
|
||||
self.assertEqual(b"bar", self.fs.get_version("test", -2).read())
|
||||
self.assertEqual(b"foo", self.fs.get_version("test", -3).read())
|
||||
self.assertEqual(b"baz", (self.fs.get_version("test", -1)).read())
|
||||
self.assertEqual(b"bar", (self.fs.get_version("test", -2)).read())
|
||||
self.assertEqual(b"foo", (self.fs.get_version("test", -3)).read())
|
||||
|
||||
self.assertRaises(NoFile, self.fs.get_version, "test", 3)
|
||||
self.assertRaises(NoFile, self.fs.get_version, "test", -4)
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.get_version("test", 3)
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.get_version("test", -4)
|
||||
|
||||
def test_get_version_with_metadata(self):
|
||||
one = self.fs.put(b"foo", filename="test", author="author1")
|
||||
@ -305,25 +320,32 @@ class TestGridfs(IntegrationTest):
|
||||
three = self.fs.put(b"baz", filename="test", author="author2")
|
||||
|
||||
self.assertEqual(
|
||||
b"foo", self.fs.get_version(filename="test", author="author1", version=-2).read()
|
||||
b"foo",
|
||||
(self.fs.get_version(filename="test", author="author1", version=-2)).read(),
|
||||
)
|
||||
self.assertEqual(
|
||||
b"bar", self.fs.get_version(filename="test", author="author1", version=-1).read()
|
||||
b"bar",
|
||||
(self.fs.get_version(filename="test", author="author1", version=-1)).read(),
|
||||
)
|
||||
self.assertEqual(
|
||||
b"foo", self.fs.get_version(filename="test", author="author1", version=0).read()
|
||||
b"foo",
|
||||
(self.fs.get_version(filename="test", author="author1", version=0)).read(),
|
||||
)
|
||||
self.assertEqual(
|
||||
b"bar", self.fs.get_version(filename="test", author="author1", version=1).read()
|
||||
b"bar",
|
||||
(self.fs.get_version(filename="test", author="author1", version=1)).read(),
|
||||
)
|
||||
self.assertEqual(
|
||||
b"baz", self.fs.get_version(filename="test", author="author2", version=0).read()
|
||||
b"baz",
|
||||
(self.fs.get_version(filename="test", author="author2", version=0)).read(),
|
||||
)
|
||||
self.assertEqual(b"baz", self.fs.get_version(filename="test", version=-1).read())
|
||||
self.assertEqual(b"baz", self.fs.get_version(filename="test", version=2).read())
|
||||
self.assertEqual(b"baz", (self.fs.get_version(filename="test", version=-1)).read())
|
||||
self.assertEqual(b"baz", (self.fs.get_version(filename="test", version=2)).read())
|
||||
|
||||
self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author3")
|
||||
self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author1", version=2)
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.get_version(filename="test", author="author3")
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.get_version(filename="test", author="author1", version=2)
|
||||
|
||||
self.fs.delete(one)
|
||||
self.fs.delete(two)
|
||||
@ -332,11 +354,12 @@ class TestGridfs(IntegrationTest):
|
||||
def test_put_filelike(self):
|
||||
oid = self.fs.put(BytesIO(b"hello world"), chunk_size=1)
|
||||
self.assertEqual(11, self.db.fs.chunks.count_documents({}))
|
||||
self.assertEqual(b"hello world", self.fs.get(oid).read())
|
||||
self.assertEqual(b"hello world", (self.fs.get(oid)).read())
|
||||
|
||||
def test_file_exists(self):
|
||||
oid = self.fs.put(b"hello")
|
||||
self.assertRaises(FileExists, self.fs.put, b"world", _id=oid)
|
||||
with self.assertRaises(FileExists):
|
||||
self.fs.put(b"world", _id=oid)
|
||||
|
||||
one = self.fs.new_file(_id=123)
|
||||
one.write(b"some content")
|
||||
@ -345,15 +368,17 @@ class TestGridfs(IntegrationTest):
|
||||
# Attempt to upload a file with more chunks to the same _id.
|
||||
with patch("gridfs.synchronous.grid_file._UPLOAD_BUFFER_SIZE", DEFAULT_CHUNK_SIZE):
|
||||
two = self.fs.new_file(_id=123)
|
||||
self.assertRaises(FileExists, two.write, b"x" * DEFAULT_CHUNK_SIZE * 3)
|
||||
with self.assertRaises(FileExists):
|
||||
two.write(b"x" * DEFAULT_CHUNK_SIZE * 3)
|
||||
# Original file is still readable (no extra chunks were uploaded).
|
||||
self.assertEqual(self.fs.get(123).read(), b"some content")
|
||||
self.assertEqual((self.fs.get(123)).read(), b"some content")
|
||||
|
||||
two = self.fs.new_file(_id=123)
|
||||
two.write(b"some content")
|
||||
self.assertRaises(FileExists, two.close)
|
||||
with self.assertRaises(FileExists):
|
||||
two.close()
|
||||
# Original file is still readable.
|
||||
self.assertEqual(self.fs.get(123).read(), b"some content")
|
||||
self.assertEqual((self.fs.get(123)).read(), b"some content")
|
||||
|
||||
def test_exists(self):
|
||||
oid = self.fs.put(b"hello")
|
||||
@ -381,15 +406,16 @@ class TestGridfs(IntegrationTest):
|
||||
self.assertFalse(self.fs.exists({"foo": {"$gt": 12}}))
|
||||
|
||||
def test_put_unicode(self):
|
||||
self.assertRaises(TypeError, self.fs.put, "hello")
|
||||
with self.assertRaises(TypeError):
|
||||
self.fs.put("hello")
|
||||
|
||||
oid = self.fs.put("hello", encoding="utf-8")
|
||||
self.assertEqual(b"hello", self.fs.get(oid).read())
|
||||
self.assertEqual("utf-8", self.fs.get(oid).encoding)
|
||||
self.assertEqual(b"hello", (self.fs.get(oid)).read())
|
||||
self.assertEqual("utf-8", (self.fs.get(oid)).encoding)
|
||||
|
||||
oid = self.fs.put("aé", encoding="iso-8859-1")
|
||||
self.assertEqual("aé".encode("iso-8859-1"), self.fs.get(oid).read())
|
||||
self.assertEqual("iso-8859-1", self.fs.get(oid).encoding)
|
||||
self.assertEqual("aé".encode("iso-8859-1"), (self.fs.get(oid)).read())
|
||||
self.assertEqual("iso-8859-1", (self.fs.get(oid)).encoding)
|
||||
|
||||
def test_missing_length_iter(self):
|
||||
# Test fix that guards against PHP-237
|
||||
@ -411,11 +437,13 @@ class TestGridfs(IntegrationTest):
|
||||
client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=10)
|
||||
db = client.db
|
||||
gfs = gridfs.GridFS(db)
|
||||
self.assertRaises(ServerSelectionTimeoutError, gfs.list)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
gfs.list()
|
||||
|
||||
fs = gridfs.GridFS(db)
|
||||
f = fs.new_file()
|
||||
self.assertRaises(ServerSelectionTimeoutError, f.close)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
f.close()
|
||||
|
||||
def test_gridfs_find(self):
|
||||
self.fs.put(b"test2", filename="two")
|
||||
@ -429,14 +457,15 @@ class TestGridfs(IntegrationTest):
|
||||
self.assertEqual(3, files.count_documents({"filename": "two"}))
|
||||
self.assertEqual(4, files.count_documents({}))
|
||||
cursor = self.fs.find(no_cursor_timeout=False).sort("uploadDate", -1).skip(1).limit(2)
|
||||
gout = next(cursor)
|
||||
gout = cursor.next()
|
||||
self.assertEqual(b"test1", gout.read())
|
||||
cursor.rewind()
|
||||
gout = next(cursor)
|
||||
gout = cursor.next()
|
||||
self.assertEqual(b"test1", gout.read())
|
||||
gout = next(cursor)
|
||||
gout = cursor.next()
|
||||
self.assertEqual(b"test2+", gout.read())
|
||||
self.assertRaises(StopIteration, cursor.__next__)
|
||||
with self.assertRaises(StopIteration):
|
||||
cursor.__next__()
|
||||
cursor.rewind()
|
||||
items = cursor.to_list()
|
||||
self.assertEqual(len(items), 2)
|
||||
@ -484,12 +513,12 @@ class TestGridfs(IntegrationTest):
|
||||
self.fs.put(data, filename="f")
|
||||
self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}})
|
||||
|
||||
self.assertEqual(data, self.fs.get_version("f").read())
|
||||
self.assertEqual(data, (self.fs.get_version("f")).read())
|
||||
|
||||
def test_unacknowledged(self):
|
||||
# w=0 is prohibited.
|
||||
with self.assertRaises(ConfigurationError):
|
||||
gridfs.GridFS(self.rs_or_single_client(w=0).pymongo_test)
|
||||
gridfs.GridFS((self.rs_or_single_client(w=0)).pymongo_test)
|
||||
|
||||
def test_md5(self):
|
||||
gin = self.fs.new_file()
|
||||
@ -524,7 +553,7 @@ class TestGridfsReplicaSet(IntegrationTest):
|
||||
self.assertEqual(gin._coll.read_preference, ReadPreference.PRIMARY)
|
||||
|
||||
oid = fs.put(b"foo")
|
||||
content = fs.get(oid).read()
|
||||
content = (fs.get(oid)).read()
|
||||
self.assertEqual(b"foo", content)
|
||||
|
||||
def test_gridfs_secondary(self):
|
||||
@ -538,7 +567,8 @@ class TestGridfsReplicaSet(IntegrationTest):
|
||||
fs = gridfs.GridFS(secondary_connection.gfsreplica, "gfssecondarytest")
|
||||
|
||||
# This won't detect secondary, raises error
|
||||
self.assertRaises(NotPrimaryError, fs.put, b"foo")
|
||||
with self.assertRaises(NotPrimaryError):
|
||||
fs.put(b"foo")
|
||||
|
||||
def test_gridfs_secondary_lazy(self):
|
||||
# Should detect it's connected to secondary and not attempt to
|
||||
@ -552,8 +582,10 @@ class TestGridfsReplicaSet(IntegrationTest):
|
||||
fs = gridfs.GridFS(client.gfsreplica, "gfssecondarylazytest")
|
||||
|
||||
# Connects, doesn't create index.
|
||||
self.assertRaises(NoFile, fs.get_last_version)
|
||||
self.assertRaises(NotPrimaryError, fs.put, "data", encoding="utf-8")
|
||||
with self.assertRaises(NoFile):
|
||||
fs.get_last_version()
|
||||
with self.assertRaises(NotPrimaryError):
|
||||
fs.put("data", encoding="utf-8")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -16,12 +16,14 @@
|
||||
"""Tests for the gridfs package."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import itertools
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from io import BytesIO
|
||||
from test.helpers import ConcurrentRunner
|
||||
from unittest.mock import patch
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
@ -44,10 +46,12 @@ from pymongo.errors import (
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
class JustWrite(threading.Thread):
|
||||
|
||||
class JustWrite(ConcurrentRunner):
|
||||
def __init__(self, gfs, num):
|
||||
threading.Thread.__init__(self)
|
||||
super().__init__()
|
||||
self.gfs = gfs
|
||||
self.num = num
|
||||
self.daemon = True
|
||||
@ -59,9 +63,9 @@ class JustWrite(threading.Thread):
|
||||
file.close()
|
||||
|
||||
|
||||
class JustRead(threading.Thread):
|
||||
class JustRead(ConcurrentRunner):
|
||||
def __init__(self, gfs, num, results):
|
||||
threading.Thread.__init__(self)
|
||||
super().__init__()
|
||||
self.gfs = gfs
|
||||
self.num = num
|
||||
self.results = results
|
||||
@ -89,12 +93,13 @@ class TestGridfs(IntegrationTest):
|
||||
|
||||
def test_basic(self):
|
||||
oid = self.fs.upload_from_stream("test_filename", b"hello world")
|
||||
self.assertEqual(b"hello world", self.fs.open_download_stream(oid).read())
|
||||
self.assertEqual(b"hello world", (self.fs.open_download_stream(oid)).read())
|
||||
self.assertEqual(1, self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(1, self.db.fs.chunks.count_documents({}))
|
||||
|
||||
self.fs.delete(oid)
|
||||
self.assertRaises(NoFile, self.fs.open_download_stream, oid)
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.open_download_stream(oid)
|
||||
self.assertEqual(0, self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
|
||||
|
||||
@ -111,7 +116,7 @@ class TestGridfs(IntegrationTest):
|
||||
|
||||
def test_empty_file(self):
|
||||
oid = self.fs.upload_from_stream("test_filename", b"")
|
||||
self.assertEqual(b"", self.fs.open_download_stream(oid).read())
|
||||
self.assertEqual(b"", (self.fs.open_download_stream(oid)).read())
|
||||
self.assertEqual(1, self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
|
||||
|
||||
@ -128,10 +133,12 @@ class TestGridfs(IntegrationTest):
|
||||
self.db.fs.chunks.update_one({"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}})
|
||||
try:
|
||||
out = self.fs.open_download_stream(files_id)
|
||||
self.assertRaises(CorruptGridFile, out.read)
|
||||
with self.assertRaises(CorruptGridFile):
|
||||
out.read()
|
||||
|
||||
out = self.fs.open_download_stream(files_id)
|
||||
self.assertRaises(CorruptGridFile, out.readline)
|
||||
with self.assertRaises(CorruptGridFile):
|
||||
out.readline()
|
||||
finally:
|
||||
self.fs.delete(files_id)
|
||||
|
||||
@ -146,13 +153,13 @@ class TestGridfs(IntegrationTest):
|
||||
self.assertTrue(
|
||||
any(
|
||||
info.get("key") == [("files_id", 1), ("n", 1)]
|
||||
for info in chunks.index_information().values()
|
||||
for info in (chunks.index_information()).values()
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
any(
|
||||
info.get("key") == [("filename", 1), ("uploadDate", 1)]
|
||||
for info in files.index_information().values()
|
||||
for info in (files.index_information()).values()
|
||||
)
|
||||
)
|
||||
|
||||
@ -174,25 +181,27 @@ class TestGridfs(IntegrationTest):
|
||||
self.assertTrue(
|
||||
any(
|
||||
info.get("key") == [("filename", 1), ("uploadDate", 1)]
|
||||
for info in files.index_information().values()
|
||||
for info in (files.index_information()).values()
|
||||
)
|
||||
)
|
||||
files.drop()
|
||||
|
||||
def test_alt_collection(self):
|
||||
oid = self.alt.upload_from_stream("test_filename", b"hello world")
|
||||
self.assertEqual(b"hello world", self.alt.open_download_stream(oid).read())
|
||||
self.assertEqual(b"hello world", (self.alt.open_download_stream(oid)).read())
|
||||
self.assertEqual(1, self.db.alt.files.count_documents({}))
|
||||
self.assertEqual(1, self.db.alt.chunks.count_documents({}))
|
||||
|
||||
self.alt.delete(oid)
|
||||
self.assertRaises(NoFile, self.alt.open_download_stream, oid)
|
||||
with self.assertRaises(NoFile):
|
||||
self.alt.open_download_stream(oid)
|
||||
self.assertEqual(0, self.db.alt.files.count_documents({}))
|
||||
self.assertEqual(0, self.db.alt.chunks.count_documents({}))
|
||||
|
||||
self.assertRaises(NoFile, self.alt.open_download_stream, "foo")
|
||||
with self.assertRaises(NoFile):
|
||||
self.alt.open_download_stream("foo")
|
||||
self.alt.upload_from_stream("foo", b"hello world")
|
||||
self.assertEqual(b"hello world", self.alt.open_download_stream_by_name("foo").read())
|
||||
self.assertEqual(b"hello world", (self.alt.open_download_stream_by_name("foo")).read())
|
||||
|
||||
self.alt.upload_from_stream("mike", b"")
|
||||
self.alt.upload_from_stream("test", b"foo")
|
||||
@ -200,7 +209,7 @@ class TestGridfs(IntegrationTest):
|
||||
|
||||
self.assertEqual(
|
||||
{"mike", "test", "hello world", "foo"},
|
||||
{k["filename"] for k in list(self.db.alt.files.find())},
|
||||
{k["filename"] for k in self.db.alt.files.find().to_list()},
|
||||
)
|
||||
|
||||
def test_threaded_reads(self):
|
||||
@ -240,13 +249,14 @@ class TestGridfs(IntegrationTest):
|
||||
two = two._id
|
||||
three = self.fs.upload_from_stream("test", b"baz")
|
||||
|
||||
self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test").read())
|
||||
self.assertEqual(b"baz", (self.fs.open_download_stream_by_name("test")).read())
|
||||
self.fs.delete(three)
|
||||
self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test").read())
|
||||
self.assertEqual(b"bar", (self.fs.open_download_stream_by_name("test")).read())
|
||||
self.fs.delete(two)
|
||||
self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test").read())
|
||||
self.assertEqual(b"foo", (self.fs.open_download_stream_by_name("test")).read())
|
||||
self.fs.delete(one)
|
||||
self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test")
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.open_download_stream_by_name("test")
|
||||
|
||||
def test_get_version(self):
|
||||
self.fs.upload_from_stream("test", b"foo")
|
||||
@ -256,28 +266,30 @@ class TestGridfs(IntegrationTest):
|
||||
self.fs.upload_from_stream("test", b"baz")
|
||||
time.sleep(0.01)
|
||||
|
||||
self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test", revision=0).read())
|
||||
self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test", revision=1).read())
|
||||
self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test", revision=2).read())
|
||||
self.assertEqual(b"foo", (self.fs.open_download_stream_by_name("test", revision=0)).read())
|
||||
self.assertEqual(b"bar", (self.fs.open_download_stream_by_name("test", revision=1)).read())
|
||||
self.assertEqual(b"baz", (self.fs.open_download_stream_by_name("test", revision=2)).read())
|
||||
|
||||
self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test", revision=-1).read())
|
||||
self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test", revision=-2).read())
|
||||
self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test", revision=-3).read())
|
||||
self.assertEqual(b"baz", (self.fs.open_download_stream_by_name("test", revision=-1)).read())
|
||||
self.assertEqual(b"bar", (self.fs.open_download_stream_by_name("test", revision=-2)).read())
|
||||
self.assertEqual(b"foo", (self.fs.open_download_stream_by_name("test", revision=-3)).read())
|
||||
|
||||
self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test", revision=3)
|
||||
self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test", revision=-4)
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.open_download_stream_by_name("test", revision=3)
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.open_download_stream_by_name("test", revision=-4)
|
||||
|
||||
def test_upload_from_stream(self):
|
||||
oid = self.fs.upload_from_stream("test_file", BytesIO(b"hello world"), chunk_size_bytes=1)
|
||||
self.assertEqual(11, self.db.fs.chunks.count_documents({}))
|
||||
self.assertEqual(b"hello world", self.fs.open_download_stream(oid).read())
|
||||
self.assertEqual(b"hello world", (self.fs.open_download_stream(oid)).read())
|
||||
|
||||
def test_upload_from_stream_with_id(self):
|
||||
oid = ObjectId()
|
||||
self.fs.upload_from_stream_with_id(
|
||||
oid, "test_file_custom_id", BytesIO(b"custom id"), chunk_size_bytes=1
|
||||
)
|
||||
self.assertEqual(b"custom id", self.fs.open_download_stream(oid).read())
|
||||
self.assertEqual(b"custom id", (self.fs.open_download_stream(oid)).read())
|
||||
|
||||
@patch("gridfs.synchronous.grid_file._UPLOAD_BUFFER_CHUNKS", 3)
|
||||
@client_context.require_failCommand_fail_point
|
||||
@ -316,14 +328,14 @@ class TestGridfs(IntegrationTest):
|
||||
gin = self.fs.open_upload_stream("from_stream")
|
||||
gin.write(b"from stream")
|
||||
gin.close()
|
||||
self.assertEqual(b"from stream", self.fs.open_download_stream(gin._id).read())
|
||||
self.assertEqual(b"from stream", (self.fs.open_download_stream(gin._id)).read())
|
||||
|
||||
def test_open_upload_stream_with_id(self):
|
||||
oid = ObjectId()
|
||||
gin = self.fs.open_upload_stream_with_id(oid, "from_stream_custom_id")
|
||||
gin.write(b"from stream with custom id")
|
||||
gin.close()
|
||||
self.assertEqual(b"from stream with custom id", self.fs.open_download_stream(oid).read())
|
||||
self.assertEqual(b"from stream with custom id", (self.fs.open_download_stream(oid)).read())
|
||||
|
||||
def test_missing_length_iter(self):
|
||||
# Test fix that guards against PHP-237
|
||||
@ -345,12 +357,12 @@ class TestGridfs(IntegrationTest):
|
||||
client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=0)
|
||||
cdb = client.db
|
||||
gfs = gridfs.GridFSBucket(cdb)
|
||||
self.assertRaises(ServerSelectionTimeoutError, gfs.delete, 0)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
gfs.delete(0)
|
||||
|
||||
gfs = gridfs.GridFSBucket(cdb)
|
||||
self.assertRaises(
|
||||
ServerSelectionTimeoutError, gfs.upload_from_stream, "test", b""
|
||||
) # Still no connection.
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
gfs.upload_from_stream("test", b"") # Still no connection.
|
||||
|
||||
def test_gridfs_find(self):
|
||||
self.fs.upload_from_stream("two", b"test2")
|
||||
@ -366,14 +378,15 @@ class TestGridfs(IntegrationTest):
|
||||
cursor = self.fs.find(
|
||||
{}, no_cursor_timeout=False, sort=[("uploadDate", -1)], skip=1, limit=2
|
||||
)
|
||||
gout = next(cursor)
|
||||
gout = cursor.next()
|
||||
self.assertEqual(b"test1", gout.read())
|
||||
cursor.rewind()
|
||||
gout = next(cursor)
|
||||
gout = cursor.next()
|
||||
self.assertEqual(b"test1", gout.read())
|
||||
gout = next(cursor)
|
||||
gout = cursor.next()
|
||||
self.assertEqual(b"test2+", gout.read())
|
||||
self.assertRaises(StopIteration, cursor.__next__)
|
||||
with self.assertRaises(StopIteration):
|
||||
cursor.next()
|
||||
cursor.close()
|
||||
self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})
|
||||
|
||||
@ -383,20 +396,21 @@ class TestGridfs(IntegrationTest):
|
||||
self.fs.upload_from_stream("f", data)
|
||||
self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}})
|
||||
|
||||
self.assertEqual(data, self.fs.open_download_stream_by_name("f").read())
|
||||
self.assertEqual(data, (self.fs.open_download_stream_by_name("f")).read())
|
||||
|
||||
def test_unacknowledged(self):
|
||||
# w=0 is prohibited.
|
||||
with self.assertRaises(ConfigurationError):
|
||||
gridfs.GridFSBucket(self.rs_or_single_client(w=0).pymongo_test)
|
||||
gridfs.GridFSBucket((self.rs_or_single_client(w=0)).pymongo_test)
|
||||
|
||||
def test_rename(self):
|
||||
_id = self.fs.upload_from_stream("first_name", b"testing")
|
||||
self.assertEqual(b"testing", self.fs.open_download_stream_by_name("first_name").read())
|
||||
self.assertEqual(b"testing", (self.fs.open_download_stream_by_name("first_name")).read())
|
||||
|
||||
self.fs.rename(_id, "second_name")
|
||||
self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "first_name")
|
||||
self.assertEqual(b"testing", self.fs.open_download_stream_by_name("second_name").read())
|
||||
with self.assertRaises(NoFile):
|
||||
self.fs.open_download_stream_by_name("first_name")
|
||||
self.assertEqual(b"testing", (self.fs.open_download_stream_by_name("second_name")).read())
|
||||
|
||||
@patch("gridfs.synchronous.grid_file._UPLOAD_BUFFER_SIZE", 5)
|
||||
def test_abort(self):
|
||||
@ -407,7 +421,8 @@ class TestGridfs(IntegrationTest):
|
||||
self.assertEqual(3, self.db.fs.chunks.count_documents({"files_id": gin._id}))
|
||||
gin.abort()
|
||||
self.assertTrue(gin.closed)
|
||||
self.assertRaises(ValueError, gin.write, b"test4")
|
||||
with self.assertRaises(ValueError):
|
||||
gin.write(b"test4")
|
||||
self.assertEqual(0, self.db.fs.chunks.count_documents({"files_id": gin._id}))
|
||||
|
||||
def test_download_to_stream(self):
|
||||
@ -490,7 +505,7 @@ class TestGridfsBucketReplicaSet(IntegrationTest):
|
||||
|
||||
gfs = gridfs.GridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest")
|
||||
oid = gfs.upload_from_stream("test_filename", b"foo")
|
||||
content = gfs.open_download_stream(oid).read()
|
||||
content = (gfs.open_download_stream(oid)).read()
|
||||
self.assertEqual(b"foo", content)
|
||||
|
||||
def test_gridfs_secondary(self):
|
||||
@ -504,7 +519,8 @@ class TestGridfsBucketReplicaSet(IntegrationTest):
|
||||
gfs = gridfs.GridFSBucket(secondary_connection.gfsbucketreplica, "gfsbucketsecondarytest")
|
||||
|
||||
# This won't detect secondary, raises error
|
||||
self.assertRaises(NotPrimaryError, gfs.upload_from_stream, "test_filename", b"foo")
|
||||
with self.assertRaises(NotPrimaryError):
|
||||
gfs.upload_from_stream("test_filename", b"foo")
|
||||
|
||||
def test_gridfs_secondary_lazy(self):
|
||||
# Should detect it's connected to secondary and not attempt to
|
||||
@ -518,8 +534,10 @@ class TestGridfsBucketReplicaSet(IntegrationTest):
|
||||
gfs = gridfs.GridFSBucket(client.gfsbucketreplica, "gfsbucketsecondarylazytest")
|
||||
|
||||
# Connects, doesn't create index.
|
||||
self.assertRaises(NoFile, gfs.open_download_stream_by_name, "test_filename")
|
||||
self.assertRaises(NotPrimaryError, gfs.upload_from_stream, "test_filename", b"data")
|
||||
with self.assertRaises(NoFile):
|
||||
gfs.open_download_stream_by_name("test_filename")
|
||||
with self.assertRaises(NotPrimaryError):
|
||||
gfs.upload_from_stream("test_filename", b"data")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from pymongo import MongoClient, ReadPreference
|
||||
from pymongo.errors import ServerSelectionTimeoutError
|
||||
@ -43,11 +44,17 @@ from test.utils_selection_tests import (
|
||||
make_server_description,
|
||||
)
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
_TEST_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
os.path.join("server_selection", "server_selection"),
|
||||
)
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(
|
||||
Path(__file__).resolve().parent, "server_selection", "server_selection"
|
||||
)
|
||||
else:
|
||||
TEST_PATH = os.path.join(
|
||||
Path(__file__).resolve().parent.parent, "server_selection", "server_selection"
|
||||
)
|
||||
|
||||
|
||||
class SelectionStoreSelector:
|
||||
@ -61,7 +68,7 @@ class SelectionStoreSelector:
|
||||
return selection
|
||||
|
||||
|
||||
class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore
|
||||
class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
@ -79,13 +86,12 @@ class TestCustomServerSelectorFunction(IntegrationTest):
|
||||
client = self.rs_or_single_client(
|
||||
server_selector=custom_selector, event_listeners=[listener]
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll
|
||||
self.addCleanup(client.drop_database, "testdb")
|
||||
|
||||
# Wait the node list to be fully populated.
|
||||
def all_hosts_started():
|
||||
return len(client.admin.command(HelloCompat.LEGACY_CMD)["hosts"]) == len(
|
||||
return len((client.admin.command(HelloCompat.LEGACY_CMD))["hosts"]) == len(
|
||||
client._topology._description.readable_servers
|
||||
)
|
||||
|
||||
@ -121,7 +127,6 @@ class TestCustomServerSelectorFunction(IntegrationTest):
|
||||
# Client setup.
|
||||
mongo_client = self.rs_or_single_client(server_selector=selector)
|
||||
test_collection = mongo_client.testdb.test_collection
|
||||
self.addCleanup(mongo_client.close)
|
||||
self.addCleanup(mongo_client.drop_database, "testdb")
|
||||
|
||||
# Do N operations and test selector is called at least N times.
|
||||
|
||||
@ -15,9 +15,12 @@
|
||||
"""Test the topology module's Server Selection Spec implementation."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.helpers import ConcurrentRunner
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
OvertCommandListener,
|
||||
@ -32,10 +35,14 @@ from pymongo.monitoring import ConnectionReadyEvent
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
|
||||
_IS_SYNC = True
|
||||
# Location of JSON test specifications.
|
||||
TEST_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), os.path.join("server_selection", "in_window")
|
||||
)
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window")
|
||||
else:
|
||||
TEST_PATH = os.path.join(
|
||||
Path(__file__).resolve().parent.parent, "server_selection", "in_window"
|
||||
)
|
||||
|
||||
|
||||
class TestAllScenarios(unittest.TestCase):
|
||||
@ -92,7 +99,7 @@ class CustomSpecTestCreator(SpecTestCreator):
|
||||
CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests()
|
||||
|
||||
|
||||
class FinderThread(threading.Thread):
|
||||
class FinderTask(ConcurrentRunner):
|
||||
def __init__(self, collection, iterations):
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
@ -109,17 +116,17 @@ class FinderThread(threading.Thread):
|
||||
class TestProse(IntegrationTest):
|
||||
def frequencies(self, client, listener, n_finds=10):
|
||||
coll = client.test.test
|
||||
N_THREADS = 10
|
||||
threads = [FinderThread(coll, n_finds) for _ in range(N_THREADS)]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
for thread in threads:
|
||||
self.assertTrue(thread.passed)
|
||||
N_TASKS = 10
|
||||
tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)]
|
||||
for task in tasks:
|
||||
task.start()
|
||||
for task in tasks:
|
||||
task.join()
|
||||
for task in tasks:
|
||||
self.assertTrue(task.passed)
|
||||
|
||||
events = listener.started_events
|
||||
self.assertEqual(len(events), n_finds * N_THREADS)
|
||||
self.assertEqual(len(events), n_finds * N_TASKS)
|
||||
nodes = client.nodes
|
||||
self.assertEqual(len(nodes), 2)
|
||||
freqs = {address: 0.0 for address in nodes}
|
||||
|
||||
@ -666,6 +666,11 @@ def joinall(threads):
|
||||
assert not t.is_alive(), "Thread %s hung" % t
|
||||
|
||||
|
||||
async def async_joinall(tasks):
|
||||
"""Join threads with a 5-minute timeout, assert joins succeeded"""
|
||||
await asyncio.wait([t.task for t in tasks if t is not None], timeout=300)
|
||||
|
||||
|
||||
def wait_until(predicate, success_description, timeout=10):
|
||||
"""Wait up to 10 seconds (by default) for predicate to be true.
|
||||
|
||||
@ -827,7 +832,7 @@ async def async_get_pools(client):
|
||||
"""Get all pools."""
|
||||
return [
|
||||
server.pool
|
||||
async for server in await (await client._get_topology()).select_servers(
|
||||
for server in await (await client._get_topology()).select_servers(
|
||||
any_server_selector, _Op.TEST
|
||||
)
|
||||
]
|
||||
|
||||
@ -18,96 +18,28 @@ from __future__ import annotations
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
from test import PyMongoTestCase
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.pymongo_mocks import DummyMonitor
|
||||
from test.utils import MockPool, parse_read_preference
|
||||
from test.utils_selection_tests_shared import (
|
||||
get_addresses,
|
||||
get_topology_type_name,
|
||||
make_server_description,
|
||||
)
|
||||
|
||||
from bson import json_util
|
||||
from pymongo.common import HEARTBEAT_FREQUENCY, MIN_SUPPORTED_WIRE_VERSION, clean_node
|
||||
from pymongo.common import HEARTBEAT_FREQUENCY
|
||||
from pymongo.errors import AutoReconnect, ConfigurationError
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
from pymongo.synchronous.settings import TopologySettings
|
||||
from pymongo.synchronous.topology import Topology
|
||||
|
||||
|
||||
def get_addresses(server_list):
|
||||
seeds = []
|
||||
hosts = []
|
||||
for server in server_list:
|
||||
seeds.append(clean_node(server["address"]))
|
||||
hosts.append(server["address"])
|
||||
return seeds, hosts
|
||||
|
||||
|
||||
def make_last_write_date(server):
|
||||
epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None)
|
||||
millis = server.get("lastWrite", {}).get("lastWriteDate")
|
||||
if millis:
|
||||
diff = ((millis % 1000) + 1000) % 1000
|
||||
seconds = (millis - diff) / 1000
|
||||
micros = diff * 1000
|
||||
return epoch + datetime.timedelta(seconds=seconds, microseconds=micros)
|
||||
else:
|
||||
# "Unknown" server.
|
||||
return epoch
|
||||
|
||||
|
||||
def make_server_description(server, hosts):
|
||||
"""Make a ServerDescription from server info in a JSON test."""
|
||||
server_type = server["type"]
|
||||
if server_type in ("Unknown", "PossiblePrimary"):
|
||||
return ServerDescription(clean_node(server["address"]), Hello({}))
|
||||
|
||||
hello_response = {"ok": True, "hosts": hosts}
|
||||
if server_type not in ("Standalone", "Mongos", "RSGhost"):
|
||||
hello_response["setName"] = "rs"
|
||||
|
||||
if server_type == "RSPrimary":
|
||||
hello_response[HelloCompat.LEGACY_CMD] = True
|
||||
elif server_type == "RSSecondary":
|
||||
hello_response["secondary"] = True
|
||||
elif server_type == "Mongos":
|
||||
hello_response["msg"] = "isdbgrid"
|
||||
elif server_type == "RSGhost":
|
||||
hello_response["isreplicaset"] = True
|
||||
elif server_type == "RSArbiter":
|
||||
hello_response["arbiterOnly"] = True
|
||||
|
||||
hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)}
|
||||
|
||||
for field in "maxWireVersion", "tags", "idleWritePeriodMillis":
|
||||
if field in server:
|
||||
hello_response[field] = server[field]
|
||||
|
||||
hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION)
|
||||
|
||||
# Sets _last_update_time to now.
|
||||
sd = ServerDescription(
|
||||
clean_node(server["address"]),
|
||||
Hello(hello_response),
|
||||
round_trip_time=server["avg_rtt_ms"] / 1000.0,
|
||||
)
|
||||
|
||||
if "lastUpdateTime" in server:
|
||||
sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec.
|
||||
|
||||
return sd
|
||||
|
||||
|
||||
def get_topology_type_name(scenario_def):
|
||||
td = scenario_def["topology_description"]
|
||||
name = td["type"]
|
||||
if name == "Unknown":
|
||||
# PyMongo never starts a topology in type Unknown.
|
||||
return "Sharded" if len(td["servers"]) > 1 else "Single"
|
||||
else:
|
||||
return name
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def get_topology_settings_dict(**kwargs):
|
||||
@ -244,7 +176,7 @@ def create_test(scenario_def):
|
||||
|
||||
|
||||
def create_selection_tests(test_dir):
|
||||
class TestAllScenarios(unittest.TestCase):
|
||||
class TestAllScenarios(PyMongoTestCase):
|
||||
pass
|
||||
|
||||
for dirpath, _, filenames in os.walk(test_dir):
|
||||
|
||||
100
test/utils_selection_tests_shared.py
Normal file
100
test/utils_selection_tests_shared.py
Normal file
@ -0,0 +1,100 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for testing Server Selection and Max Staleness."""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from pymongo.common import MIN_SUPPORTED_WIRE_VERSION, clean_node
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.server_description import ServerDescription
|
||||
|
||||
|
||||
def get_addresses(server_list):
|
||||
seeds = []
|
||||
hosts = []
|
||||
for server in server_list:
|
||||
seeds.append(clean_node(server["address"]))
|
||||
hosts.append(server["address"])
|
||||
return seeds, hosts
|
||||
|
||||
|
||||
def make_last_write_date(server):
|
||||
epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None)
|
||||
millis = server.get("lastWrite", {}).get("lastWriteDate")
|
||||
if millis:
|
||||
diff = ((millis % 1000) + 1000) % 1000
|
||||
seconds = (millis - diff) / 1000
|
||||
micros = diff * 1000
|
||||
return epoch + datetime.timedelta(seconds=seconds, microseconds=micros)
|
||||
else:
|
||||
# "Unknown" server.
|
||||
return epoch
|
||||
|
||||
|
||||
def make_server_description(server, hosts):
|
||||
"""Make a ServerDescription from server info in a JSON test."""
|
||||
server_type = server["type"]
|
||||
if server_type in ("Unknown", "PossiblePrimary"):
|
||||
return ServerDescription(clean_node(server["address"]), Hello({}))
|
||||
|
||||
hello_response = {"ok": True, "hosts": hosts}
|
||||
if server_type not in ("Standalone", "Mongos", "RSGhost"):
|
||||
hello_response["setName"] = "rs"
|
||||
|
||||
if server_type == "RSPrimary":
|
||||
hello_response[HelloCompat.LEGACY_CMD] = True
|
||||
elif server_type == "RSSecondary":
|
||||
hello_response["secondary"] = True
|
||||
elif server_type == "Mongos":
|
||||
hello_response["msg"] = "isdbgrid"
|
||||
elif server_type == "RSGhost":
|
||||
hello_response["isreplicaset"] = True
|
||||
elif server_type == "RSArbiter":
|
||||
hello_response["arbiterOnly"] = True
|
||||
|
||||
hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)}
|
||||
|
||||
for field in "maxWireVersion", "tags", "idleWritePeriodMillis":
|
||||
if field in server:
|
||||
hello_response[field] = server[field]
|
||||
|
||||
hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION)
|
||||
|
||||
# Sets _last_update_time to now.
|
||||
sd = ServerDescription(
|
||||
clean_node(server["address"]),
|
||||
Hello(hello_response),
|
||||
round_trip_time=server["avg_rtt_ms"] / 1000.0,
|
||||
)
|
||||
|
||||
if "lastUpdateTime" in server:
|
||||
sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec.
|
||||
|
||||
return sd
|
||||
|
||||
|
||||
def get_topology_type_name(scenario_def):
|
||||
td = scenario_def["topology_description"]
|
||||
name = td["type"]
|
||||
if name == "Unknown":
|
||||
# PyMongo never starts a topology in type Unknown.
|
||||
return "Sharded" if len(td["servers"]) > 1 else "Single"
|
||||
else:
|
||||
return name
|
||||
@ -122,7 +122,9 @@ replacements = {
|
||||
"SpecRunnerTask": "SpecRunnerThread",
|
||||
"AsyncMockConnection": "MockConnection",
|
||||
"AsyncMockPool": "MockPool",
|
||||
"StopAsyncIteration": "StopIteration",
|
||||
"create_async_event": "create_event",
|
||||
"async_joinall": "joinall",
|
||||
}
|
||||
|
||||
docstring_replacements: dict[tuple[str, str], str] = {
|
||||
@ -169,7 +171,7 @@ gridfs_files = [
|
||||
|
||||
def async_only_test(f: str) -> bool:
|
||||
"""Return True for async tests that should not be converted to sync."""
|
||||
return f in ["test_locks.py", "test_concurrency.py"]
|
||||
return f in ["test_locks.py", "test_concurrency.py", "test_async_cancellation.py"]
|
||||
|
||||
|
||||
test_files = [
|
||||
@ -202,6 +204,7 @@ converted_tests = [
|
||||
"test_comment.py",
|
||||
"test_common.py",
|
||||
"test_connection_logging.py",
|
||||
"test_connection_monitoring.py",
|
||||
"test_connections_survive_primary_stepdown_spec.py",
|
||||
"test_create_entities.py",
|
||||
"test_crud_unified.py",
|
||||
@ -212,12 +215,14 @@ converted_tests = [
|
||||
"test_dns.py",
|
||||
"test_encryption.py",
|
||||
"test_examples.py",
|
||||
"test_grid_file.py",
|
||||
"test_gridfs.py",
|
||||
"test_gridfs_bucket.py",
|
||||
"test_gridfs_spec.py",
|
||||
"test_heartbeat_monitoring.py",
|
||||
"test_index_management.py",
|
||||
"test_grid_file.py",
|
||||
"test_load_balancer.py",
|
||||
"test_json_util_integration.py",
|
||||
"test_gridfs_spec.py",
|
||||
"test_load_balancer.py",
|
||||
"test_logger.py",
|
||||
"test_max_staleness.py",
|
||||
"test_monitoring.py",
|
||||
@ -233,9 +238,11 @@ converted_tests = [
|
||||
"test_retryable_writes_unified.py",
|
||||
"test_run_command.py",
|
||||
"test_sdam_monitoring_spec.py",
|
||||
"test_server_selection.py",
|
||||
"test_server_selection_in_window.py",
|
||||
"test_server_selection_logging.py",
|
||||
"test_session.py",
|
||||
"test_server_selection_rtt.py",
|
||||
"test_session.py",
|
||||
"test_sessions_unified.py",
|
||||
"test_srv_polling.py",
|
||||
"test_ssl.py",
|
||||
@ -245,6 +252,7 @@ converted_tests = [
|
||||
"test_unified_format.py",
|
||||
"test_versioned_api_integration.py",
|
||||
"unified_format.py",
|
||||
"utils_selection_tests.py",
|
||||
]
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user