Skip to content

Commit

Permalink
Fix til string to integer routines (microsoft#18276)
Browse files Browse the repository at this point in the history
  • Loading branch information
lhecker authored Dec 5, 2024
1 parent c0d40c9 commit 89bc36c
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 44 deletions.
68 changes: 39 additions & 29 deletions src/inc/til/string.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,61 +361,71 @@ namespace til // Terminal Implementation Library. Also: "Today I Learned"
_TIL_INLINEPREFIX constexpr std::optional<uint64_t> parse_u64(const std::basic_string_view<T, Traits>& str, int base = 0) noexcept
{
// We don't have to test ptr for nullability, as we only access it under either condition:
// * str.length() > 0, for determining the base
// * str.size() > 0, for determining the base
// * ptr != end, when parsing the characters; if ptr is null, length will be 0 and thus end == ptr
#pragma warning(push)
#pragma warning(disable : 26429) // Symbol 'ptr' is never tested for nullness, it can be marked as not_null
#pragma warning(disable : 26451) // Arithmetic overflow: Using operator '+' on a 4 byte value and then casting the result to a 8 byte value. [...]
#pragma warning(disable : 26481) // Don't use pointer arithmetic. Use span instead
auto ptr = str.data();
const auto end = ptr + str.length();
const auto end = ptr + str.size();
uint64_t accumulator = 0;
uint64_t base64 = base;
uint64_t base_uint64 = base;

if (base <= 0)
{
base64 = 10;
base_uint64 = 10;

if (ptr != end && *ptr == '0')
if (str.size() >= 2 && *ptr == '0')
{
base64 = 8;
base_uint64 = 8;
ptr += 1;

if (ptr != end && (*ptr == 'x' || *ptr == 'X'))
// Shift to lowercase to make the comparison easier.
const auto ch = *ptr | 0x20;

if (ch == 'b')
{
base_uint64 = 2;
ptr += 1;
}
else if (ch == 'x')
{
base64 = 16;
base_uint64 = 16;
ptr += 1;
}
}
}

if (ptr == end)
if (ptr == end || base_uint64 > 36)
{
return {};
}

const auto max_before_mul = UINT64_MAX / base_uint64;

for (;;)
{
uint64_t value = 0;
if (*ptr >= '0' && *ptr <= '9')
{
value = *ptr - '0';
}
else if (*ptr >= 'A' && *ptr <= 'F')
{
value = *ptr - 'A' + 10;
}
else if (*ptr >= 'a' && *ptr <= 'f')
{
value = *ptr - 'a' + 10;
}
else
{
return {};
}

const auto acc = accumulator * base64 + value;
if (acc < accumulator)
// Magic mapping from 0-9, A-Z, a-z to 0-35 go brrr. Invalid values are >35.
const uint64_t ch = *ptr;
const uint64_t sub = ch >= '0' && ch <= '9' ? (('0' - 1) & ~0x20) : (('A' - 1) & ~0x20) - 10;
// 'A' and 'a' reside at 0b...00001. By subtracting 1 we shift them to 0b...00000.
// We can then mask off 0b..1..... (= 0x20) to map a-z to A-Z.
// Once we subtract `sub`, all characters between Z and a will underflow.
// This results in A-Z and a-z mapping to 10-35.
const uint64_t value = ((ch - 1) & ~0x20) - sub;

// This is where we'd be using __builtin_mul_overflow and __builtin_add_overflow,
// but when MSVC finally added support for it in v17.7, it had a different name,
// only worked on x86, and only for signed integers. So, we can't use it.
const auto acc = accumulator * base_uint64 + value;
if (
// Check for invalid inputs.
value >= base_uint64 ||
// Check for multiplication overflow.
accumulator > max_before_mul ||
// Check for addition overflow.
acc < accumulator)
{
return {};
}
Expand Down
84 changes: 69 additions & 15 deletions src/til/ut_til/string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ class StringTests
// clang++ -fsanitize=address,undefined,fuzzer -std=c++17 file.cpp
// and was run for 20min across 16 jobs in parallel.
#if 0
template<typename T, typename Traits>
std::optional<uint64_t> parse_u64(const std::basic_string_view<T, Traits>& str, int base = 0) noexcept
{
// ... implementation ...
}

extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size)
{
while (size > 0 && (isspace(*data) || *data == '+' || *data == '-'))
Expand All @@ -93,27 +99,17 @@ class StringTests
}

char narrow_buffer[128];
wchar_t wide_buffer[128];

memcpy(narrow_buffer, data, size);
for (size_t i = 0; i < size; ++i)
{
wide_buffer[i] = data[i];
}

// strtoul requires a null terminator
narrow_buffer[size] = 0;
wide_buffer[size] = 0;

char* end;
const auto expected = strtoul(narrow_buffer, &end, 0);
if (end != narrow_buffer + size || expected >= ULONG_MAX / 16)
{
return 0;
}
const auto val = strtoull(narrow_buffer, &end, 0);
const auto bad = end != narrow_buffer + size || val == ULLONG_MAX;
const auto expected = bad ? std::nullopt : std::optional{ val };

const auto actual = parse_u64({ wide_buffer, size });
if (expected != actual)
const auto actual = parse_u64(std::string_view{ narrow_buffer, size });
if (expected != actual && actual.value_or(0) != ULLONG_MAX)
{
__builtin_trap();
}
Expand All @@ -122,6 +118,64 @@ class StringTests
}
#endif

TEST_METHOD(parse_u64_overflow)
{
VERIFY_ARE_EQUAL(UINT64_C(18446744073709551614), til::details::parse_u64(std::string_view{ "18446744073709551614" }));
VERIFY_ARE_EQUAL(UINT64_C(18446744073709551615), til::details::parse_u64(std::string_view{ "18446744073709551615" }));
VERIFY_ARE_EQUAL(std::nullopt, til::details::parse_u64(std::string_view{ "18446744073709551616" }));
VERIFY_ARE_EQUAL(std::nullopt, til::details::parse_u64(std::string_view{ "18446744073709551617" }));
VERIFY_ARE_EQUAL(std::nullopt, til::details::parse_u64(std::string_view{ "88888888888888888888" }));
}

TEST_METHOD(parse_unsigned)
{
VERIFY_ARE_EQUAL(std::nullopt, til::parse_unsigned<uint32_t>(""));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_unsigned<uint32_t>("0x"));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_unsigned<uint32_t>("Z"));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_unsigned<uint32_t>("0xZ"));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_unsigned<uint32_t>("0Z"));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_unsigned<uint32_t>("123abc"));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_unsigned<uint32_t>("0123abc"));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_unsigned<uint32_t>("0x100000000"));
VERIFY_ARE_EQUAL(0u, til::parse_unsigned<uint32_t>("0"));
VERIFY_ARE_EQUAL(0u, til::parse_unsigned<uint32_t>("0x0"));
VERIFY_ARE_EQUAL(0123u, til::parse_unsigned<uint32_t>("0123"));
VERIFY_ARE_EQUAL(123u, til::parse_unsigned<uint32_t>("123"));
VERIFY_ARE_EQUAL(0x123u, til::parse_unsigned<uint32_t>("0x123"));
VERIFY_ARE_EQUAL(0x123abcu, til::parse_unsigned<uint32_t>("0x123abc"));
VERIFY_ARE_EQUAL(0X123ABCu, til::parse_unsigned<uint32_t>("0X123ABC"));
VERIFY_ARE_EQUAL(UINT32_MAX, til::parse_unsigned<uint32_t>("0xffffffff"));
VERIFY_ARE_EQUAL(UINT32_MAX, til::parse_unsigned<uint32_t>("4294967295"));
}

TEST_METHOD(parse_signed)
{
VERIFY_ARE_EQUAL(std::nullopt, til::parse_signed<int32_t>(""));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_signed<int32_t>("-"));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_signed<int32_t>("--"));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_signed<int32_t>("--0"));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_signed<int32_t>("-0Z"));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_signed<int32_t>("-123abc"));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_signed<int32_t>("-0123abc"));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_signed<int32_t>("0x80000000"));
VERIFY_ARE_EQUAL(std::nullopt, til::parse_signed<int32_t>("-0x80000001"));
VERIFY_ARE_EQUAL(0, til::parse_signed<int32_t>("0"));
VERIFY_ARE_EQUAL(0, til::parse_signed<int32_t>("-0"));
VERIFY_ARE_EQUAL(0, til::parse_signed<int32_t>("-0x0"));
VERIFY_ARE_EQUAL(0123, til::parse_signed<int32_t>("0123"));
VERIFY_ARE_EQUAL(123, til::parse_signed<int32_t>("123"));
VERIFY_ARE_EQUAL(0x123, til::parse_signed<int32_t>("0x123"));
VERIFY_ARE_EQUAL(-0123, til::parse_signed<int32_t>("-0123"));
VERIFY_ARE_EQUAL(-123, til::parse_signed<int32_t>("-123"));
VERIFY_ARE_EQUAL(-0x123, til::parse_signed<int32_t>("-0x123"));
VERIFY_ARE_EQUAL(-0x123abc, til::parse_signed<int32_t>("-0x123abc"));
VERIFY_ARE_EQUAL(-0X123ABC, til::parse_signed<int32_t>("-0X123ABC"));
VERIFY_ARE_EQUAL(INT32_MIN, til::parse_signed<int32_t>("-0x80000000"));
VERIFY_ARE_EQUAL(INT32_MIN, til::parse_signed<int32_t>("-2147483648"));
VERIFY_ARE_EQUAL(INT32_MAX, til::parse_signed<int32_t>("0x7fffffff"));
VERIFY_ARE_EQUAL(INT32_MAX, til::parse_signed<int32_t>("2147483647"));
}

TEST_METHOD(tolower_ascii)
{
for (wchar_t ch = 0; ch < 128; ++ch)
Expand Down

0 comments on commit 89bc36c

Please sign in to comment.