diff --git a/examples/dgfip_c/const.h b/examples/dgfip_c/const.h index 87d404a0b..1a3c67176 100644 --- a/examples/dgfip_c/const.h +++ b/examples/dgfip_c/const.h @@ -95,7 +95,8 @@ extern void free_erreur(); #ifdef FLG_OPTIM_MIN_MAX #define my_floor(a) (floor_g((a) + 0.000001)) -#define my_arr(a) (floor_g((a) + 0.50005)) +/*#define my_arr(a) (floor_g((a) + 0.50005)) *//* Ancienne version (2021) */ +#define my_arr(a) (((a) < 0.0) ? ceil_g((a) - .50005) : floor_g((a) + .50005)) #else @@ -117,6 +118,7 @@ extern void free_erreur(); #endif /* FLG_OPTIM_MIN_MAX */ extern double floor_g(double); +extern double ceil_g(double); extern int multimax_def(int, char *); extern double multimax(double, double *); extern int modulo_def(int, int); diff --git a/examples/dgfip_c/enchain_static.c.inc b/examples/dgfip_c/enchain_static.c.inc index cf13295ed..dd2ba6a0d 100644 --- a/examples/dgfip_c/enchain_static.c.inc +++ b/examples/dgfip_c/enchain_static.c.inc @@ -313,6 +313,15 @@ double floor_g(double a) } } +double ceil_g(double a) +{ + if (fabs(a) <= LONG_MAX) { + return ceil(a); + } else { + return a; + } +} + int multimax_def(int nbopd, char *var) { int i = 0; diff --git a/src/mlang/backend_ir/bir_interpreter.ml b/src/mlang/backend_ir/bir_interpreter.ml index 0d44e7734..f9a81f6e8 100644 --- a/src/mlang/backend_ir/bir_interpreter.ml +++ b/src/mlang/backend_ir/bir_interpreter.ml @@ -76,6 +76,8 @@ module type S = sig val var_value_to_var_literal : var_value -> var_literal + val update_ctx_with_inputs : ctx -> Mir.literal Bir.VariableMap.t -> ctx + type run_error = | ErrorValue of string * Pos.t | FloatIndex of string * Pos.t @@ -95,24 +97,27 @@ module type S = sig val replace_undefined_with_input_variables : Mir.program -> Mir.VariableDict.t -> Mir.program + val print_output : Bir_interface.bir_function -> ctx -> unit + val raise_runtime_as_structured : run_error -> ctx -> Mir.program -> 'a + + val evaluate_expr : ctx -> Mir.program -> Bir.expression Pos.marked -> value + + val evaluate_program : Bir.program -> ctx -> int -> ctx end -module Make (N : Bir_number.NumberInterface) = struct +module Make (N : Bir_number.NumberInterface) (RF : Bir_roundops.RoundOpsFunctor) = +struct (* Careful : this behavior mimics the one imposed by the original Mlang compiler... *) + module R = RF (N) + type custom_float = N.t - let truncatef (x : N.t) : N.t = N.floor N.(x +. N.of_float 0.000001) + let truncatef (x : N.t) : N.t = R.truncatef x - (* Careful : rounding in M is done with this arbitrary behavior. We can't use - copysign here because [x < zero] is critical to have the correct behavior - on -0 *) - let roundf (x : N.t) = - N.of_int - (N.to_int - N.(x +. N.of_float (if N.(x < zero ()) then -0.50005 else 0.50005))) + let roundf (x : N.t) = R.roundf x type value = Number of N.t | Undefined @@ -773,95 +778,109 @@ module Make (N : Bir_number.NumberInterface) = struct else raise (RuntimeError (e, ctx)) end -module RegularFloatInterpreter = Make (Bir_number.RegularFloatNumber) -module MPFRInterpreter = Make (Bir_number.MPFRNumber) - module BigIntPrecision = struct let scaling_factor_bits = ref 64 end -module BigIntInterpreter = - Make (Bir_number.BigIntFixedPointNumber (BigIntPrecision)) -module IntervalInterpreter = Make (Bir_number.IntervalNumber) -module RationalInterpreter = Make (Bir_number.RationalNumber) +module MainframeLongSize = struct + let max_long = ref Int64.max_int +end -type value_sort = - | RegularFloat - | MPFR of int - | BigInt of int - | Interval - | Rational +module FloatDefInterp = + Make (Bir_number.RegularFloatNumber) (Bir_roundops.DefaultRoundOps) +module FloatMultInterp = + Make (Bir_number.RegularFloatNumber) (Bir_roundops.MultiRoundOps) +module FloatMfInterp = + Make + (Bir_number.RegularFloatNumber) + (Bir_roundops.MainframeRoundOps (MainframeLongSize)) +module MPFRDefInterp = + Make (Bir_number.MPFRNumber) (Bir_roundops.DefaultRoundOps) +module MPFRMultInterp = + Make (Bir_number.MPFRNumber) (Bir_roundops.MultiRoundOps) +module MPFRMfInterp = + Make + (Bir_number.MPFRNumber) + (Bir_roundops.MainframeRoundOps (MainframeLongSize)) +module BigIntDefInterp = + Make + (Bir_number.BigIntFixedPointNumber + (BigIntPrecision)) + (Bir_roundops.DefaultRoundOps) +module BigIntMultInterp = + Make + (Bir_number.BigIntFixedPointNumber + (BigIntPrecision)) + (Bir_roundops.MultiRoundOps) +module BigIntMfInterp = + Make + (Bir_number.BigIntFixedPointNumber + (BigIntPrecision)) + (Bir_roundops.MainframeRoundOps (MainframeLongSize)) +module IntvDefInterp = + Make (Bir_number.IntervalNumber) (Bir_roundops.DefaultRoundOps) +module IntvMultInterp = + Make (Bir_number.IntervalNumber) (Bir_roundops.MultiRoundOps) +module IntvMfInterp = + Make + (Bir_number.IntervalNumber) + (Bir_roundops.MainframeRoundOps (MainframeLongSize)) +module RatDefInterp = + Make (Bir_number.RationalNumber) (Bir_roundops.DefaultRoundOps) +module RatMultInterp = + Make (Bir_number.RationalNumber) (Bir_roundops.MultiRoundOps) +module RatMfInterp = + Make + (Bir_number.RationalNumber) + (Bir_roundops.MainframeRoundOps (MainframeLongSize)) + +let get_interp (sort : Cli.value_sort) (roundops : Cli.round_ops) : (module S) = + match (sort, roundops) with + | RegularFloat, RODefault -> (module FloatDefInterp) + | RegularFloat, ROMulti -> (module FloatMultInterp) + | RegularFloat, ROMainframe _ -> (module FloatMfInterp) + | MPFR _, RODefault -> (module MPFRDefInterp) + | MPFR _, ROMulti -> (module MPFRMultInterp) + | MPFR _, ROMainframe _ -> (module MPFRMfInterp) + | BigInt _, RODefault -> (module BigIntDefInterp) + | BigInt _, ROMulti -> (module BigIntMultInterp) + | BigInt _, ROMainframe _ -> (module BigIntMfInterp) + | Interval, RODefault -> (module IntvDefInterp) + | Interval, ROMulti -> (module IntvMultInterp) + | Interval, ROMainframe _ -> (module IntvMfInterp) + | Rational, RODefault -> (module RatDefInterp) + | Rational, ROMulti -> (module RatMultInterp) + | Rational, ROMainframe _ -> (module RatMfInterp) + +let prepare_interp (sort : Cli.value_sort) (roundops : Cli.round_ops) : unit = + begin + match sort with + | MPFR prec -> Mpfr.set_default_prec prec + | BigInt prec -> BigIntPrecision.scaling_factor_bits := prec + | Interval -> Mpfr.set_default_prec 64 + | _ -> () + end; + match roundops with + | ROMainframe long_size -> + let max_long = + if long_size = 32 then Int64.of_int32 Int32.max_int + else if long_size = 64 then Int64.max_int + else assert false + (* checked when parsing command line *) + in + MainframeLongSize.max_long := max_long + | _ -> () let evaluate_program (bir_func : Bir_interface.bir_function) (p : Bir.program) (inputs : Mir.literal Bir.VariableMap.t) (code_loc_start_value : int) - (sort : value_sort) : unit -> unit = - match sort with - | RegularFloat -> - let ctx = - RegularFloatInterpreter.update_ctx_with_inputs - RegularFloatInterpreter.empty_ctx inputs - in - let ctx = - RegularFloatInterpreter.evaluate_program p ctx code_loc_start_value - in - fun () -> RegularFloatInterpreter.print_output bir_func ctx - | MPFR prec -> - Mpfr.set_default_prec prec; - let ctx = - MPFRInterpreter.update_ctx_with_inputs MPFRInterpreter.empty_ctx inputs - in - let ctx = MPFRInterpreter.evaluate_program p ctx code_loc_start_value in - fun () -> MPFRInterpreter.print_output bir_func ctx - | BigInt prec -> - BigIntPrecision.scaling_factor_bits := prec; - let ctx = - BigIntInterpreter.update_ctx_with_inputs BigIntInterpreter.empty_ctx - inputs - in - let ctx = BigIntInterpreter.evaluate_program p ctx code_loc_start_value in - fun () -> BigIntInterpreter.print_output bir_func ctx - | Interval -> - Mpfr.set_default_prec 64; - let ctx = - IntervalInterpreter.update_ctx_with_inputs IntervalInterpreter.empty_ctx - inputs - in - let ctx = - IntervalInterpreter.evaluate_program p ctx code_loc_start_value - in - fun () -> IntervalInterpreter.print_output bir_func ctx - | Rational -> - let ctx = - RationalInterpreter.update_ctx_with_inputs RationalInterpreter.empty_ctx - inputs - in - let ctx = - RationalInterpreter.evaluate_program p ctx code_loc_start_value - in - fun () -> RationalInterpreter.print_output bir_func ctx + (sort : Cli.value_sort) (roundops : Cli.round_ops) : unit -> unit = + prepare_interp sort roundops; + let module Interp = (val get_interp sort roundops : S) in + let ctx = Interp.update_ctx_with_inputs Interp.empty_ctx inputs in + let ctx = Interp.evaluate_program p ctx code_loc_start_value in + fun () -> Interp.print_output bir_func ctx let evaluate_expr (p : Mir.program) (e : Bir.expression Pos.marked) - (sort : value_sort) : Mir.literal = - let f p e = - match sort with - | RegularFloat -> - RegularFloatInterpreter.value_to_literal - (RegularFloatInterpreter.evaluate_expr - RegularFloatInterpreter.empty_ctx p e) - | MPFR prec -> - Mpfr.set_default_prec prec; - MPFRInterpreter.value_to_literal - (MPFRInterpreter.evaluate_expr MPFRInterpreter.empty_ctx p e) - | BigInt prec -> - BigIntPrecision.scaling_factor_bits := prec; - BigIntInterpreter.value_to_literal - (BigIntInterpreter.evaluate_expr BigIntInterpreter.empty_ctx p e) - | Interval -> - Mpfr.set_default_prec 64; - IntervalInterpreter.value_to_literal - (IntervalInterpreter.evaluate_expr IntervalInterpreter.empty_ctx p e) - | Rational -> - RationalInterpreter.value_to_literal - (RationalInterpreter.evaluate_expr RationalInterpreter.empty_ctx p e) - in - f p e + (sort : Cli.value_sort) (roundops : Cli.round_ops) : Mir.literal = + let module Interp = (val get_interp sort roundops : S) in + Interp.value_to_literal (Interp.evaluate_expr Interp.empty_ctx p e) diff --git a/src/mlang/backend_ir/bir_interpreter.mli b/src/mlang/backend_ir/bir_interpreter.mli index 99e33e962..142dc63d9 100644 --- a/src/mlang/backend_ir/bir_interpreter.mli +++ b/src/mlang/backend_ir/bir_interpreter.mli @@ -23,7 +23,7 @@ type var_literal = | SimpleVar of Mir.literal | TableVar of int * Mir.literal array -(**{1 Instrumentation of he interpreter}*) +(**{1 Instrumentation of the interpreter}*) (** The BIR interpreter can be instrumented to record which program locations have been executed. *) @@ -96,6 +96,8 @@ module type S = sig val var_value_to_var_literal : var_value -> var_literal + val update_ctx_with_inputs : ctx -> Mir.literal Bir.VariableMap.t -> ctx + (** Interpreter runtime errors *) type run_error = | ErrorValue of string * Pos.t @@ -118,41 +120,82 @@ module type S = sig (** Before execution of the program, replaces the [undefined] stubs for input variables by their true input value *) + val print_output : Bir_interface.bir_function -> ctx -> unit + val raise_runtime_as_structured : run_error -> ctx -> Mir.program -> 'a (** Raises a runtime error with a formatted error message and context *) + + val evaluate_expr : ctx -> Mir.program -> Bir.expression Pos.marked -> value + + val evaluate_program : Bir.program -> ctx -> int -> ctx end -module RegularFloatInterpreter : S +module FloatDefInterp : S +(** The different interpreters, which combine a representation of numbers and + rounding operations. The first part of the name corresponds to the + representation of numbers, and is one of the following: + + - Float: "regular" IEE754 floating point numbers + - MPFR: arbitrary precision floating-point numbers using MPFR + - BigInt: fixed-point numbers + - Intv: intervals of two IEEE754 floating-point numbers + - Rat: rationals + + The second part indicates the rounding operations to use, and is one of the + following: + + - Def: use the default rounding operations, those of the PC/single-thread + context + - Multi: use the rouding operations of the PC/multi-thread context + - Mf: use the rounding operations of the mainframe context *) + +module FloatMultInterp : S + +module FloatMfInterp : S + +module MPFRDefInterp : S + +module MPFRMultInterp : S + +module MPFRMfInterp : S + +module BigIntDefInterp : S + +module BigIntMultInterp : S + +module BigIntMfInterp : S + +module IntvDefInterp : S + +module IntvMultInterp : S -module MPFRInterpreter : S +module IntvMfInterp : S -module BigIntInterpreter : S +module RatDefInterp : S -module IntervalInterpreter : S +module RatMultInterp : S -module RationalInterpreter : S +module RatMfInterp : S (** {1 Generic interpretation API}*) -(** According on the [value_sort], a specific interpreter will be called with - the right kind of floating-point value *) -type value_sort = - | RegularFloat - | MPFR of int (** bitsize of the floats *) - | BigInt of int (** precision of the fixed point *) - | Interval - | Rational +val get_interp : Cli.value_sort -> Cli.round_ops -> (module S) val evaluate_program : Bir_interface.bir_function -> Bir.program -> Mir.literal Bir.VariableMap.t -> int -> - value_sort -> + Cli.value_sort -> + Cli.round_ops -> unit -> unit (** Main interpreter function *) val evaluate_expr : - Mir.program -> Bir.expression Pos.marked -> value_sort -> Mir.literal + Mir.program -> + Bir.expression Pos.marked -> + Cli.value_sort -> + Cli.round_ops -> + Mir.literal (** Interprets only an expression *) diff --git a/src/mlang/backend_ir/bir_number.ml b/src/mlang/backend_ir/bir_number.ml index 65fb26b01..9470ea560 100644 --- a/src/mlang/backend_ir/bir_number.ml +++ b/src/mlang/backend_ir/bir_number.ml @@ -19,8 +19,12 @@ module type NumberInterface = sig val format_t : Format.formatter -> t -> unit + val abs : t -> t + val floor : t -> t + val ceil : t -> t + val of_int : Int64.t -> t val to_int : t -> Int64.t @@ -69,8 +73,12 @@ module RegularFloatNumber : NumberInterface = struct let format_t fmt f = Format.fprintf fmt "%f" f + let abs x = Float.abs x + let floor x = Float.floor x + let ceil x = Float.ceil x + let of_int i = Int64.to_float i let to_int f = Int64.of_float f @@ -112,11 +120,21 @@ module RegularFloatNumber : NumberInterface = struct let is_zero x = x = 0. end -let mpfr_float (x : Mpfrf.t) : Mpfrf.t = +let mpfr_abs (x : Mpfrf.t) : Mpfrf.t = + let out = Mpfr.init2 (Mpfr.get_prec x) in + ignore (Mpfr.abs out x Mpfr.Near); + Mpfrf.of_mpfr out + +let mpfr_floor (x : Mpfrf.t) : Mpfrf.t = let out = Mpfr.init () in ignore (Mpfr.floor out x); Mpfrf.of_mpfr out +let mpfr_ceil (x : Mpfrf.t) : Mpfrf.t = + let out = Mpfr.init () in + ignore (Mpfr.ceil out x); + Mpfrf.of_mpfr out + module MPFRNumber : NumberInterface = struct type t = Mpfrf.t @@ -124,7 +142,11 @@ module MPFRNumber : NumberInterface = struct let format_t fmt f = Format.fprintf fmt "%a" Mpfrf.print f - let floor (x : t) : t = mpfr_float x + let abs (x : t) : t = mpfr_abs x + + let floor (x : t) : t = mpfr_floor x + + let ceil (x : t) : t = mpfr_ceil x let of_int i = Mpfrf.of_int (Int64.to_int i) rounding @@ -175,9 +197,19 @@ module IntervalNumber : NumberInterface = struct let format_t fmt f = Format.fprintf fmt "[%a;%a]" Mpfrf.print f.down Mpfrf.print f.up + let abs x = + let id = mpfr_abs x.down in + let iu = mpfr_abs x.up in + v id iu + let floor x = - let id = mpfr_float x.down in - let iu = mpfr_float x.up in + let id = mpfr_floor x.down in + let iu = mpfr_floor x.up in + v id iu + + let ceil x = + let id = mpfr_ceil x.down in + let iu = mpfr_ceil x.up in v id iu let of_int i = @@ -276,10 +308,17 @@ module RationalNumber : NumberInterface = struct let format_t fmt f = Mpqf.print fmt f + let abs x = Mpqf.abs x + let floor x = let num = Mpqf.get_num x in let dem = Mpqf.get_den x in - Mpqf.of_mpz (Mpzf.tdiv_q num dem) + Mpqf.of_mpz (Mpzf.fdiv_q num dem) + + let ceil x = + let num = Mpqf.get_num x in + let dem = Mpqf.get_den x in + Mpqf.of_mpz (Mpzf.cdiv_q num dem) let of_int i = Mpqf.of_int (Int64.to_int i) @@ -352,7 +391,15 @@ end) : NumberInterface = struct let int_part = Mpzf.mul int_part (precision_modulo ()) in (frac_part, int_part) - let floor x = snd (modf x) + let abs x = Mpzf.abs x + + let floor x = + let prec_mod = precision_modulo () in + Mpzf.mul (Mpzf.fdiv_q x prec_mod) prec_mod + + let ceil x = + let prec_mod = precision_modulo () in + Mpzf.mul (Mpzf.cdiv_q x prec_mod) prec_mod let of_int i = Mpzf.mul (Mpzf.of_int (Int64.to_int i)) (precision_modulo ()) diff --git a/src/mlang/backend_ir/bir_number.mli b/src/mlang/backend_ir/bir_number.mli index 0791454d4..7c74dc230 100644 --- a/src/mlang/backend_ir/bir_number.mli +++ b/src/mlang/backend_ir/bir_number.mli @@ -19,8 +19,12 @@ module type NumberInterface = sig val format_t : Format.formatter -> t -> unit + val abs : t -> t + val floor : t -> t + val ceil : t -> t + val of_int : Int64.t -> t val to_int : t -> Int64.t @@ -64,7 +68,7 @@ end module RegularFloatNumber : NumberInterface -val mpfr_float : Mpfrf.t -> Mpfrf.t +val mpfr_floor : Mpfrf.t -> Mpfrf.t module MPFRNumber : NumberInterface diff --git a/src/mlang/backend_ir/bir_roundops.ml b/src/mlang/backend_ir/bir_roundops.ml new file mode 100644 index 000000000..dd793ef99 --- /dev/null +++ b/src/mlang/backend_ir/bir_roundops.ml @@ -0,0 +1,77 @@ +(* Copyright (C) 2019-2021 Inria, contributor: Denis Merigoux + + + This program is free software: you can redistribute it and/or modify it under + the terms of the GNU General Public License as published by the Free Software + Foundation, either version 3 of the License, or (at your option) any later + version. + + This program is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + details. + + You should have received a copy of the GNU General Public License along with + this program. If not, see . *) + +module type RoundOpsInterface = sig + type t + + val truncatef : t -> t + + val roundf : t -> t +end + +module type RoundOpsFunctor = functor (N : Bir_number.NumberInterface) -> + RoundOpsInterface with type t = N.t + +module DefaultRoundOps (N : Bir_number.NumberInterface) : + RoundOpsInterface with type t = N.t = struct + type t = N.t + + let truncatef (x : N.t) : N.t = N.floor N.(x +. N.of_float 0.000001) + + (* Careful : rounding in M is done with this arbitrary behavior. We can't use + copysign here because [x < zero] is critical to have the correct behavior + on -0 *) + let roundf (x : N.t) = + N.of_int + (N.to_int + N.(x +. N.of_float (if N.(x < zero ()) then -0.50005 else 0.50005))) +end + +module MultiRoundOps (N : Bir_number.NumberInterface) : + RoundOpsInterface with type t = N.t = struct + type t = N.t + + let truncatef (x : N.t) : N.t = N.floor N.(x +. N.of_float 0.000001) + + let roundf (x : N.t) = + let n_0_5 = N.of_float 0.5 in + let n_100000_0 = N.of_float 100000.0 in + let v1 = N.floor x in + let v2 = N.(N.floor (((x -. v1) *. n_100000_0) +. n_0_5) /. n_100000_0) in + N.floor N.(v1 +. v2 +. n_0_5) +end + +module MainframeRoundOps (L : sig + val max_long : Int64.t ref +end) +(N : Bir_number.NumberInterface) : RoundOpsInterface with type t = N.t = struct + type t = N.t + + let floor_g (x : N.t) : N.t = + if N.abs x <= N.of_int !L.max_long then N.floor x else x + + let ceil_g (x : N.t) : N.t = + if N.abs x <= N.of_int !L.max_long then N.ceil x else x + + let truncatef (x : N.t) : N.t = floor_g N.(x +. N.of_float 0.000001) + + (* Careful : rounding in M is done with this arbitrary behavior. We can't use + copysign here because [x < zero] is critical to have the correct behavior + on -0 *) + let roundf (x : N.t) = + if N.(x < zero ()) then ceil_g N.(x -. N.of_float 0.50005) + else floor_g N.(x +. N.of_float 0.50005) +end diff --git a/src/mlang/backend_ir/bir_roundops.mli b/src/mlang/backend_ir/bir_roundops.mli new file mode 100644 index 000000000..d05d865e9 --- /dev/null +++ b/src/mlang/backend_ir/bir_roundops.mli @@ -0,0 +1,44 @@ +(* Copyright (C) 2019-2021 Inria, contributor: Denis Merigoux + + + This program is free software: you can redistribute it and/or modify it under + the terms of the GNU General Public License as published by the Free Software + Foundation, either version 3 of the License, or (at your option) any later + version. + + This program is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + details. + + You should have received a copy of the GNU General Public License along with + this program. If not, see . *) + +(** Rounding operations to use in the interpreter *) +module type RoundOpsInterface = sig + type t + + val truncatef : t -> t + + val roundf : t -> t +end + +(** The actual implementation of rounding operations depends on the chosen + representation of numbers, hence we need a functor *) +module type RoundOpsFunctor = functor (N : Bir_number.NumberInterface) -> + RoundOpsInterface with type t = N.t + +module DefaultRoundOps : RoundOpsFunctor +(** Default rounding operations: those used in the PC/single-thread context *) + +module MultiRoundOps : RoundOpsFunctor +(** Multithread rounding operations: those used in the PC/multi-thread context *) + +(** Mainframe rounding operations: those used in the mainframe context. As the + behavior depends on the sie of the `long` type, this size must be given as + an argument (and should be either 32 or 64). *) +module MainframeRoundOps : functor + (L : sig + val max_long : Int64.t ref + end) + -> RoundOpsFunctor diff --git a/src/mlang/driver.ml b/src/mlang/driver.ml index e69a5cc20..380e227dc 100644 --- a/src/mlang/driver.ml +++ b/src/mlang/driver.ml @@ -79,11 +79,55 @@ let driver (files : string list) (debug : bool) (var_info_debug : string list) (mpp_file : string) (output : string option) (run_all_tests : string option) (run_test : string option) (mpp_function : string) (optimize : bool) (optimize_unsafe_float : bool) (code_coverage : bool) - (precision : string option) (test_error_margin : float option) - (m_clean_calls : bool) (dgfip_options : string list option) + (precision : string option) (roundops : string option) + (test_error_margin : float option) (m_clean_calls : bool) + (dgfip_options : string list option) (var_dependencies : (string * string) option) = + let value_sort = + let precision = Option.get precision in + if precision = "double" then Cli.RegularFloat + else + let mpfr_regex = Re.Pcre.regexp "^mpfr(\\d+)$" in + if Re.Pcre.pmatch ~rex:mpfr_regex precision then + let mpfr_prec = + Re.Pcre.get_substring (Re.Pcre.exec ~rex:mpfr_regex precision) 1 + in + Cli.MPFR (int_of_string mpfr_prec) + else if precision = "interval" then Cli.Interval + else + let bigint_regex = Re.Pcre.regexp "^fixed(\\d+)$" in + if Re.Pcre.pmatch ~rex:bigint_regex precision then + let fixpoint_prec = + Re.Pcre.get_substring (Re.Pcre.exec ~rex:bigint_regex precision) 1 + in + Cli.BigInt (int_of_string fixpoint_prec) + else if precision = "mpq" then Cli.Rational + else + Errors.raise_error + (Format.asprintf "Unkown precision option: %s" precision) + in + let round_ops = + let roundops = Option.get roundops in + if roundops = "default" then Cli.RODefault + else if roundops = "multi" then Cli.ROMulti + else + let mf_regex = Re.Pcre.regexp "^mainframe(\\d+)$" in + if Re.Pcre.pmatch ~rex:mf_regex roundops then + let mf_long_size = + Re.Pcre.get_substring (Re.Pcre.exec ~rex:mf_regex roundops) 1 + in + match int_of_string mf_long_size with + | (32 | 64) as sz -> Cli.ROMainframe sz + | _ -> + Errors.raise_error + (Format.asprintf "Invalid long size for mainframe: %s" + mf_long_size) + else + Errors.raise_error + (Format.asprintf "Unkown roundops option: %s" roundops) + in Cli.set_all_arg_refs files debug var_info_debug display_time dep_graph_file - print_cycles output optimize_unsafe_float m_clean_calls; + print_cycles output optimize_unsafe_float m_clean_calls value_sort round_ops; try let dgfip_flags = process_dgfip_options backend dgfip_options in Cli.debug_print "Reading M files..."; @@ -156,29 +200,6 @@ let driver (files : string list) (debug : bool) (var_info_debug : string list) let combined_program = Mpp_ir_to_bir.create_combined_program full_m_program mpp mpp_function in - let value_sort = - let precision = Option.get precision in - if precision = "double" then Bir_interpreter.RegularFloat - else - let mpfr_regex = Re.Pcre.regexp "^mpfr(\\d+)$" in - if Re.Pcre.pmatch ~rex:mpfr_regex precision then - let mpfr_prec = - Re.Pcre.get_substring (Re.Pcre.exec ~rex:mpfr_regex precision) 1 - in - Bir_interpreter.MPFR (int_of_string mpfr_prec) - else if precision = "interval" then Bir_interpreter.Interval - else - let bigint_regex = Re.Pcre.regexp "^fixed(\\d+)$" in - if Re.Pcre.pmatch ~rex:bigint_regex precision then - let fixpoint_prec = - Re.Pcre.get_substring (Re.Pcre.exec ~rex:bigint_regex precision) 1 - in - Bir_interpreter.BigInt (int_of_string fixpoint_prec) - else if precision = "mpq" then Bir_interpreter.Rational - else - Errors.raise_error - (Format.asprintf "Unkown precision option: %s" precision) - in if run_all_tests <> None then begin if code_coverage && optimize then Errors.raise_error @@ -188,7 +209,7 @@ let driver (files : string list) (debug : bool) (var_info_debug : string list) match run_all_tests with Some s -> s | _ -> assert false in Test_interpreter.check_all_tests combined_program tests optimize - code_coverage value_sort + code_coverage value_sort round_ops (Option.get test_error_margin) end else if run_test <> None then begin @@ -201,7 +222,7 @@ let driver (files : string list) (debug : bool) (var_info_debug : string list) in ignore (Test_interpreter.check_test combined_program test optimize false - value_sort + value_sort round_ops (Option.get test_error_margin)); Cli.result_print "Test passed!" end @@ -236,7 +257,7 @@ let driver (files : string list) (debug : bool) (var_info_debug : string list) let inputs = Bir_interface.read_inputs_from_stdin function_spec in let print_output = Bir_interpreter.evaluate_program function_spec combined_program - inputs 0 value_sort + inputs 0 value_sort round_ops in print_output () end diff --git a/src/mlang/mpp_ir/mpp_ir_to_bir.ml b/src/mlang/mpp_ir/mpp_ir_to_bir.ml index ec5123602..05b0f3a15 100644 --- a/src/mlang/mpp_ir/mpp_ir_to_bir.ml +++ b/src/mlang/mpp_ir/mpp_ir_to_bir.ml @@ -203,9 +203,8 @@ let wrap_m_code_call (m_program : Mir_interface.full_program) { m_program with program = - Bir_interpreter.RegularFloatInterpreter - .replace_undefined_with_input_variables m_program.program - ctx.variables_used_as_inputs; + Bir_interpreter.FloatDefInterp.replace_undefined_with_input_variables + m_program.program ctx.variables_used_as_inputs; } in let execution_order = @@ -558,6 +557,6 @@ let create_combined_program (m_program : Mir_interface.full_program) mir_program = m_program.program; outputs = Bir.VariableMap.empty; } - with Bir_interpreter.RegularFloatInterpreter.RuntimeError (r, ctx) -> - Bir_interpreter.RegularFloatInterpreter.raise_runtime_as_structured r ctx + with Bir_interpreter.FloatDefInterp.RuntimeError (r, ctx) -> + Bir_interpreter.FloatDefInterp.raise_runtime_as_structured r ctx m_program.program diff --git a/src/mlang/optimizing_ir/partial_evaluation.ml b/src/mlang/optimizing_ir/partial_evaluation.ml index 6a66a781d..219984076 100644 --- a/src/mlang/optimizing_ir/partial_evaluation.ml +++ b/src/mlang/optimizing_ir/partial_evaluation.ml @@ -252,10 +252,10 @@ let get_closest_dominating_def (var : Bir.variable) (ctx : partial_ev_ctx) : else Some def) let interpreter_ctx_from_partial_ev_ctx (ctx : partial_ev_ctx) : - Bir_interpreter.RegularFloatInterpreter.ctx = + Bir_interpreter.FloatDefInterp.ctx = { - Bir_interpreter.RegularFloatInterpreter.empty_ctx with - Bir_interpreter.RegularFloatInterpreter.ctx_vars = + Bir_interpreter.FloatDefInterp.empty_ctx with + Bir_interpreter.FloatDefInterp.ctx_vars = Bir.VariableMap.map Option.get (Bir.VariableMap.filter (fun _ x -> Option.is_some x) @@ -264,9 +264,8 @@ let interpreter_ctx_from_partial_ev_ctx (ctx : partial_ev_ctx) : match get_closest_dominating_def var ctx with | Some (SimpleVar (PartialLiteral l)) -> Some - (Bir_interpreter.RegularFloatInterpreter.SimpleVar - (Bir_interpreter.RegularFloatInterpreter - .literal_to_value l)) + (Bir_interpreter.FloatDefInterp.SimpleVar + (Bir_interpreter.FloatDefInterp.literal_to_value l)) | _ -> None) ctx.ctx_vars)); } @@ -293,7 +292,7 @@ let rec partially_evaluate_expr (ctx : partial_ev_ctx) (p : Mir.program) Mir.Literal (Bir_interpreter.evaluate_expr p (Pos.same_pos_as (Mir.Comparison (op, new_e1, new_e2)) e) - RegularFloat) + !Cli.value_sort !Cli.round_ops) | _ -> if d1 = Undefined || d2 = Undefined then Mir.Literal Undefined else Comparison (op, new_e1, new_e2) @@ -312,7 +311,7 @@ let rec partially_evaluate_expr (ctx : partial_ev_ctx) (p : Mir.program) from_literal (Bir_interpreter.evaluate_expr p (Pos.same_pos_as (Mir.Binop (op, new_e1, new_e2)) e1) - RegularFloat) + !Cli.value_sort !Cli.round_ops) (* first all the combinations giving undefined *) | Mast.And, (Literal Undefined, _ | _, Undefined), _ -> from_literal Undefined @@ -410,7 +409,7 @@ let rec partially_evaluate_expr (ctx : partial_ev_ctx) (p : Mir.program) from_literal (Bir_interpreter.evaluate_expr p (Pos.same_pos_as (Mir.Unop (op, new_e1)) e1) - RegularFloat) + !Cli.value_sort !Cli.round_ops) | _ -> ( ( Unop (op, new_e1), match (op, d1) with @@ -450,17 +449,16 @@ let rec partially_evaluate_expr (ctx : partial_ev_ctx) (p : Mir.program) then int_of_float f else let err, ctx = - ( Bir_interpreter.RegularFloatInterpreter.FloatIndex + ( Bir_interpreter.FloatDefInterp.FloatIndex (Format.asprintf "%f" f, Pos.get_position e1), interpreter_ctx_from_partial_ev_ctx ctx ) in if !Bir_interpreter.exit_on_rte then - Bir_interpreter.RegularFloatInterpreter - .raise_runtime_as_structured err ctx p + Bir_interpreter.FloatDefInterp.raise_runtime_as_structured + err ctx p else raise - (Bir_interpreter.RegularFloatInterpreter.RuntimeError - (err, ctx)) + (Bir_interpreter.FloatDefInterp.RuntimeError (err, ctx)) in match get_closest_dominating_def (Pos.unmark var) ctx with | Some (SimpleVar _) -> assert false (* should not happen *) @@ -615,7 +613,9 @@ let rec partially_evaluate_expr (ctx : partial_ev_ctx) (p : Mir.program) let new_e = Pos.same_pos_as (Mir.FunctionCall (func, new_args)) e in let new_e, d = if all_args_literal then - from_literal (Bir_interpreter.evaluate_expr p new_e RegularFloat) + from_literal + (Bir_interpreter.evaluate_expr p new_e !Cli.value_sort + !Cli.round_ops) else match func with | ArrFunc | InfFunc | MinFunc | MaxFunc | Multimax -> @@ -727,16 +727,14 @@ let partially_evaluate_stmt (stmt : stmt) (block_id : block_id) | Some (PartialLiteral (Float _)) -> Cli.error_print "Error during partial evaluation!"; let err, ctx = - ( Bir_interpreter.RegularFloatInterpreter.ConditionViolated + ( Bir_interpreter.FloatDefInterp.ConditionViolated (fst cond.cond_error, cond.cond_expr, []), interpreter_ctx_from_partial_ev_ctx ctx ) in if !Bir_interpreter.exit_on_rte then - Bir_interpreter.RegularFloatInterpreter.raise_runtime_as_structured - err ctx p.mir_program - else - raise - (Bir_interpreter.RegularFloatInterpreter.RuntimeError (err, ctx)) + Bir_interpreter.FloatDefInterp.raise_runtime_as_structured err ctx + p.mir_program + else raise (Bir_interpreter.FloatDefInterp.RuntimeError (err, ctx)) | _ -> ( Pos.same_pos_as (SVerif { cond with cond_expr = new_e }) stmt :: new_block, diff --git a/src/mlang/test_framework/test_interpreter.ml b/src/mlang/test_framework/test_interpreter.ml index cadc46f44..a061cda95 100644 --- a/src/mlang/test_framework/test_interpreter.ml +++ b/src/mlang/test_framework/test_interpreter.ml @@ -121,8 +121,8 @@ let to_MIR_function_and_inputs (program : Bir.program) (t : test_file) input_file ) let check_test (combined_program : Bir.program) (test_name : string) - (optimize : bool) (code_coverage : bool) - (value_sort : Bir_interpreter.value_sort) (test_error_margin : float) : + (optimize : bool) (code_coverage : bool) (value_sort : Cli.value_sort) + (round_ops : Cli.round_ops) (test_error_margin : float) : Bir_instrumentation.code_coverage_result = Cli.debug_print "Parsing %s..." test_name; let t = parse_file test_name in @@ -151,7 +151,7 @@ let check_test (combined_program : Bir.program) (test_name : string) if code_coverage then Bir_instrumentation.code_coverage_init (); let _print_outputs = Bir_interpreter.evaluate_program f combined_program input_file - (-code_loc_offset) value_sort + (-code_loc_offset) value_sort round_ops in if code_coverage then Bir_instrumentation.code_coverage_result () else Bir_instrumentation.empty_code_coverage_result @@ -173,8 +173,8 @@ let incr_int_key (m : int IntMap.t) (key : int) : int IntMap.t = | Some i -> IntMap.add key (i + 1) m let check_all_tests (p : Bir.program) (test_dir : string) (optimize : bool) - (code_coverage_activated : bool) (value_sort : Bir_interpreter.value_sort) - (test_error_margin : float) = + (code_coverage_activated : bool) (value_sort : Cli.value_sort) + (round_ops : Cli.round_ops) (test_error_margin : float) = let arr = Sys.readdir test_dir in let arr = Array.of_list @@ -220,11 +220,14 @@ let check_all_tests (p : Bir.program) (test_dir : string) (optimize : bool) (Pos.unmark err.Mir.Error.name); (successes, failures, code_coverage_acc) in + let module Interp = (val Bir_interpreter.get_interp value_sort round_ops + : Bir_interpreter.S) + in try Cli.debug_flag := false; let code_coverage_result = check_test p (test_dir ^ name) optimize code_coverage_activated - value_sort test_error_margin + value_sort round_ops test_error_margin in Cli.debug_flag := true; let code_coverage_acc = @@ -233,119 +236,27 @@ let check_all_tests (p : Bir.program) (test_dir : string) (optimize : bool) in (name :: successes, failures, code_coverage_acc) with - | Bir_interpreter.RegularFloatInterpreter.RuntimeError - ((ConditionViolated _ as cv), _) -> - let expr, err, bindings = - match cv with - | Bir_interpreter.RegularFloatInterpreter.ConditionViolated - (err, expr, bindings) -> ( - ( expr, - err, - match bindings with - | [ (v, Bir_interpreter.RegularFloatInterpreter.SimpleVar l1) ] - -> - Some - ( v, - Bir_interpreter.RegularFloatInterpreter.value_to_literal - l1 ) - | _ -> None )) - | _ -> assert false - (* should not happen *) - in - report_violated_condition_error bindings expr err - | Bir_interpreter.MPFRInterpreter.RuntimeError - ((ConditionViolated _ as cv), _) -> - let expr, err, bindings = - match cv with - | Bir_interpreter.MPFRInterpreter.ConditionViolated - (err, expr, bindings) -> ( - ( expr, - err, - match bindings with - | [ (v, Bir_interpreter.MPFRInterpreter.SimpleVar l1) ] -> - Some (v, Bir_interpreter.MPFRInterpreter.value_to_literal l1) - | _ -> None )) - | _ -> assert false - (* should not happen *) - in - report_violated_condition_error bindings expr err - | Bir_interpreter.BigIntInterpreter.RuntimeError - ((ConditionViolated _ as cv), _) -> - let expr, err, bindings = - match cv with - | Bir_interpreter.BigIntInterpreter.ConditionViolated - (err, expr, bindings) -> ( - ( expr, - err, - match bindings with - | [ (v, Bir_interpreter.BigIntInterpreter.SimpleVar l1) ] -> - Some - (v, Bir_interpreter.BigIntInterpreter.value_to_literal l1) - | _ -> None )) - | _ -> assert false - (* should not happen *) - in - report_violated_condition_error bindings expr err - | Bir_interpreter.IntervalInterpreter.RuntimeError - ((ConditionViolated _ as cv), _) -> - let expr, err, bindings = - match cv with - | Bir_interpreter.IntervalInterpreter.ConditionViolated - (err, expr, bindings) -> ( - ( expr, - err, - match bindings with - | [ (v, Bir_interpreter.IntervalInterpreter.SimpleVar l1) ] -> - Some - ( v, - Bir_interpreter.IntervalInterpreter.value_to_literal l1 - ) - | _ -> None )) - | _ -> assert false - (* should not happen *) - in - report_violated_condition_error bindings expr err - | Bir_interpreter.RationalInterpreter.RuntimeError - ((ConditionViolated _ as cv), _) -> + | Interp.RuntimeError ((ConditionViolated _ as cv), _) -> let expr, err, bindings = match cv with - | Bir_interpreter.RationalInterpreter.ConditionViolated - (err, expr, bindings) -> ( + | Interp.ConditionViolated (err, expr, bindings) -> ( ( expr, err, match bindings with - | [ (v, Bir_interpreter.RationalInterpreter.SimpleVar l1) ] -> - Some - ( v, - Bir_interpreter.RationalInterpreter.value_to_literal l1 - ) + | [ (v, Interp.SimpleVar l1) ] -> + Some (v, Interp.value_to_literal l1) | _ -> None )) | _ -> assert false (* should not happen *) in report_violated_condition_error bindings expr err - | Bir_interpreter.IntervalInterpreter.RuntimeError - (Bir_interpreter.IntervalInterpreter.StructuredError (msg, pos, kont), _) - | Bir_interpreter.BigIntInterpreter.RuntimeError - (Bir_interpreter.BigIntInterpreter.StructuredError (msg, pos, kont), _) - | Bir_interpreter.MPFRInterpreter.RuntimeError - (Bir_interpreter.MPFRInterpreter.StructuredError (msg, pos, kont), _) - | Bir_interpreter.RegularFloatInterpreter.RuntimeError - ( Bir_interpreter.RegularFloatInterpreter.StructuredError - (msg, pos, kont), - _ ) - | Bir_interpreter.RationalInterpreter.RuntimeError - (Bir_interpreter.RationalInterpreter.StructuredError (msg, pos, kont), _) + | Interp.RuntimeError (Interp.StructuredError (msg, pos, kont), _) | Errors.StructuredError (msg, pos, kont) -> Cli.error_print "Error in test %s: %a" name Errors.format_structured_error (msg, pos); (match kont with None -> () | Some kont -> kont ()); (successes, failures, code_coverage_acc) - | Bir_interpreter.IntervalInterpreter.RuntimeError (_, _) - | Bir_interpreter.BigIntInterpreter.RuntimeError (_, _) - | Bir_interpreter.MPFRInterpreter.RuntimeError (_, _) - | Bir_interpreter.RegularFloatInterpreter.RuntimeError (_, _) - | Bir_interpreter.RationalInterpreter.RuntimeError (_, _) -> + | Interp.RuntimeError (_, _) -> Cli.error_print "Runtime error in test %s" name; (successes, failures, code_coverage_acc) in diff --git a/src/mlang/test_framework/test_interpreter.mli b/src/mlang/test_framework/test_interpreter.mli index f8c99f9ff..75a85967e 100644 --- a/src/mlang/test_framework/test_interpreter.mli +++ b/src/mlang/test_framework/test_interpreter.mli @@ -18,22 +18,26 @@ val check_test : (* test file name *) string -> (* optimize *) bool -> (* code coverage *) bool -> - Bir_interpreter.value_sort -> + Cli.value_sort -> + Cli.round_ops -> (* test_error margin *) float -> Bir_instrumentation.code_coverage_result -(** [check_test test_file optimize code_coverage value_sort test_error_margin] - runs the BIR interpreter using float kind [value_sort] on a given - [test_file]. A margin or error of [test_error_margin] is tolerated between - the computed values and the expected values. [optimize] and [code_coverage] - are flags that trigger respectively compiler optimizations and code coverage - instrumentation for the interpreter run. *) +(** [check_test test_file optimize code_coverage value_sort round_ops + test_error_margin] + runs the BIR interpreter using float kind [value_sort] and rounding + operations [round_ops] on a given [test_file]. A margin or error of + [test_error_margin] is tolerated between the computed values and the + expected values. [optimize] and [code_coverage] are flags that trigger + respectively compiler optimizations and code coverage instrumentation for + the interpreter run. *) val check_all_tests : Bir.program -> string -> bool -> bool -> - Bir_interpreter.value_sort -> + Cli.value_sort -> + Cli.round_ops -> float -> unit (** Similar to [check_test] but tests a whole folder full of test files *) diff --git a/src/mlang/utils/cli.ml b/src/mlang/utils/cli.ml index 18a387564..5c08518bb 100644 --- a/src/mlang/utils/cli.ml +++ b/src/mlang/utils/cli.ml @@ -144,6 +144,20 @@ let precision = and down rounding mode), mpq (multi-precision rationals) . Default \ is double") +let roundops = + Arg.( + value + & opt (some string) (Some "default") + & info [ "roundops" ] ~docv:"ROUNDOPS" + ~doc: + "Rounding operations to use in the interpreter: default, multi, \ + mainframe (where n is the size in bits of the long type to \ + simulate). Each corresponds to the behavior of the legacy DGFiP \ + code in different environments: default when running on a regular \ + PC, multi when running in a multithread context, and mainframe when \ + running on a mainframe. In this case, the size of the long type has \ + to be specified; it can be either 32 or 64.") + let test_error_margin = Arg.( value @@ -183,7 +197,7 @@ let mlang_t f = const f $ files $ debug $ var_info_debug $ display_time $ dep_graph_file $ no_print_cycles $ backend $ function_spec $ mpp_file $ output $ run_all_tests $ run_test $ mpp_function $ optimize $ optimize_unsafe_float - $ code_coverage $ precision $ test_error_margin $ m_clean_calls + $ code_coverage $ precision $ roundops $ test_error_margin $ m_clean_calls $ dgfip_options $ var_dependencies) let info = @@ -228,6 +242,16 @@ let info = | Some v -> Build_info.V1.Version.to_string v) ~doc ~exits ~man +type value_sort = + | RegularFloat + | MPFR of int (* bitsize of the floats *) + | BigInt of int (* precision of the fixed point *) + | Interval + | Rational + +type round_ops = RODefault | ROMulti | ROMainframe of int +(* size of type long, either 32 or 64 *) + let source_files : string list ref = ref [] let dep_graph_file : string ref = ref "dep_graph.dot" @@ -252,11 +276,16 @@ let optimize_unsafe_float = ref false let m_clean_calls = ref false +let value_sort = ref RegularFloat + +let round_ops = ref RODefault + let set_all_arg_refs (files_ : string list) (debug_ : bool) (var_info_debug_ : string list) (display_time_ : bool) (dep_graph_file_ : string) (no_print_cycles_ : bool) (output_file_ : string option) (optimize_unsafe_float_ : bool) - (m_clean_calls_ : bool) = + (m_clean_calls_ : bool) (value_sort_ : value_sort) (round_ops_ : round_ops) + = source_files := files_; debug_flag := debug_; var_info_debug := var_info_debug_; @@ -266,6 +295,8 @@ let set_all_arg_refs (files_ : string list) (debug_ : bool) no_print_cycles_flag := no_print_cycles_; optimize_unsafe_float := optimize_unsafe_float_; m_clean_calls := m_clean_calls_; + value_sort := value_sort_; + round_ops := round_ops_; match output_file_ with None -> () | Some o -> output_file := o (**{1 Terminal formatting}*) diff --git a/src/mlang/utils/cli.mli b/src/mlang/utils/cli.mli index 534c0215d..e6376f9b3 100644 --- a/src/mlang/utils/cli.mli +++ b/src/mlang/utils/cli.mli @@ -36,6 +36,7 @@ val mlang_t : bool -> bool -> string option -> + string option -> float option -> bool -> string list option -> @@ -49,6 +50,27 @@ val info : Cmdliner.Cmd.info (**{2 Flags and parameters}*) +(** According on the [value_sort], a specific interpreter will be called with + the right kind of floating-point value *) +type value_sort = + | RegularFloat + | MPFR of int (** bitsize of the floats *) + | BigInt of int (** precision of the fixed point *) + | Interval + | Rational + +(** Rounding operations to use in the interpreter. They correspond to the + rounding operations used by the DGFiP calculator in different execution + contexts. + + - RODefault: rounding operations used in the PC/single-thread context + - ROMulti: rouding operations used in the PC/multi-thread context + - ROMainframe rounding operations used in the mainframe context *) +type round_ops = + | RODefault + | ROMulti + | ROMainframe of int (** size of type long, either 32 or 64 *) + val source_files : string list ref (** M source files to be compiled *) @@ -85,6 +107,10 @@ val optimize_unsafe_float : bool ref val m_clean_calls : bool ref (** Clean regular variables between M calls *) +val value_sort : value_sort ref + +val round_ops : round_ops ref + val set_all_arg_refs : (* files *) string list -> (* debug *) bool -> @@ -95,6 +121,8 @@ val set_all_arg_refs : (* output_file *) string option -> (* optimize_unsafe_float *) bool -> (* m_clean_call *) bool -> + value_sort -> + round_ops -> unit val add_prefix_to_each_line : string -> (int -> string) -> string