Skip to content

Commit

Permalink
Merge pull request #51 from ReactiveBayes/dev-combination-callbacks
Browse files Browse the repository at this point in the history
Add optional callbacks to the collectLatest and combineLatestUpdates
  • Loading branch information
bvdmitri authored Apr 10, 2024
2 parents 9d331b2 + 336af9c commit 9191753
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 27 deletions.
41 changes: 26 additions & 15 deletions src/observable/collected.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ export collectLatest
import Base: show

"""
collectLatest(sources::S, mappingFn::F = copy) where { S, F }
collectLatest(::Type{T}, ::Type{R}, sources::S, mappingFn::F = copy)
collectLatest(sources::S, mappingFn::F = copy, callbackFn::C = nothing)
collectLatest(::Type{T}, ::Type{R}, sources::S, mappingFn::F = copy, callbackFn::C = nothing)
Collects values from multible Observables and emits it in one single array every time each inner Observable has a new value.
Reemits errors from inner observables. Completes when all inner observables completes.
# Arguments
- `sources`: input sources
- `mappingFn`: optional mappingFn applied to an array of emited values, `copy` by default, should return a Vector
- `callbackFn`: optional callback function, which is called right after `mappingFn` has been evaluated, accepts the state of the inner actor and the computed value, `nothing` by default
Note: `collectLatest` completes immediately if `sources` are empty.
Expand All @@ -37,17 +38,17 @@ subscribe!(collected, logger())
See also: [`Subscribable`](@ref), [`subscribe!`](@ref), [`combineLatest`](@ref)
"""
function collectLatest(sources::S, mappingFn::F = copy) where { S, F }
function collectLatest(sources::S, mappingFn::F = copy, callbackFn::C = nothing) where { S, F, C }
T = union_type(sources)
R = similar_typeof(sources, T)
return CollectLatestObservable{T, S, R, F}(sources, mappingFn)
return CollectLatestObservable{T, S, R, F, C}(sources, mappingFn, callbackFn)
end

collectLatest(::Type{T}, ::Type{R}, sources::S, mappingFn::F = copy) where { T, R, S, F } = CollectLatestObservable{T, S, R, F}(sources, mappingFn)
collectLatest(::Type{T}, ::Type{R}, sources::S, mappingFn::F = copy, callbackFn::C = nothing) where { T, R, S, F, C } = CollectLatestObservable{T, S, R, F, C}(sources, mappingFn, callbackFn)

##

struct CollectLatestObservableWrapper{L, A, S, B, T, F}
struct CollectLatestObservableWrapper{L, A, S, B, T, F, C}
actor :: A
storage :: S

Expand All @@ -56,25 +57,30 @@ struct CollectLatestObservableWrapper{L, A, S, B, T, F}
ustatus :: B # Updates status
subscriptions :: T
mappingFn :: F
callbackFn :: C

CollectLatestObservableWrapper{L, A, S, B, T, F}(actor::A, storage::S, cstatus::B, vstatus::B, ustatus::B, subscriptions::T, mappingFn::F) where {L, A, S, B, T, F} = begin
return new(actor, storage, cstatus, vstatus, ustatus, subscriptions, mappingFn)
CollectLatestObservableWrapper{L, A, S, B, T, F, C}(actor::A, storage::S, cstatus::B, vstatus::B, ustatus::B, subscriptions::T, mappingFn::F, callbackFn::C) where {L, A, S, B, T, F, C} = begin
return new(actor, storage, cstatus, vstatus, ustatus, subscriptions, mappingFn, callbackFn)
end
end

function CollectLatestObservableWrapper(::Type{L}, actor::A, storage::S, mappingFn::F) where { L, A, S, F }
function CollectLatestObservableWrapper(::Type{L}, actor::A, storage::S, mappingFn::F, callbackFn::C) where { L, A, S, F, C }
nsize = size(storage)
cstatus = falses(nsize)
vstatus = falses(nsize)
ustatus = falses(nsize)
subscriptions = fill!(similar(storage, Teardown), voidTeardown)
return CollectLatestObservableWrapper{L, A, S, typeof(cstatus), typeof(subscriptions), F}(actor, storage, cstatus, vstatus, ustatus, subscriptions, mappingFn)
return CollectLatestObservableWrapper{L, A, S, typeof(cstatus), typeof(subscriptions), F, C}(actor, storage, cstatus, vstatus, ustatus, subscriptions, mappingFn, callbackFn)
end

cstatus(wrapper::CollectLatestObservableWrapper, index::CartesianIndex) = @inbounds wrapper.cstatus[index]
vstatus(wrapper::CollectLatestObservableWrapper, index::CartesianIndex) = @inbounds wrapper.vstatus[index]
ustatus(wrapper::CollectLatestObservableWrapper, index::CartesianIndex) = @inbounds wrapper.ustatus[index]

fill_cstatus!(wrapper::CollectLatestObservableWrapper, value) = fill!(wrapper.cstatus, value)
fill_vstatus!(wrapper::CollectLatestObservableWrapper, value) = fill!(wrapper.vstatus, value)
fill_ustatus!(wrapper::CollectLatestObservableWrapper, value) = fill!(wrapper.ustatus, value)

dispose(wrapper::CollectLatestObservableWrapper) = begin fill!(wrapper.cstatus, true); foreach(s -> unsubscribe!(s), wrapper.subscriptions) end

struct CollectLatestObservableInnerActor{L, I <: CartesianIndex, W} <: Actor{L}
Expand All @@ -94,7 +100,11 @@ function next_received!(wrapper::CollectLatestObservableWrapper, data, index::Ca
@inbounds wrapper.ustatus[index] = true
if all(wrapper.vstatus) && !all(wrapper.cstatus)
unsafe_copyto!(wrapper.vstatus, 1, wrapper.cstatus, 1, length(wrapper.vstatus))
next!(wrapper.actor, wrapper.mappingFn(wrapper.storage))
value = wrapper.mappingFn(wrapper.storage)
next!(wrapper.actor, value)
if !isnothing(wrapper.callbackFn)
wrapper.callbackFn(wrapper, value)
end
end
end

Expand All @@ -120,15 +130,16 @@ end

##

@subscribable struct CollectLatestObservable{T, S, R, F} <: Subscribable{R}
sources :: S
mappingFn :: F
@subscribable struct CollectLatestObservable{T, S, R, F, C} <: Subscribable{R}
sources :: S
mappingFn :: F
callbackFn :: C
end

function on_subscribe!(observable::CollectLatestObservable{L}, actor::A) where { L, A }
sources = observable.sources
storage = similar(sources, L)
wrapper = CollectLatestObservableWrapper(L, actor, storage, observable.mappingFn)
wrapper = CollectLatestObservableWrapper(L, actor, storage, observable.mappingFn, observable.callbackFn)
W = typeof(wrapper)

if length(sources) !== 0
Expand Down
36 changes: 24 additions & 12 deletions src/observable/combined_updates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ See also: [`Subscribable`](@ref), [`subscribe!`](@ref), [`PushEach`](@ref), [`Pu
"""
function combineLatestUpdates end

combineLatestUpdates(; strategy = PushEach()) = error("combineLatestUpdates operator expects at least one inner observable on input")
combineLatestUpdates(args...; strategy = PushEach()) = combineLatestUpdates(args, strategy)
combineLatestUpdates(sources::S, strategy::G = PushEach()) where { S <: Tuple, G } = CombineLatestUpdatesObservable{S, G}(sources, strategy)
combineLatestUpdates(; strategy = PushEach()) = error("combineLatestUpdates operator expects at least one inner observable on input")
combineLatestUpdates(args...; strategy = PushEach()) = combineLatestUpdates(args, strategy)
combineLatestUpdates(sources::S, strategy::G = PushEach(), ::Type{R} = S, mappingFn::F = identity, callbackFn::C = nothing) where { S <: Tuple, R, G, F, C } = CombineLatestUpdatesObservable{R, S, G, F, C}(sources, strategy, mappingFn, callbackFn)

##

Expand All @@ -37,32 +37,42 @@ on_complete!(actor::CombineLatestUpdatesInnerActor{L, W}) where { L, W } =

##

struct CombineLatestUpdatesActorWrapper{S, A, G, U}
struct CombineLatestUpdatesActorWrapper{S, A, G, U, F, C}
sources :: S
actor :: A
nsize :: Int
strategy :: G # Push update strategy
updates :: U # Updates
subscriptions :: Vector{Teardown}
mappingFn :: F
callbackFn :: C
end

function CombineLatestUpdatesActorWrapper(sources::S, actor::A, strategy::G) where { S, A, G }
function CombineLatestUpdatesActorWrapper(sources::S, actor::A, strategy::G, mappingFn::F, callbackFn::C) where { S, A, G, F, C }
updates = getustorage(S)
nsize = length(sources)
subscriptions = fill!(Vector{Teardown}(undef, nsize), voidTeardown)
return CombineLatestUpdatesActorWrapper(sources, actor, nsize, strategy, updates, subscriptions)
return CombineLatestUpdatesActorWrapper(sources, actor, nsize, strategy, updates, subscriptions, mappingFn, callbackFn)
end

push_update!(wrapper::CombineLatestUpdatesActorWrapper) = push_update!(wrapper.nsize, wrapper.updates, wrapper.strategy)

dispose(wrapper::CombineLatestUpdatesActorWrapper) = begin fill_cstatus!(wrapper.updates, true); foreach(s -> unsubscribe!(s), wrapper.subscriptions) end

fill_cstatus!(wrapper::CombineLatestUpdatesActorWrapper, value) = fill_cstatus!(wrapper.updates, value)
fill_vstatus!(wrapper::CombineLatestUpdatesActorWrapper, value) = fill_vstatus!(wrapper.updates, value)
fill_ustatus!(wrapper::CombineLatestUpdatesActorWrapper, value) = fill_ustatus!(wrapper.updates, value)

function next_received!(wrapper::CombineLatestUpdatesActorWrapper, data, index::Int)
vstatus!(wrapper.updates, index, true)
ustatus!(wrapper.updates, index, true)
if all_vstatus(wrapper.updates) && !all_cstatus(wrapper.updates)
push_update!(wrapper)
next!(wrapper.actor, wrapper.sources)
value = wrapper.mappingFn(wrapper.sources)
next!(wrapper.actor, value)
if !isnothing(wrapper.callbackFn)
wrapper.callbackFn(wrapper, value)
end
end
end

Expand All @@ -88,15 +98,17 @@ end

##

@subscribable struct CombineLatestUpdatesObservable{S, G} <: Subscribable{S}
sources :: S
strategy :: G
@subscribable struct CombineLatestUpdatesObservable{R, S, G, F, C} <: Subscribable{R}
sources :: S
strategy :: G
mappingFn :: F
callbackFn :: C
end

getrecent(observable::CombineLatestUpdatesObservable) = getrecent(observable.sources)

function on_subscribe!(observable::CombineLatestUpdatesObservable{S, G}, actor::A) where { S, G, A }
wrapper = CombineLatestUpdatesActorWrapper(observable.sources, actor, observable.strategy)
function on_subscribe!(observable::CombineLatestUpdatesObservable, actor)
wrapper = CombineLatestUpdatesActorWrapper(observable.sources, actor, observable.strategy, observable.mappingFn, observable.callbackFn)

__combine_latest_updates_unrolled_fill_subscriptions!(observable.sources, wrapper)

Expand Down
46 changes: 46 additions & 0 deletions test/observable/test_observable_collect_latest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,52 @@ include("../test_helpers.jl")
unsubscribe!(subscription)
end

@testset begin
source1 = Subject(Int)
source2 = Subject(Int)

callbackCalled = []
callbackFn = (wrapper, value) -> begin
# We reset the state of the `vstatus`
if isequal(value, "2")
Rocket.fill_vstatus!(wrapper, true)
push!(callbackCalled, true)
else
push!(callbackCalled, false)
end
end

combined = collectLatest(Int, String, [ source1, source2 ], (values) -> string(sum(values)), callbackFn)
values = []
subscription = subscribe!(combined, (value) -> push!(values, value))

@test values == []
@test callbackCalled == []
next!(source1, 0)
@test values == []
@test callbackCalled == []
next!(source2, 0)
@test values == ["0"]
@test callbackCalled == [false]

next!(source1, 1)
@test values == ["0"]
@test callbackCalled == [false]
next!(source2, 1)
@test values == ["0", "2"]
@test callbackCalled == [false, true]

next!(source1, 2)
@test values == ["0", "2", "3"] # this is hapenning because the callback should have been called
@test callbackCalled == [false, true, false]
next!(source1, 2)
@test values == ["0", "2", "3"]
@test callbackCalled == [false, true, false]
next!(source2, 2)
@test values == ["0", "2", "3", "4"]
@test callbackCalled == [false, true, false, false]
end

end

end
46 changes: 46 additions & 0 deletions test/observable/test_observable_combine_updates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,52 @@ include("../test_helpers.jl")
unsubscribe!(subscription)
end

@testset begin
source1 = RecentSubject(Int)
source2 = RecentSubject(Int)

callbackCalled = []
callbackFn = (wrapper, value) -> begin
# We reset the state of the `vstatus`
if isequal(value, "2")
Rocket.fill_vstatus!(wrapper, true)
push!(callbackCalled, true)
else
push!(callbackCalled, false)
end
end

combined = combineLatestUpdates((source1, source2), PushNew(), String, (sources) -> string(sum(Rocket.getrecent.(sources))), callbackFn)
values = []
subscription = subscribe!(combined, (value) -> push!(values, value))

@test values == []
@test callbackCalled == []
next!(source1, 0)
@test values == []
@test callbackCalled == []
next!(source2, 0)
@test values == ["0"]
@test callbackCalled == [false]

next!(source1, 1)
@test values == ["0"]
@test callbackCalled == [false]
next!(source2, 1)
@test values == ["0", "2"]
@test callbackCalled == [false, true]

next!(source1, 2)
@test values == ["0", "2", "3"] # this is hapenning because the callback should have been called
@test callbackCalled == [false, true, false]
next!(source1, 2)
@test values == ["0", "2", "3"]
@test callbackCalled == [false, true, false]
next!(source2, 2)
@test values == ["0", "2", "3", "4"]
@test callbackCalled == [false, true, false, false]
end

end

end

0 comments on commit 9191753

Please sign in to comment.