From 3c69277ba9e0f1e66c341c084320f7a763e1bdeb Mon Sep 17 00:00:00 2001 From: Yoctopuce dev Date: Mon, 21 Jul 2025 15:54:27 +0200 Subject: [PATCH] py/objint_longlong: Fix overflow check in mp_obj_int_get_checked. This is to fix an outstanding TODO. The test cases is using a range as this will exist in all builds, but `mp_obj_get_int` is used in many different parts of code where an overflow is more likely to occur. Signed-off-by: Yoctopuce dev --- py/objint_longlong.c | 13 +++++++++++-- tests/basics/int_64_basics.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/py/objint_longlong.c b/py/objint_longlong.c index 339ce7cfd..8b8fdc62e 100644 --- a/py/objint_longlong.c +++ b/py/objint_longlong.c @@ -308,8 +308,17 @@ mp_int_t mp_obj_int_get_truncated(mp_const_obj_t self_in) { } mp_int_t mp_obj_int_get_checked(mp_const_obj_t self_in) { - // TODO: Check overflow - return mp_obj_int_get_truncated(self_in); + if (mp_obj_is_small_int(self_in)) { + return MP_OBJ_SMALL_INT_VALUE(self_in); + } else { + const mp_obj_int_t *self = self_in; + long long value = self->val; + mp_int_t truncated = (mp_int_t)value; + if ((long long)truncated == value) { + return truncated; + } + } + mp_raise_msg(&mp_type_OverflowError, MP_ERROR_TEXT("overflow converting long int to machine word")); } mp_uint_t mp_obj_int_get_uint_checked(mp_const_obj_t self_in) { diff --git a/tests/basics/int_64_basics.py b/tests/basics/int_64_basics.py index 289ea49b6..2a161dac0 100644 --- a/tests/basics/int_64_basics.py +++ b/tests/basics/int_64_basics.py @@ -125,6 +125,22 @@ else: x = 1 << 62 print('a' * (x + 4 - x)) +# test overflow check in mp_obj_get_int_maybe +x = 1 << 32 +r = None +try: + r = range(0, x) +except OverflowError: + # 32-bit target, correctly handled the overflow of x + print("ok") +if r is not None: + if len(r) == x: + # 64-bit target, everything is just a small-int + print("ok") + else: + # 32-bit target that did not handle the overflow of x + print("unhandled overflow") + # negative shifts are invalid try: print((1 << 48) >> -4)