Migrate all the remaining logic to Rust (#669)

This commit is contained in:
Alex Gaynor 2023-11-23 14:44:52 -05:00 committed by GitHub
parent 0a9e2fc0e3
commit fb89f7c975
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 143 additions and 122 deletions

View File

@ -45,7 +45,9 @@ dependencies = [
"base64",
"bcrypt",
"bcrypt-pbkdf",
"getrandom",
"pyo3",
"subtle",
]
[[package]]

View File

@ -10,6 +10,8 @@ pyo3 = { version = "0.20.0" }
bcrypt = "0.15"
bcrypt-pbkdf = "0.10.0"
base64 = "0.21.5"
subtle = "2.5"
getrandom = "0.2"
[features]
extension-module = ["pyo3/extension-module"]

View File

@ -5,7 +5,10 @@
#![deny(rust_2018_idioms)]
use base64::Engine;
use pyo3::PyTypeInfo;
use std::convert::TryInto;
use std::io::Write;
use subtle::ConstantTimeEq;
pub const BASE64_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
&base64::alphabet::BCRYPT,
@ -13,9 +16,43 @@ pub const BASE64_ENGINE: base64::engine::GeneralPurpose = base64::engine::Genera
);
#[pyo3::prelude::pyfunction]
fn encode_base64<'p>(py: pyo3::Python<'p>, data: &[u8]) -> &'p pyo3::types::PyBytes {
let output = BASE64_ENGINE.encode(data);
pyo3::types::PyBytes::new(py, output.as_bytes())
fn gensalt<'p>(
py: pyo3::Python<'p>,
rounds: Option<u16>,
prefix: Option<&[u8]>,
) -> pyo3::PyResult<&'p pyo3::types::PyBytes> {
let rounds = rounds.unwrap_or(12);
let prefix = prefix.unwrap_or(b"2b");
if prefix != b"2a" && prefix != b"2b" {
return Err(pyo3::exceptions::PyValueError::new_err(
"Supported prefixes are b'2a' or b'2b'",
));
}
if !(4..=31).contains(&rounds) {
return Err(pyo3::exceptions::PyValueError::new_err("Invalid rounds"));
}
let mut salt = [0; 16];
getrandom::getrandom(&mut salt).unwrap();
let encoded_salt = BASE64_ENGINE.encode(salt);
pyo3::types::PyBytes::new_with(
py,
1 + prefix.len() + 1 + 2 + 1 + encoded_salt.len(),
|mut b| {
write!(b, "$").unwrap();
b.write_all(prefix).unwrap();
write!(b, "$").unwrap();
write!(b, "{:02.2}", rounds).unwrap();
write!(b, "$").unwrap();
b.write_all(encoded_salt.as_bytes()).unwrap();
Ok(())
},
)
}
#[pyo3::prelude::pyfunction]
@ -24,6 +61,14 @@ fn hashpass<'p>(
password: &[u8],
salt: &[u8],
) -> pyo3::PyResult<&'p pyo3::types::PyBytes> {
// bcrypt originally suffered from a wraparound bug:
// http://www.openwall.com/lists/oss-security/2012/01/02/4
// This bug was corrected in the OpenBSD source by truncating inputs to 72
// bytes on the updated prefix $2b$, but leaving $2a$ unchanged for
// compatibility. However, pyca/bcrypt 2.0.0 *did* correctly truncate inputs
// on $2a$, so we do it here to preserve compatibility with 2.0.0
let password = &password[..password.len().min(72)];
// salt here is not just the salt bytes, but rather an encoded value
// containing a version number, number of rounds, and the salt.
// Should be [prefix, cost, hash]. This logic is copied from `bcrypt`
@ -65,14 +110,59 @@ fn hashpass<'p>(
))
}
#[pyo3::prelude::pyfunction]
fn checkpass(
py: pyo3::Python<'_>,
password: &[u8],
hashed_password: &[u8],
) -> pyo3::PyResult<bool> {
Ok(hashpass(py, password, hashed_password)?
.as_bytes()
.ct_eq(hashed_password)
.into())
}
#[pyo3::prelude::pyfunction]
fn pbkdf<'p>(
py: pyo3::Python<'p>,
password: &[u8],
salt: &[u8],
rounds: u32,
desired_key_bytes: usize,
rounds: u32,
ignore_few_rounds: Option<bool>,
) -> pyo3::PyResult<&'p pyo3::types::PyBytes> {
let ignore_few_rounds = ignore_few_rounds.unwrap_or(false);
if password.is_empty() || salt.is_empty() {
return Err(pyo3::exceptions::PyValueError::new_err(
"password and salt must not be empty",
));
}
if desired_key_bytes == 0 || desired_key_bytes > 512 {
return Err(pyo3::exceptions::PyValueError::new_err(
"desired_key_bytes must be 1-512",
));
}
if rounds < 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"rounds must be 1 or more",
));
}
if rounds < 50 && !ignore_few_rounds {
// They probably think bcrypt.kdf()'s rounds parameter is logarithmic,
// expecting this value to be slow enough (it probably would be if this
// were bcrypt). Emit a warning.
pyo3::PyErr::warn(
py,
pyo3::exceptions::PyUserWarning::type_object(py),
&format!("Warning: bcrypt.kdf() called with only {rounds} round(s). This few is not secure: the parameter is linear, like PBKDF2."),
3
)?;
}
pyo3::types::PyBytes::new_with(py, desired_key_bytes, |output| {
py.allow_threads(|| {
bcrypt_pbkdf::bcrypt_pbkdf(password, salt, rounds, output).unwrap();
@ -83,8 +173,9 @@ fn pbkdf<'p>(
#[pyo3::prelude::pymodule]
fn _bcrypt(_py: pyo3::Python<'_>, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> {
m.add_function(pyo3::wrap_pyfunction!(encode_base64, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(gensalt, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(hashpass, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(checkpass, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(pbkdf, m)?)?;
Ok(())

View File

@ -16,10 +16,6 @@
from __future__ import absolute_import
from __future__ import division
import hmac
import os
import warnings
from .__about__ import (
__author__,
__copyright__,
@ -49,79 +45,7 @@ __all__ = [
]
def gensalt(rounds: int = 12, prefix: bytes = b"2b") -> bytes:
if prefix not in (b"2a", b"2b"):
raise ValueError("Supported prefixes are b'2a' or b'2b'")
if rounds < 4 or rounds > 31:
raise ValueError("Invalid rounds")
salt = os.urandom(16)
output = _bcrypt.encode_base64(salt)
return (
b"$"
+ prefix
+ b"$"
+ ("%2.2u" % rounds).encode("ascii")
+ b"$"
+ output
)
def hashpw(password: bytes, salt: bytes) -> bytes:
if isinstance(password, str) or isinstance(salt, str):
raise TypeError("Strings must be encoded before hashing")
# bcrypt originally suffered from a wraparound bug:
# http://www.openwall.com/lists/oss-security/2012/01/02/4
# This bug was corrected in the OpenBSD source by truncating inputs to 72
# bytes on the updated prefix $2b$, but leaving $2a$ unchanged for
# compatibility. However, pyca/bcrypt 2.0.0 *did* correctly truncate inputs
# on $2a$, so we do it here to preserve compatibility with 2.0.0
password = password[:72]
return _bcrypt.hashpass(password, salt)
def checkpw(password: bytes, hashed_password: bytes) -> bool:
if isinstance(password, str) or isinstance(hashed_password, str):
raise TypeError("Strings must be encoded before checking")
ret = hashpw(password, hashed_password)
return hmac.compare_digest(ret, hashed_password)
def kdf(
password: bytes,
salt: bytes,
desired_key_bytes: int,
rounds: int,
ignore_few_rounds: bool = False,
) -> bytes:
if isinstance(password, str) or isinstance(salt, str):
raise TypeError("Strings must be encoded before hashing")
if len(password) == 0 or len(salt) == 0:
raise ValueError("password and salt must not be empty")
if desired_key_bytes <= 0 or desired_key_bytes > 512:
raise ValueError("desired_key_bytes must be 1-512")
if rounds < 1:
raise ValueError("rounds must be 1 or more")
if rounds < 50 and not ignore_few_rounds:
# They probably think bcrypt.kdf()'s rounds parameter is logarithmic,
# expecting this value to be slow enough (it probably would be if this
# were bcrypt). Emit a warning.
warnings.warn(
(
"Warning: bcrypt.kdf() called with only {0} round(s). "
"This few is not secure: the parameter is linear, like PBKDF2."
).format(rounds),
UserWarning,
stacklevel=2,
)
return _bcrypt.pbkdf(password, salt, rounds, desired_key_bytes)
gensalt = _bcrypt.gensalt
hashpw = _bcrypt.hashpass
checkpw = _bcrypt.checkpass
kdf = _bcrypt.pbkdf

View File

@ -1,7 +1,10 @@
import typing
def encode_base64(data: bytes) -> bytes: ...
def gensalt(rounds: int = 12, prefix: bytes = b"2b") -> bytes: ...
def hashpass(password: bytes, salt: bytes) -> bytes: ...
def checkpass(password: bytes, hashed_password: bytes) -> bool: ...
def pbkdf(
password: bytes, salt: bytes, rounds: int, desired_key_bytes: int
password: bytes,
salt: bytes,
rounds: int,
desired_key_bytes: int,
ignore_few_rounds: bool = False,
) -> bytes: ...

View File

@ -1,5 +1,3 @@
import os
import pytest
import bcrypt
@ -175,39 +173,40 @@ _2y_test_vectors = [
def test_gensalt_basic(monkeypatch):
monkeypatch.setattr(os, "urandom", lambda n: b"0000000000000000")
assert bcrypt.gensalt() == b"$2b$12$KB.uKB.uKB.uKB.uKB.uK."
salt = bcrypt.gensalt()
assert salt.startswith(b"$2b$12$")
@pytest.mark.parametrize(
("rounds", "expected"),
("rounds", "expected_prefix"),
[
(4, b"$2b$04$KB.uKB.uKB.uKB.uKB.uK."),
(5, b"$2b$05$KB.uKB.uKB.uKB.uKB.uK."),
(6, b"$2b$06$KB.uKB.uKB.uKB.uKB.uK."),
(7, b"$2b$07$KB.uKB.uKB.uKB.uKB.uK."),
(8, b"$2b$08$KB.uKB.uKB.uKB.uKB.uK."),
(9, b"$2b$09$KB.uKB.uKB.uKB.uKB.uK."),
(10, b"$2b$10$KB.uKB.uKB.uKB.uKB.uK."),
(11, b"$2b$11$KB.uKB.uKB.uKB.uKB.uK."),
(12, b"$2b$12$KB.uKB.uKB.uKB.uKB.uK."),
(13, b"$2b$13$KB.uKB.uKB.uKB.uKB.uK."),
(14, b"$2b$14$KB.uKB.uKB.uKB.uKB.uK."),
(15, b"$2b$15$KB.uKB.uKB.uKB.uKB.uK."),
(16, b"$2b$16$KB.uKB.uKB.uKB.uKB.uK."),
(17, b"$2b$17$KB.uKB.uKB.uKB.uKB.uK."),
(18, b"$2b$18$KB.uKB.uKB.uKB.uKB.uK."),
(19, b"$2b$19$KB.uKB.uKB.uKB.uKB.uK."),
(20, b"$2b$20$KB.uKB.uKB.uKB.uKB.uK."),
(21, b"$2b$21$KB.uKB.uKB.uKB.uKB.uK."),
(22, b"$2b$22$KB.uKB.uKB.uKB.uKB.uK."),
(23, b"$2b$23$KB.uKB.uKB.uKB.uKB.uK."),
(24, b"$2b$24$KB.uKB.uKB.uKB.uKB.uK."),
(4, b"$2b$04$"),
(5, b"$2b$05$"),
(6, b"$2b$06$"),
(7, b"$2b$07$"),
(8, b"$2b$08$"),
(9, b"$2b$09$"),
(10, b"$2b$10$"),
(11, b"$2b$11$"),
(12, b"$2b$12$"),
(13, b"$2b$13$"),
(14, b"$2b$14$"),
(15, b"$2b$15$"),
(16, b"$2b$16$"),
(17, b"$2b$17$"),
(18, b"$2b$18$"),
(19, b"$2b$19$"),
(20, b"$2b$20$"),
(21, b"$2b$21$"),
(22, b"$2b$22$"),
(23, b"$2b$23$"),
(24, b"$2b$24$"),
],
)
def test_gensalt_rounds_valid(rounds, expected, monkeypatch):
monkeypatch.setattr(os, "urandom", lambda n: b"0000000000000000")
assert bcrypt.gensalt(rounds) == expected
def test_gensalt_rounds_valid(rounds, expected_prefix):
salt = bcrypt.gensalt(rounds)
assert salt.startswith(expected_prefix)
@pytest.mark.parametrize("rounds", list(range(1, 4)))
@ -218,12 +217,12 @@ def test_gensalt_rounds_invalid(rounds):
def test_gensalt_bad_prefix():
with pytest.raises(ValueError):
bcrypt.gensalt(prefix="bad")
bcrypt.gensalt(prefix=b"bad")
def test_gensalt_2a_prefix(monkeypatch):
monkeypatch.setattr(os, "urandom", lambda n: b"0000000000000000")
assert bcrypt.gensalt(prefix=b"2a") == b"$2a$12$KB.uKB.uKB.uKB.uKB.uK."
salt = bcrypt.gensalt(prefix=b"2a")
assert salt.startswith(b"$2a$12$")
@pytest.mark.parametrize(("password", "salt", "hashed"), _test_vectors)
@ -478,7 +477,7 @@ def test_kdf_warn_rounds():
(b"", b"$2b$04$cVWp4XaNU8a4v1uMRum2SO", 10, 10, ValueError),
(b"password", b"", 10, 10, ValueError),
(b"password", b"$2b$04$cVWp4XaNU8a4v1uMRum2SO", 0, 10, ValueError),
(b"password", b"$2b$04$cVWp4XaNU8a4v1uMRum2SO", -3, 10, ValueError),
(b"password", b"$2b$04$cVWp4XaNU8a4v1uMRum2SO", -3, 10, OverflowError),
(b"password", b"$2b$04$cVWp4XaNU8a4v1uMRum2SO", 513, 10, ValueError),
(b"password", b"$2b$04$cVWp4XaNU8a4v1uMRum2SO", 20, 0, ValueError),
],