From: Simon Glass <sjg@chromium.org> At present the aes_get_rounds() and aes_get_keycols() functions compare the key_len parameter (in bits) directly against AES*_KEY_LENGTH constants (in bytes), causing incorrect round and column counts for non-128-bit keys. Additionally, aes_expand_key() uses key_len as a byte count in memcpy(), copying far more data than intended and causing buffer overflows. Specifically, for AES-256 (256-bit key) it comparies 256 (bits) against 32 (bytes), failing the comparison. This causes AES-256 to use AES-128 parameters (10 rounds instead of 14) and the memcpy() to copy 256 bytes instead of 32. Fix by converting key_len from bits to bytes before comparisons and in memcpy. With this we get: - AES-128 (128 bits / 16 bytes): 10 rounds, 4 key columns - AES-192 (192 bits / 24 bytes): 12 rounds, 6 key columns - AES-256 (256 bits / 32 bytes): 14 rounds, 8 key columns Co-developed-by: Claude <noreply@anthropic.com> Signed-off-by: Simon Glass <sjg@chromium.org> Fixes: 8302d1708ae ("aes: add support of aes192 and aes256") --- lib/aes.c | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/lib/aes.c b/lib/aes.c index 39ad4a990f0..3bcbeeab9af 100644 --- a/lib/aes.c +++ b/lib/aes.c @@ -513,10 +513,11 @@ static u8 rcon[11] = { static u32 aes_get_rounds(u32 key_len) { u32 rounds = AES128_ROUNDS; + u32 key_len_bytes = key_len / 8; /* Convert bits to bytes */ - if (key_len == AES192_KEY_LENGTH) + if (key_len_bytes == AES192_KEY_LENGTH) rounds = AES192_ROUNDS; - else if (key_len == AES256_KEY_LENGTH) + else if (key_len_bytes == AES256_KEY_LENGTH) rounds = AES256_ROUNDS; return rounds; @@ -525,10 +526,11 @@ static u32 aes_get_rounds(u32 key_len) static u32 aes_get_keycols(u32 key_len) { u32 keycols = AES128_KEYCOLS; + u32 key_len_bytes = key_len / 8; /* Convert bits to bytes */ - if (key_len == AES192_KEY_LENGTH) + if (key_len_bytes == AES192_KEY_LENGTH) keycols = AES192_KEYCOLS; - else if (key_len == AES256_KEY_LENGTH) + else if (key_len_bytes == AES256_KEY_LENGTH) keycols = AES256_KEYCOLS; return keycols; @@ -538,12 +540,13 @@ static u32 aes_get_keycols(u32 key_len) void aes_expand_key(u8 *key, u32 key_len, u8 *expkey) { u8 tmp0, tmp1, tmp2, tmp3, tmp4; - u32 idx, aes_rounds, aes_keycols; + uint idx, aes_rounds, aes_keycols; aes_rounds = aes_get_rounds(key_len); aes_keycols = aes_get_keycols(key_len); - memcpy(expkey, key, key_len); + /* key_len is in bits; convert to bytes */ + memcpy(expkey, key, key_len / 8); for (idx = aes_keycols; idx < AES_STATECOLS * (aes_rounds + 1); idx++) { tmp0 = expkey[4*idx - 4]; -- 2.43.0