diff --git a/Cargo.lock b/Cargo.lock index fee9a01e..48dfd05e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1614,6 +1614,7 @@ dependencies = [ "serde", "serde_json", "service_utils", + "superposition_derives", "superposition_macros", "superposition_types", "uuid", @@ -3845,6 +3846,15 @@ dependencies = [ "serde_json", ] +[[package]] +name = "superposition_derives" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "superposition_macros" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index d56fe538..30d8391a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ members = [ "examples/cac_client_integration_example", "examples/superposition-demo-app", "crates/superposition_macros", + "crates/superposition_derives", ] [[workspace.metadata.leptos]] diff --git a/cog.toml b/cog.toml index 09b7c1e3..b03332ce 100644 --- a/cog.toml +++ b/cog.toml @@ -9,7 +9,7 @@ pre_bump_hooks = [] pre_package_bump_hooks = [ "echo 'upgrading {{package}}' to {{version}}", - "cargo set-version --package {{package}} {{version}}" + "cargo set-version --package {{package}} {{version}}", ] post_package_bump_hooks = [] @@ -36,3 +36,4 @@ superposition_types = { path = "crates/superposition_types" } js_client = { path = "clients/js" } haskell_client = { path = "clients/haskell" } superposition_macros = { path = "crates/superposition_macros" } +superposition_derives = { path = "crates/superposition_derives" } diff --git a/crates/experimentation_platform/Cargo.toml b/crates/experimentation_platform/Cargo.toml index 1f3e5099..7fdb0c49 100644 --- a/crates/experimentation_platform/Cargo.toml +++ b/crates/experimentation_platform/Cargo.toml @@ -18,6 +18,7 @@ reqwest = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } service_utils = { path = "../service_utils" } +superposition_derives = { path = "../superposition_derives" } superposition_macros = { path = "../superposition_macros" } superposition_types = { path = "../superposition_types", features = ["result"] } uuid = { workspace = true } diff --git a/crates/experimentation_platform/src/api/experiments/handlers.rs b/crates/experimentation_platform/src/api/experiments/handlers.rs index 1f27fe67..f07830c9 100644 --- a/crates/experimentation_platform/src/api/experiments/handlers.rs +++ b/crates/experimentation_platform/src/api/experiments/handlers.rs @@ -15,10 +15,9 @@ use diesel::{ r2d2::{ConnectionManager, PooledConnection}, ExpressionMethods, PgConnection, QueryDsl, RunQueryDsl, }; - -use service_utils::helpers::{construct_request_headers, generate_snowflake_id, request}; - use reqwest::{Method, Response, StatusCode}; +use serde_json::{json, Map, Value}; +use service_utils::helpers::{construct_request_headers, generate_snowflake_id, request}; use service_utils::service::types::{ AppHeader, AppState, CustomHeaders, DbConnection, Tenant, }; @@ -32,23 +31,18 @@ use super::{ validate_experiment, validate_override_keys, }, types::{ - AuditQueryFilters, ConcludeExperimentRequest, ContextAction, ContextBulkResponse, - ContextMoveReq, ContextPutReq, ExperimentCreateRequest, ExperimentCreateResponse, - ExperimentResponse, ExperimentsResponse, ListFilters, OverrideKeysUpdateRequest, - RampRequest, Variant, + ApplicableVariantsQuery, AuditQueryFilters, ConcludeExperimentRequest, + ContextAction, ContextBulkResponse, ContextMoveReq, ContextPutReq, + ExperimentCreateRequest, ExperimentCreateResponse, ExperimentResponse, + ExperimentsResponse, ListFilters, OverrideKeysUpdateRequest, RampRequest, }, }; -use crate::{ - api::experiments::types::ApplicableVariantsQuery, - db::{ - models::{EventLog, Experiment, ExperimentStatusType}, - schema::{event_log::dsl as event_log, experiments::dsl as experiments}, - }, +use crate::db::{ + models::{EventLog, Experiment, ExperimentStatusType, Variant, Variants}, + schema::{event_log::dsl as event_log, experiments::dsl as experiments}, }; -use serde_json::{json, Map, Value}; - pub fn endpoints(scope: Scope) -> Scope { scope .service(get_audit_logs) @@ -301,7 +295,7 @@ async fn create( traffic_percentage: 0, status: ExperimentStatusType::CREATED, context: Value::Object(req.context.clone().into_inner().into()), - variants: serde_json::to_value(variants).unwrap(), + variants: Variants::new(variants), last_modified_by: user.get_email(), chosen_variant: None, }; @@ -374,17 +368,9 @@ pub async fn conclude( })?; let mut operations: Vec = vec![]; - let experiment_variants: Vec = serde_json::from_value(experiment.variants) - .map_err(|err| { - log::error!( - "failed parse eixisting experiment variant while concluding with error: {}", - err - ); - unexpected_error!("Something went wrong, failed to conclude experiment") - })?; let mut is_valid_winner_variant = false; - for variant in experiment_variants { + for variant in experiment.variants.into_inner() { let context_id = variant.context_id.ok_or_else(|| { log::error!("context id not available for variant {:?}", variant.id); unexpected_error!("Something went wrong, failed to conclude experiment") @@ -517,10 +503,7 @@ async fn get_applicable_variants( for exp in experiments { if let Some(v) = decide_variant( exp.traffic_percentage as u8, - serde_json::from_value(exp.variants).map_err(|e| { - log::error!("Unable to parse variants from DB {e}"); - unexpected_error!("Something went wrong.") - })?, + exp.variants.into_inner(), query_data.toss, ) .map_err(|e| { @@ -639,15 +622,7 @@ async fn ramp( let old_traffic_percentage = experiment.traffic_percentage as u8; let new_traffic_percentage = req.traffic_percentage as u8; - let experiment_variants: Vec = serde_json::from_value(experiment.variants) - .map_err(|e| { - log::error!( - "failed to parse existing experiment variants while ramping {}", - e - ); - unexpected_error!("Something went wrong, failed to ramp traffic percentage") - })?; - let variants_count = experiment_variants.len() as u8; + let variants_count = experiment.variants.into_inner().len() as u8; let max = 100 / variants_count; if matches!(experiment.status, ExperimentStatusType::CONCLUDED) { @@ -711,11 +686,7 @@ async fn update_overrides( )); } - let experiment_variants: Vec = serde_json::from_value(experiment.variants) - .map_err(|err| { - log::error!("failed to parse exisiting variants with error {}", err); - unexpected_error!("Something went wrong, failed to update experiment") - })?; + let experiment_variants: Vec = experiment.variants.into_inner(); let id_to_existing_variant: HashMap = HashMap::from_iter( experiment_variants diff --git a/crates/experimentation_platform/src/api/experiments/helpers.rs b/crates/experimentation_platform/src/api/experiments/helpers.rs index 9658f89d..cacc76a9 100644 --- a/crates/experimentation_platform/src/api/experiments/helpers.rs +++ b/crates/experimentation_platform/src/api/experiments/helpers.rs @@ -1,5 +1,3 @@ -use super::types::{Variant, VariantType}; -use crate::db::models::{Experiment, ExperimentStatusType}; use diesel::pg::PgConnection; use diesel::{BoolExpressionMethods, ExpressionMethods, QueryDsl, RunQueryDsl}; use serde_json::{Map, Value}; @@ -9,6 +7,8 @@ use std::collections::HashSet; use superposition_macros::{bad_argument, unexpected_error}; use superposition_types::{result as superposition, Condition, Exp, Overrides}; +use crate::db::models::{Experiment, ExperimentStatusType, Variant, VariantType}; + pub fn check_variant_types(variants: &Vec) -> superposition::Result<()> { let mut experimental_variant_cnt = 0; let mut control_variant_cnt = 0; diff --git a/crates/experimentation_platform/src/api/experiments/types.rs b/crates/experimentation_platform/src/api/experiments/types.rs index c968c03c..5737c5b4 100644 --- a/crates/experimentation_platform/src/api/experiments/types.rs +++ b/crates/experimentation_platform/src/api/experiments/types.rs @@ -6,22 +6,7 @@ use serde_json::{Map, Value}; use service_utils::helpers::deserialize_stringified_list; use superposition_types::{Condition, Exp, Overrides}; -use crate::db::models::{self, ExperimentStatusType}; - -#[derive(Deserialize, Serialize, Clone, PartialEq, Debug)] -pub enum VariantType { - CONTROL, - EXPERIMENTAL, -} - -#[derive(Deserialize, Serialize, Clone)] -pub struct Variant { - pub id: String, - pub variant_type: VariantType, - pub context_id: Option, - pub override_id: Option, - pub overrides: Exp, -} +use crate::db::models::{self, ExperimentStatusType, Variant}; /********** Experiment Create Req Types ************/ @@ -62,7 +47,7 @@ pub struct ExperimentResponse { pub traffic_percentage: i32, pub context: Value, - pub variants: Value, + pub variants: Vec, pub last_modified_by: String, pub chosen_variant: Option, } @@ -81,7 +66,7 @@ impl From for ExperimentResponse { traffic_percentage: experiment.traffic_percentage, context: experiment.context, - variants: experiment.variants, + variants: experiment.variants.into_inner(), last_modified_by: experiment.last_modified_by, chosen_variant: experiment.chosen_variant, } diff --git a/crates/experimentation_platform/src/db/models.rs b/crates/experimentation_platform/src/db/models.rs index 830f9042..0f2a9fa1 100644 --- a/crates/experimentation_platform/src/db/models.rs +++ b/crates/experimentation_platform/src/db/models.rs @@ -2,10 +2,13 @@ use crate::db::schema::*; use chrono::{DateTime, NaiveDateTime, Utc}; use diesel::{ - query_builder::QueryId, Insertable, Queryable, QueryableByName, Selectable, + deserialize::FromSqlRow, expression::AsExpression, query_builder::QueryId, + sql_types::Json, Insertable, Queryable, QueryableByName, Selectable, }; use serde::{Deserialize, Serialize}; use serde_json::Value; +use superposition_derives::{JsonFromSql, JsonToSql}; +use superposition_types::{Exp, Overrides}; #[derive( Debug, @@ -25,6 +28,39 @@ pub enum ExperimentStatusType { INPROGRESS, } +#[derive(Deserialize, Serialize, Clone, PartialEq, Debug)] +pub enum VariantType { + CONTROL, + EXPERIMENTAL, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Variant { + pub id: String, + pub variant_type: VariantType, + #[serde(skip_serializing_if = "Option::is_none")] + pub context_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub override_id: Option, + pub overrides: Exp, +} + +#[derive( + Debug, Clone, Serialize, Deserialize, AsExpression, FromSqlRow, JsonFromSql, JsonToSql, +)] +#[diesel(sql_type = Json)] +pub struct Variants(Vec); + +impl Variants { + pub fn new(data: Vec) -> Self { + Self(data) + } + + pub fn into_inner(self) -> Vec { + self.0 + } +} + #[derive(QueryableByName, Queryable, Selectable, Insertable, Serialize, Clone, Debug)] #[diesel(check_for_backend(diesel::pg::Pg))] #[diesel(primary_key(id))] @@ -40,7 +76,7 @@ pub struct Experiment { pub traffic_percentage: i32, pub context: Value, - pub variants: Value, + pub variants: Variants, pub last_modified_by: String, pub chosen_variant: Option, } diff --git a/crates/experimentation_platform/tests/experimentation_tests.rs b/crates/experimentation_platform/tests/experimentation_tests.rs index 4b317dd2..738f6b3f 100644 --- a/crates/experimentation_platform/tests/experimentation_tests.rs +++ b/crates/experimentation_platform/tests/experimentation_tests.rs @@ -1,6 +1,8 @@ use chrono::Utc; use experimentation_platform::api::experiments::helpers; -use experimentation_platform::db::models::{Experiment, ExperimentStatusType}; +use experimentation_platform::db::models::{ + Experiment, ExperimentStatusType, Variant, Variants, +}; use serde_json::{json, Map, Value}; use service_utils::helpers::extract_dimensions; use service_utils::service::types::ExperimentationFlags; @@ -52,7 +54,7 @@ fn experiment_gen( override_keys: &[String], context: &Map, status: ExperimentStatusType, - variants: &Value, + variants: &Vec, ) -> Experiment { Experiment { id: 123456789, @@ -66,7 +68,7 @@ fn experiment_gen( override_keys: override_keys.to_vec(), status, context: json!(context.clone()), - variants: variants.clone(), + variants: Variants::new(variants.clone()), chosen_variant: None, } } @@ -231,7 +233,7 @@ fn test_is_valid_experiment_no_restrictions_overlapping_experiment( &["key1".to_string(), "key2".to_string()], &experiment_context, ExperimentStatusType::CREATED, - &json!(""), + &vec![], )]; assert_eq!( @@ -271,7 +273,7 @@ fn test_is_valid_experiment_no_restrictions_non_overlapping_experiment( Dimensions::Client("testclient2".to_string()), ]), ExperimentStatusType::CREATED, - &json!(""), + &vec![], )]; assert_eq!( @@ -310,7 +312,7 @@ fn test_is_valid_experiment_restrict_same_keys_overlapping_ctx_overlapping_exper &experiment_override_keys, &experiment_context, ExperimentStatusType::CREATED, - &json!(""), + &vec![], )]; assert_eq!( @@ -347,7 +349,7 @@ fn test_is_valid_experiment_restrict_same_keys_overlapping_ctx_overlapping_exper &["key1".to_string(), "key3".to_string()], &experiment_context, ExperimentStatusType::CREATED, - &json!(""), + &vec![], )]; assert_eq!( @@ -384,7 +386,7 @@ fn test_is_valid_experiment_restrict_same_keys_overlapping_ctx_overlapping_exper &["key3".to_string(), "key4".to_string()], &experiment_context, ExperimentStatusType::CREATED, - &json!(""), + &vec![], )]; assert_eq!( @@ -423,7 +425,7 @@ fn test_is_valid_experiment_restrict_diff_keys_overlapping_ctx_overlapping_exper &experiment_override_keys, &experiment_context, ExperimentStatusType::CREATED, - &json!(""), + &vec![], )]; assert_eq!( @@ -460,7 +462,7 @@ fn test_is_valid_experiment_restrict_diff_keys_overlapping_ctx_overlapping_exper &["key1".to_string(), "key3".to_string()], &experiment_context, ExperimentStatusType::CREATED, - &json!(""), + &vec![], )]; assert_eq!( @@ -497,7 +499,7 @@ fn test_is_valid_experiment_restrict_diff_keys_overlapping_ctx_overlapping_exper &["key3".to_string(), "key4".to_string()], &experiment_context, ExperimentStatusType::CREATED, - &json!(""), + &vec![], )]; assert_eq!( @@ -539,7 +541,7 @@ fn test_is_valid_experiment_restrict_same_keys_non_overlapping_ctx_non_overlappi Dimensions::Client("testclient2".to_string()), ]), ExperimentStatusType::CREATED, - &json!(""), + &vec![], )]; assert_eq!( @@ -579,7 +581,7 @@ fn test_is_valid_experiment_restrict_same_keys_non_overlapping_ctx_non_overlappi Dimensions::Client("testclient2".to_string()), ]), ExperimentStatusType::CREATED, - &json!(""), + &vec![], )]; assert_eq!( @@ -619,7 +621,7 @@ fn test_is_valid_experiment_restrict_same_keys_non_overlapping_ctx_non_overlappi Dimensions::Client("testclient2".to_string()), ]), ExperimentStatusType::CREATED, - &json!(""), + &vec![], )]; assert_eq!( diff --git a/crates/superposition_derives/Cargo.toml b/crates/superposition_derives/Cargo.toml new file mode 100644 index 00000000..75297abe --- /dev/null +++ b/crates/superposition_derives/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "superposition_derives" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +syn = { version = "2.0", features = ["full"] } +quote = "1.0" +proc-macro2 = "1.0" + +[lints] +workspace = true + +[lib] +proc-macro = true diff --git a/crates/superposition_derives/src/lib.rs b/crates/superposition_derives/src/lib.rs new file mode 100644 index 00000000..36302a9f --- /dev/null +++ b/crates/superposition_derives/src/lib.rs @@ -0,0 +1,48 @@ +extern crate proc_macro; +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, DeriveInput}; + +/// Implements `FromSql` trait for converting `Json` type to the type for `Pg` backend +/// +#[proc_macro_derive(JsonFromSql)] +pub fn json_from_sql_derive(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let name = input.ident; + + let expanded = quote! { + impl diesel::deserialize::FromSql for #name { + fn from_sql(bytes: diesel::pg::PgValue<'_>) -> diesel::deserialize::Result { + let value = >::from_sql(bytes)?; + Ok(serde_json::from_value(value)?) + } + } + }; + + TokenStream::from(expanded) +} + +/// Implements `ToSql` trait for converting the typed data to `Json` type for `Pg` backend +/// +#[proc_macro_derive(JsonToSql)] +pub fn json_to_sql_derive(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let name = input.ident; + + let expanded = quote! { + impl diesel::serialize::ToSql for #name { + fn to_sql<'b>( + &'b self, + out: &mut diesel::serialize::Output<'b, '_, diesel::pg::Pg>, + ) -> diesel::serialize::Result { + let value = serde_json::to_value(self)?; + >::to_sql( + &value, + &mut out.reborrow(), + ) + } + } + }; + + TokenStream::from(expanded) +}