py/modmath: Add domain error checking to sqrt, log, log2, log10.
These functions will raise 'ValueError: math domain error' on invalid input.
This commit is contained in:
committed by
Damien George
parent
f7c4f9a640
commit
17298af61e
34
py/modmath.c
34
py/modmath.c
@@ -25,6 +25,7 @@
|
||||
*/
|
||||
|
||||
#include "py/builtin.h"
|
||||
#include "py/nlr.h"
|
||||
|
||||
#if MICROPY_PY_BUILTINS_FLOAT && MICROPY_PY_MATH
|
||||
|
||||
@@ -35,7 +36,10 @@
|
||||
/// The `math` module provides some basic mathematical funtions for
|
||||
/// working with floating-point numbers.
|
||||
|
||||
//TODO: Change macros to check for overflow and raise OverflowError or RangeError
|
||||
STATIC NORETURN void math_error(void) {
|
||||
nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError, "math domain error"));
|
||||
}
|
||||
|
||||
#define MATH_FUN_1(py_name, c_name) \
|
||||
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj))); } \
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name);
|
||||
@@ -52,6 +56,16 @@
|
||||
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { mp_int_t x = MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj)); return mp_obj_new_int(x); } \
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name);
|
||||
|
||||
#define MATH_FUN_1_ERRCOND(py_name, c_name, error_condition) \
|
||||
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { \
|
||||
mp_float_t x = mp_obj_get_float(x_obj); \
|
||||
if (error_condition) { \
|
||||
math_error(); \
|
||||
} \
|
||||
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(x)); \
|
||||
} \
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name);
|
||||
|
||||
#if MP_NEED_LOG2
|
||||
// 1.442695040888963407354163704 is 1/_M_LN2
|
||||
#define log2(x) (log(x) * 1.442695040888963407354163704)
|
||||
@@ -59,7 +73,7 @@
|
||||
|
||||
/// \function sqrt(x)
|
||||
/// Returns the square root of `x`.
|
||||
MATH_FUN_1(sqrt, sqrt)
|
||||
MATH_FUN_1_ERRCOND(sqrt, sqrt, (x < (mp_float_t)0.0))
|
||||
/// \function pow(x, y)
|
||||
/// Returns `x` to the power of `y`.
|
||||
MATH_FUN_2(pow, pow)
|
||||
@@ -69,9 +83,9 @@ MATH_FUN_1(exp, exp)
|
||||
/// \function expm1(x)
|
||||
MATH_FUN_1(expm1, expm1)
|
||||
/// \function log2(x)
|
||||
MATH_FUN_1(log2, log2)
|
||||
MATH_FUN_1_ERRCOND(log2, log2, (x <= (mp_float_t)0.0))
|
||||
/// \function log10(x)
|
||||
MATH_FUN_1(log10, log10)
|
||||
MATH_FUN_1_ERRCOND(log10, log10, (x <= (mp_float_t)0.0))
|
||||
/// \function cosh(x)
|
||||
MATH_FUN_1(cosh, cosh)
|
||||
/// \function sinh(x)
|
||||
@@ -139,11 +153,19 @@ MATH_FUN_1(lgamma, lgamma)
|
||||
|
||||
// log(x[, base])
|
||||
STATIC mp_obj_t mp_math_log(mp_uint_t n_args, const mp_obj_t *args) {
|
||||
mp_float_t l = MICROPY_FLOAT_C_FUN(log)(mp_obj_get_float(args[0]));
|
||||
mp_float_t x = mp_obj_get_float(args[0]);
|
||||
if (x <= (mp_float_t)0.0) {
|
||||
math_error();
|
||||
}
|
||||
mp_float_t l = MICROPY_FLOAT_C_FUN(log)(x);
|
||||
if (n_args == 1) {
|
||||
return mp_obj_new_float(l);
|
||||
} else {
|
||||
return mp_obj_new_float(l / MICROPY_FLOAT_C_FUN(log)(mp_obj_get_float(args[1])));
|
||||
mp_float_t base = mp_obj_get_float(args[1]);
|
||||
if (base <= (mp_float_t)0.0) {
|
||||
math_error();
|
||||
}
|
||||
return mp_obj_new_float(l / MICROPY_FLOAT_C_FUN(log)(base));
|
||||
}
|
||||
}
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mp_math_log_obj, 1, 2, mp_math_log);
|
||||
|
||||
Reference in New Issue
Block a user