From 5a04c74c0a14b704664c7f2cceb20251dd2b5371 Mon Sep 17 00:00:00 2001 From: Lukasz Stafiniak Date: Fri, 24 Jan 2025 22:16:43 +0100 Subject: [PATCH] Fixes #303: major expansion of available operations, work in progress --- arrayjit/lib/low_level.ml | 65 +++++++++++- arrayjit/lib/low_level.mli | 10 ++ arrayjit/lib/ops.ml | 198 ++++++++++++++++++++++++++++++++++--- lib/attic.mld | 41 ++++++++ todo-archive.md | 8 +- todo.md | 7 -- 6 files changed, 308 insertions(+), 21 deletions(-) diff --git a/arrayjit/lib/low_level.ml b/arrayjit/lib/low_level.ml index 559f26fb..ae854629 100644 --- a/arrayjit/lib/low_level.ml +++ b/arrayjit/lib/low_level.ml @@ -847,7 +847,7 @@ let get_ident_within_code ?no_dots llcs = Tn.update_code_name tn ident; ident -let fprint_hum ?name ?static_indices () ppf llc = +let fprint_cstyle ?name ?static_indices () ppf llc = let ident_label = get_ident_within_code [| llc |] in let open Stdlib.Format in pp_set_margin ppf !code_hum_margin; @@ -899,7 +899,68 @@ let fprint_hum ?name ?static_indices () ppf llc = let prefix, infix, postfix = Ops.binop_c_syntax prec op in fprintf ppf "@[<1>%s%a%s@ %a@]%s" prefix (pp_float prec) v1 infix (pp_float prec) v2 postfix | Unop (Identity, v) -> (pp_float prec) ppf v - | Unop (Relu, v) -> fprintf ppf "@[<1>relu(%a@])" (pp_float prec) v + | Unop (op, v) -> + let prefix, postfix = Ops.unop_c_syntax prec op in + fprintf ppf "%s%a%s" prefix (pp_float prec) v postfix + in + fprintf ppf "@,@["; + fprint_function_header ?name ?static_indices () ppf; + pp_ll ppf llc; + fprintf ppf "@]" + +let fprint_hum ?name ?static_indices () ppf llc = + let ident_label = get_ident_within_code [| llc |] in + let open Stdlib.Format in + pp_set_margin ppf !code_hum_margin; + let pp_ident ppf la = fprintf ppf "%s" @@ ident_label la in + let pp_local ppf { tn; scope_id } = fprintf ppf "v%d_%a" scope_id pp_ident tn in + let rec pp_ll ppf c : unit = + match c with + | Noop -> () + | Seq (c1, c2) -> + fprintf ppf "@[%a@]" (pp_print_list pp_ll) + (List.filter [ c1; c2 ] ~f:(function Noop -> false | _ -> true)) + | For_loop { index = i; from_; to_; body; trace_it = _ } -> + fprintf ppf "@[for %a = %d to %d {@ %a@]@,}" pp_symbol i from_ to_ pp_ll body + | Zero_out tn -> fprintf ppf "zero_out %a;" pp_ident tn + | Set p -> + p.debug <- asprintf "@[<2>%a[@,%a] :=@ %a;@]" pp_ident p.tn pp_indices p.idcs pp_float p.llv; + fprintf ppf "@[<2>%a[@,%a] :=@ %a;@]" pp_ident p.tn pp_indices p.idcs pp_float p.llv + | Comment message -> fprintf ppf "/* %s */" message + | Staged_compilation _ -> fprintf ppf "STAGED_COMPILATION_CALLBACK()" + | Set_local (id, llv) -> fprintf ppf "@[<2>%a :=@ %a;@]" pp_local id pp_float llv + and pp_float ppf value = + match value with + | Local_scope { id; body; _ } -> fprintf ppf "@[<2>%a {@ %a@]@ }@," pp_local id pp_ll body + | Get_local id -> pp_local ppf id + | Get_global (Ops.C_function s, None) -> fprintf ppf "%s()" s + | Get_global (Ops.C_function s, Some idcs) -> fprintf ppf "%s(%a)" s pp_indices idcs + | Get_global (Ops.External_unsafe { ptr; prec; dims = _ }, None) -> + fprintf ppf "%s" @@ Ops.ptr_to_string_hum ptr prec + | Get_global (Ops.External_unsafe { ptr; prec; dims = _ }, Some idcs) -> + fprintf ppf "%s[%a]" (Ops.ptr_to_string_hum ptr prec) pp_indices idcs + | Get_global (Ops.Merge_buffer { source_node_id }, None) -> + let tn = Option.value_exn ~here:[%here] @@ Tnode.find ~id:source_node_id in + fprintf ppf "%a.merge" pp_ident tn + | Get_global (Ops.Merge_buffer { source_node_id }, Some idcs) -> + let tn = Option.value_exn ~here:[%here] @@ Tnode.find ~id:source_node_id in + fprintf ppf "@[<2>%a.merge[@,%a]@]" pp_ident tn pp_indices idcs + | Get (tn, idcs) -> fprintf ppf "@[<2>%a[@,%a]@]" pp_ident tn pp_indices idcs + | Constant c -> fprintf ppf "%.16g" c + | Embed_index idx -> pp_axis_index ppf idx + | Binop (Arg1, v1, _v2) -> pp_float ppf v1 + | Binop (Arg2, _v1, v2) -> pp_float ppf v2 + | Binop (op, v1, v2) -> + if Ops.is_binop_nice_infix op then + let infix = Ops.binop_cd_syntax op in + fprintf ppf "@[<1>(%a %s@ %a@])" pp_float v1 infix pp_float v2 + else + let prefix = Ops.binop_cd_fallback_syntax op in + fprintf ppf "@[<1>%s(%a,@ %a@])" prefix pp_float v1 pp_float v2 + | Unop (Identity, v) -> pp_float ppf v + | Unop (op, v) -> + let prefix = Ops.unop_cd_syntax op in + fprintf ppf "%s(%a)" prefix pp_float v in fprintf ppf "@,@["; fprint_function_header ?name ?static_indices () ppf; diff --git a/arrayjit/lib/low_level.mli b/arrayjit/lib/low_level.mli index c75fc6d4..1b2cc4b4 100644 --- a/arrayjit/lib/low_level.mli +++ b/arrayjit/lib/low_level.mli @@ -118,6 +118,15 @@ val fprint_function_header : val get_ident_within_code : ?no_dots:bool -> t array -> Tnode.t -> string +val fprint_cstyle : + ?name:string -> + ?static_indices:Indexing.static_symbol list -> + unit -> + Stdlib.Format.formatter -> + t -> + unit +(** Adheres more to the C syntax, outputs implicit type casts. *) + val fprint_hum : ?name:string -> ?static_indices:Indexing.static_symbol list -> @@ -125,3 +134,4 @@ val fprint_hum : Stdlib.Format.formatter -> t -> unit +(** Adheres more to the %cd syntax, does not output implicit type casts. *) diff --git a/arrayjit/lib/ops.ml b/arrayjit/lib/ops.ml index dc224567..342f87f8 100644 --- a/arrayjit/lib/ops.ml +++ b/arrayjit/lib/ops.ml @@ -112,7 +112,14 @@ let hum_typ_of_prec = function | Double_prec _ -> "double" | Void_prec -> "void" -(** {2 *** Operations ***} *) +(** {2 *** Operations ***} + + See: {{https://github.com/tinygrad/tinygrad/blob/master/tinygrad/ops.py#L123} tinygrad ops}, + {{https://docs.nvidia.com/cuda/cuda-math-api/index.html} CUDA Math API} (intrinsics). + + This is a redundant set of operations, aiming to expose hardware-supported "intrinsics", + to reduce the need for backends to pattern-match and optimize. Also for convenience. +*) (** Initializes or resets a array by filling in the corresponding numbers, at the appropriate precision. *) @@ -127,10 +134,49 @@ type init_op = | File_mapped of string * prec (** Reads the data using [Unix.openfile] and [Unix.map_file]. *) [@@deriving equal, sexp] -type binop = Add | Sub | Mul | Div | ToPowOf | Relu_gate | Arg2 | Arg1 +type binop = + | Add + | Sub + | Mul + | Div + | ToPowOf + | Relu_gate + | Arg2 + | Arg1 + | Max + | Min + | Mod + | Cmplt + | Cmpne + (* Waiting till we have a use-case to see how to sensibly introduce bitwise operations. *) + (* | Shl *) + (* | Shr *) + | Or + | And + | Threefry (** Counter-based random number generator. *) [@@deriving sexp, compare, equal] -type unop = Identity | Relu [@@deriving sexp, compare, equal] +type unop = + | Identity + | Relu + | Satur01 (** Saturate (truncate) to within the interval [[0; 1]]. *) + | Exp + | Log + | Exp2 + | Log2 + | Exp10 + | Log10 + | Sin + | Cos + | Sqrt + | Recip + | Recip_sqrt + | Neg + | Tanh_approx +[@@deriving sexp, compare, equal] + +type ternop = Where (** Where(a,b,c): if a then b else c *) | FMA (** FMA(a,b,c): (a * b) + c *) +[@@deriving sexp, compare, equal] (** Either the left-neutral or right-neutral element of the operation. Unspecified if the operation does not have a neutral element. *) @@ -139,8 +185,11 @@ let neutral_elem = function | Mul | Div -> 1. | ToPowOf -> 1. | Relu_gate -> 1. - | Arg2 -> 0. - | Arg1 -> 0. + | Max -> Float.neg_infinity + | Min -> Float.infinity + | And -> 1. + | Or -> 0. + | Arg2 | Arg1 | Mod | Cmplt | Cmpne (* | Shl | Shr *) | Threefry -> 0. let interpret_binop op v1 v2 = let open Float in @@ -153,10 +202,47 @@ let interpret_binop op v1 v2 = | Div -> v1 / v2 | ToPowOf -> if is_integer v2 then int_pow v1 @@ to_int v2 else v1 ** v2 | Relu_gate -> if v1 > 0.0 then v2 else 0.0 + | Max -> max v1 v2 + | Min -> min v1 v2 + | Mod -> v1 % v2 + | Cmplt -> if v1 < v2 then 1. else 0. + | Cmpne -> if v1 <> v2 then 1. else 0. + (* | Shl -> v1 * (int_pow 2. @@ to_int v2) *) + (* | Shr -> v1 / (int_pow 2. @@ to_int v2) *) + | Or -> if v1 <> 0. || v2 <> 0. then 1. else 0. + | And -> if v1 <> 0. && v2 <> 0. then 1. else 0. + | Threefry -> + (* FIXME: NOT IMPLEMENTED YET *) + failwith "FIXME: NOT IMPLEMENTED YET" let interpret_unop op v = let open Float in - match op with Identity -> v | Relu when v >= 0. -> v | Relu -> 0. + match op with + | Identity -> v + | Relu when v >= 0. -> v + | Relu -> 0. + | Satur01 when v <= 0. -> 0. + | Satur01 when v >= 1. -> 1. + | Satur01 -> v + | Exp -> exp v + | Log -> log v + | Exp2 -> 2. ** v + | Log2 -> log v / log 2. + | Exp10 -> 10. ** v + | Log10 -> log v / log 10. + | Sin -> sin v + | Cos -> cos v + | Sqrt -> sqrt v + | Recip -> 1. / v + | Recip_sqrt -> 1. / sqrt v + | Neg -> ~-.v + | Tanh_approx -> tanh v + +let is_binop_infix = function Threefry -> false | _ -> true + +let is_binop_nice_infix = function + | Arg1 | Arg2 | Relu_gate | Max | Min | Threefry -> false + | _ -> true let binop_cd_syntax = function | Arg1 -> "-@>" @@ -167,6 +253,36 @@ let binop_cd_syntax = function | Div -> "/" | ToPowOf -> "**" | Relu_gate -> "-?/" + | Cmplt -> "<" + | Cmpne -> "<>" + | Or -> "||" + | And -> "&&" + | Mod -> "%" + | Max -> "@^" + | Min -> "^^" + (* | Shl -> "lsl" *) + (* | Shr -> "lsr" *) + | Threefry -> "threefry" + +let binop_cd_fallback_syntax = function + | Arg1 -> "fst" + | Arg2 -> "snd" + | Add -> "add" + | Sub -> "sub" + | Mul -> "mul" + | Div -> "div" + | ToPowOf -> "pow" + | Relu_gate -> "relu_gate" + | Cmplt -> "lt" + | Cmpne -> "le" + | Or -> "orf" + | And -> "andf" + | Mod -> "modf" + | Max -> "max" + | Min -> "min" + (* | Shl -> "shlf" *) + (* | Shr -> "shrf" *) + | Threefry -> "threefry" let binop_c_syntax prec v = match (v, prec) with @@ -184,9 +300,33 @@ let binop_c_syntax prec v = invalid_arg "Ops.binop_c_syntax: ToPowOf not supported for byte/integer precisions" | Relu_gate, Byte_prec _ -> ("(", " > 0 ?", " : 0)") | Relu_gate, _ -> ("(", " > 0.0 ?", " : 0.0)") + | Max, Double_prec _ -> ("fmax(", ",", ")") + | Max, Single_prec _ -> ("fmaxf(", ",", ")") + | Max, Half_prec _ -> ("fmaxf(", ",", ")") + | Max, Byte_prec _ -> ("fmax(", ",", ")") + | Min, Double_prec _ -> ("fmin(", ",", ")") + | Min, Single_prec _ -> ("fminf(", ",", ")") + | Min, Half_prec _ -> ("fminf(", ",", ")") + | Min, Byte_prec _ -> ("fmin(", ",", ")") + | Mod, _ -> ("(", " %", ")") + | Cmplt, _ -> ("(", " <", ")") + | Cmpne, _ -> ("(", " !=", ")") + (* | Shl, Byte_prec _ -> ("(", " <<", ")") *) + (* | Shl, _ -> ("((", ") * exp2(", "))") *) + (* | Shr, Byte_prec _ -> ("(", " >>", ")") *) + (* | Shr, _ -> ("((", ") / exp2(", "))") *) + | Or, _ -> ("(", " ||", ")") + | And, _ -> ("(", " &&", ")") + | Threefry, Double_prec _ -> ("threefry(", ",", ")") + | Threefry, Single_prec _ -> ("threefryf(", ",", ")") + | Threefry, Half_prec _ -> ("threefryf(", ",", ")") + | Threefry, Byte_prec _ -> ("threefryf(", ",", ")") + +let is_assign_op = function + | Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpne | Threefry -> false + | Add | Sub | Mul | Div | ToPowOf | Relu_gate | Arg2 | Max | Min | Or | And -> true let assign_op_cd_syntax ~initialize_neutral = function - | Arg1 -> invalid_arg "Ops.assign_op_cd_syntax: Arg1 is not a %cd assignment operator" | Arg2 -> "=:" | Add when initialize_neutral -> "=:+" | Sub when initialize_neutral -> "=:-" @@ -194,12 +334,22 @@ let assign_op_cd_syntax ~initialize_neutral = function | Div when initialize_neutral -> "=:/" | ToPowOf when initialize_neutral -> "=:**" | Relu_gate when initialize_neutral -> "=:?/" + | Or when initialize_neutral -> "=:||" + | And when initialize_neutral -> "=:&&" + | Max when initialize_neutral -> "=:@^" + | Min when initialize_neutral -> "=:^^" | Add -> "=+" | Sub -> "=-" | Mul -> "=*" | Div -> "=/" | ToPowOf -> "=**" | Relu_gate -> "=?/" + | Max -> "=@^" + | Min -> "=^^" + | Or -> "=||" + | And -> "=&&" + | Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpne | Threefry -> + invalid_arg "Ops.assign_op_cd_syntax: not an assignment op" let assign_op_c_syntax = function | Arg1 -> invalid_arg "Ops.assign_op_c_syntax: Arg1 is not a C assignment operator" @@ -208,10 +358,29 @@ let assign_op_c_syntax = function | Sub -> "-=" | Mul -> "*=" | Div -> "/=" - | ToPowOf -> invalid_arg "Ops.assign_op_c_syntax: ToPowOf function is not a C assignment operator" - | Relu_gate -> invalid_arg "Ops.assign_op_c_syntax: Relu_gate is not a C assignment operator" - -let unop_cd_syntax = function Identity -> "~=" | Relu -> "?/" + | Mod -> "%=" + (* | Shl -> "<<=" *) + (* | Shr -> ">>=" *) + | _ -> invalid_arg "Ops.assign_op_c_syntax: not a C assignment operator" + +(** Note: currently we do not support unary prefix symbols. *) +let unop_cd_syntax = function + | Identity -> "id" + | Relu -> "relu" + | Satur01 -> "sat01" + | Exp -> "exp" + | Log -> "log" + | Exp2 -> "exp2" + | Log2 -> "log2" + | Exp10 -> "exp10" + | Log10 -> "log10" + | Sin -> "sin" + | Cos -> "cos" + | Sqrt -> "sqrt" + | Recip -> "recip" + | Recip_sqrt -> "recip_sqrt" + | Neg -> "neg" + | Tanh_approx -> "tanh" let unop_c_syntax prec v = match (v, prec) with @@ -219,6 +388,13 @@ let unop_c_syntax prec v = | Relu, Single_prec _ -> ("fmaxf(0.0, ", ")") | Relu, Byte_prec _ -> ("fmax(0, ", ")") | Relu, _ -> ("fmax(0.0, ", ")") + | _ -> + (* FIXME: NOT IMPLEMENTED YET *) + failwith "NOT IMPLEMENTED YET" +(* | Satur01, _ -> ("", "") | Exp, _ -> ("", "") | Log, _ -> ("", "") | Exp2, _ -> ("", "") | Log2, + _ -> ("", "") | Exp10, _ -> ("", "") | Log10, _ -> ("", "") | Sin, _ -> ("", "") | Cos, _ -> ("", + "") | Sqrt, _ -> ("", "") | Recip, _ -> ("", "") | Recip_sqrt, _ -> ("", "") | Neg, _ -> ("", "") + | Tanh_approx, _ -> ("", "") *) let c_convert_precision ~from ~to_ = match (from, to_) with diff --git a/lib/attic.mld b/lib/attic.mld index 8fff54cc..94ddffe1 100644 --- a/lib/attic.mld +++ b/lib/attic.mld @@ -499,4 +499,45 @@ let%track3_sexp finalize (ctx : context) : unit = (not (Option.exists ctx.parent ~f:(fun pc -> Map.mem pc.ctx_arrays key))) && not (Hashtbl.mem ctx.stream.device.cross_stream_candidates key) then Cu.Deviceptr.mem_free data)) +]} + +Adding constants to the representation is probably a bad idea... File Ops.ml: +{[ + +type constant = Min_noninf | Zero | One | Pi | Max_noninf | C of float +[@@deriving sexp, compare, equal] + +let float_of_c prec c = + match (c, prec) with + | _, Void_prec -> invalid_arg "Ops.float_of_c: void precision" + | Min_noninf, Double_prec _ -> ~-.Float.max_finite_value + | Min_noninf, Single_prec _ -> ~-.Float.((2. ** 127.) *. (2. -. (2. ** -23.))) + | Min_noninf, Half_prec _ -> ~-.Float.((2. ** 15.) *. (2. - (2. ** -10.))) + | Min_noninf, Byte_prec _ -> -127. + | Zero, _ -> 0. + | One, _ -> 1. + | Pi, _ -> Float.pi + | Max_noninf, Double_prec _ -> Float.max_finite_value + | Max_noninf, Single_prec _ -> Float.((2. ** 127.) *. (2. -. (2. ** -23.))) + | Max_noninf, Half_prec _ -> Float.((2. ** 15.) *. (2. - (2. ** -10.))) + | Max_noninf, Byte_prec _ -> 128. + | C c, _ -> c + +let constant_cd_syntax = function + | Min_noninf -> "min_val" + | Zero -> "0" + | One -> "1" + | Pi -> "pi" + | Max_noninf -> "max_val" + | C c -> Printf.sprintf "%g" c + +let constant_c_syntax = function + | Min_noninf -> "(-DBL_MAX)" + | Zero -> "0.0" + | One -> "1.0" + | Pi -> "M_PI" + | Max_noninf -> "DBL_MAX" + | C c when Float.(c < 0.) -> Printf.sprintf "(%g)" c + | C c -> Printf.sprintf "%g" c + ]} \ No newline at end of file diff --git a/todo-archive.md b/todo-archive.md index 05d74df8..2c3a5b38 100644 --- a/todo-archive.md +++ b/todo-archive.md @@ -3,4 +3,10 @@ (B) remove syncing from the data parallel algo: stream-to-stream syncing is now automatic {cm:2024-11-23} (A) cuda backend crashes in bin/moons_benchmark {cm:2024-11-22} (B) figure out why cuda backend parallelism slows down in later epochs {cm:2024-11-25} -clean up event hashtables when a stream or device gets synchronized {cm:2024-12-03} \ No newline at end of file +clean up event hashtables when a stream or device gets synchronized {cm:2024-12-03} +(A) Ensure that reading from host on CPU performs required synchronization {cm:2024-12-31} +Update `anatomy_of_a_backend.md` {cm:2025-01-01} +Update introductory slides {cm:2024-12-17} +Config to skip capturing logs from stdout {cm:2024-12-18} +Automatic blocking on access of a host array when a scheduled `to_host` transfer has not finished {cm:2025-01-01} +Migrate graphing to PrintBox-distributed extension {cm:2025-01-24} \ No newline at end of file diff --git a/todo.md b/todo.md index d07a941d..c54059cb 100644 --- a/todo.md +++ b/todo.md @@ -1,8 +1 @@ # This file is for tasks with a smaller granularity than issues, typically immediate tasks. -(A) Ensure that reading from host on CPU performs required synchronization {cm:2024-12-31} - -Update `anatomy_of_a_backend.md` {cm:2025-01-01} -Update introductory slides {cm:2024-12-17} -Config to skip capturing logs from stdout {cm:2024-12-18} -Automatic blocking on access of a host array when a scheduled `to_host` transfer has not finished {cm:2025-01-01} -Migrate graphing to PrintBox-distributed extension \ No newline at end of file