Skip to content

Commit

Permalink
fix visual search
Browse files Browse the repository at this point in the history
  • Loading branch information
itsdfish committed May 18, 2023
1 parent 3b87647 commit 032ed3e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Background_Tutorials/ACTRModels_Tutorial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ version = "17.4.0+0"
"""

# ╔═╡ Cell order:
# ╟─2bf04f4e-67df-45bc-a183-6b7b8c12cfe0
# ╠═2bf04f4e-67df-45bc-a183-6b7b8c12cfe0
# ╟─6c65c466-ffa6-4559-9668-85324ce39a2c
# ╟─1500b6f4-bcdc-48a5-a3ee-a62fe581caeb
# ╠═04476fd1-e79c-4cfe-bcda-ef3cb519683b
Expand Down
10 changes: 4 additions & 6 deletions Tutorial_Models/Unit10/Visual_Search/Visual_Search_Model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ function logpdf(d::VisualSearch, data::Vector{Vector{Fixation}})
return LL
end

function simulate(;n_trials, parms...)
# create an experiment object containing experiment parameters
experiment = Experiment(;n_trials)
function simulate(experiment; parms...)
# generate stimuli, consisting of visual array and target for each trial
stimuli = map(_->generate_stimuli(experiment), 1:n_trials)
stimuli = map(_-> generate_stimuli(experiment), 1:experiment.n_trials)
# generate data for each trial
run_condition!(experiment, stimuli; parms...);
# return stimuli and fixation data
Expand Down Expand Up @@ -82,13 +80,13 @@ function loglikelihood_fixation(ex, actr, visicon, fixation)
end


function computeLL(stimuli, all_fixations; topdown_weight)
function computeLL(stimuli, all_fixations; topdown_weight, parms...)
ex = Experiment()
LL = 0.0
# copy and reset fields in visual array
_stimuli = set_stimuli(stimuli, topdown_weight)
for i in 1:length(all_fixations)
LL += loglikelihood_trial(ex, _stimuli[i][1], _stimuli[i][3], all_fixations[i]; topdown_weight)
LL += loglikelihood_trial(ex, _stimuli[i][1], _stimuli[i][3], all_fixations[i]; topdown_weight, parms...)
end
return LL
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,13 @@ The following code will generate fixations for 10 trials using a topdown activat
begin
# number of trials
n_trials = 10
# fixed parameters
fixed_parms = (noise=false, rnd_time=false)
# weight for topdown activation
topdown_weight = 0.66
# create an experiment object containing experiment parameters
_experiment = Experiment(;n_trials)
stimuli,all_fixations = simulate(_experiment ; topdown_weight);
stimuli,all_fixations = simulate(_experiment ; topdown_weight, fixed_parms...);
end

# ╔═╡ 232d202c-6585-440f-b936-6a6ca6ccb7b4
Expand Down Expand Up @@ -596,6 +598,7 @@ Nyamsuren, E., & Taatgen, N. A. (2013). Pre-attentive and attentive vision modul
Moran, R., Zehetleitner, M., Müller, H. J., & Usher, M. (2013). Competitive guided search: Meeting the challenge of benchmark RT distributions. Journal of Vision, 13(8), 24-24.
"""


# ╔═╡ 00000000-0000-0000-0000-000000000001
PLUTO_PROJECT_TOML_CONTENTS = """
[deps]
Expand All @@ -615,7 +618,7 @@ PlutoUI = "~0.7.51"
Revise = "~3.5.2"
StatsPlots = "~0.15.5"
Turing = "~0.25.1"
VisualSearchACTR = "~0.3.5"
VisualSearchACTR = "~0.3.6"
"""

# ╔═╡ 00000000-0000-0000-0000-000000000002
Expand All @@ -624,14 +627,19 @@ PLUTO_MANIFEST_TOML_CONTENTS = """
julia_version = "1.9.0"
manifest_format = "2.0"
project_hash = "a7bae6e53c8df1f66126fb85730b67cbd8161048"
project_hash = "9e3aa3a4f1aa4594cb55a9dd337c8920837a32b9"
[[deps.ACTRModels]]
deps = ["ConcreteStructs", "Distributions", "Parameters", "Pkg", "PrettyTables", "Random", "Reexport", "SafeTestsets", "SequentialSamplingModels", "StatsBase", "StatsFuns", "Test"]
git-tree-sha1 = "ca795df6568ce4d3f386e8c775f798e6e7b71565"
uuid = "c095b0ea-a6ca-5cbd-afed-dbab2e976880"
version = "0.11.1"
[[deps.ADTypes]]
git-tree-sha1 = "dcfdf328328f2645531c4ddebf841228aef74130"
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
version = "0.1.3"
[[deps.ATK_jll]]
deps = ["Artifacts", "Glib_jll", "JLLWrappers", "Libdl"]
git-tree-sha1 = "a2ecb68d240333fe63bea1965b71884e98c2d0f0"
Expand Down Expand Up @@ -1608,9 +1616,9 @@ version = "2.1.1"
[[deps.LogDensityProblemsAD]]
deps = ["DocStringExtensions", "LogDensityProblems", "Requires", "SimpleUnPack"]
git-tree-sha1 = "5f219f583a399381dc147b984648429bf8c3fc6a"
git-tree-sha1 = "b726468867eb032ebd7aba0337213eb18ed0566b"
uuid = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
version = "1.4.2"
version = "1.4.3"
[deps.LogDensityProblemsAD.extensions]
LogDensityProblemsADEnzymeExt = "Enzyme"
Expand Down Expand Up @@ -2111,10 +2119,10 @@ uuid = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
version = "0.0.1"
[[deps.SciMLBase]]
deps = ["ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces"]
git-tree-sha1 = "e803672f8d58e9937f59923dd3b159c9b7e1838b"
deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces"]
git-tree-sha1 = "75552338dda481baeb9b9e171f73ecd0171e8f34"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
version = "1.92.0"
version = "1.92.2"
[[deps.SciMLOperators]]
deps = ["ArrayInterface", "DocStringExtensions", "Lazy", "LinearAlgebra", "Setfield", "SparseArrays", "StaticArraysCore", "Tricks"]
Expand Down Expand Up @@ -2410,9 +2418,9 @@ version = "0.2.0"
[[deps.VisualSearchACTR]]
deps = ["ACTRModels", "ArgCheck", "Cairo", "ColorSchemes", "Colors", "ConcreteStructs", "Crayons", "Distributions", "Graphics", "Gtk", "Random", "Reexport", "SafeTestsets", "Statistics", "StatsBase", "StatsPlots"]
git-tree-sha1 = "7ef983cfd94f94b65e09ba5f49a85ad02d2400f4"
git-tree-sha1 = "26c417ccb63ca2c047c68a08c2d6972e86576075"
uuid = "ad0872d7-b137-4624-87af-d763ac8dea1a"
version = "0.3.5"
version = "0.3.6"
[[deps.Wayland_jll]]
deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"]
Expand Down Expand Up @@ -2718,7 +2726,7 @@ version = "1.4.1+0"
"""

# ╔═╡ Cell order:
# ╟─7ec2da8e-0bf9-11ed-1f84-6dfe87a9d83b
# ╠═7ec2da8e-0bf9-11ed-1f84-6dfe87a9d83b
# ╟─d1592c9c-3516-4b59-9ba3-fe33c2274bf7
# ╟─7d8bfda3-ed03-4b43-b02b-c218aad7e76f
# ╟─5549bd89-1220-4d71-9cd7-c5b286d4ffd9
Expand Down
10 changes: 6 additions & 4 deletions test/VisualSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ using SafeTestsets
@safetestset "Visual Search" begin
using VisualSearchACTR, Test, Parameters, Random
include("../Tutorial_Models/Unit10/Visual_Search/Visual_Search_Model.jl")
Random.seed!(93)
Random.seed!(95594)
n_trials = 200
topdown_weight = .66
stimuli,all_fixations = simulate(;n_trials, topdown_weight);
experiment = Experiment(;n_trials)
fixed_parms = (noise=false, rnd_time=false)
stimuli,all_fixations = simulate(experiment; topdown_weight, fixed_parms...);
x = range(.8 * topdown_weight, 1.2 * topdown_weight, length=100)
@time y = map(x->computeLL(stimuli, all_fixations; topdown_weight=x), x)
y = map(x -> computeLL(stimuli, all_fixations; topdown_weight=x, fixed_parms...), x)
mxv,mxi = findmax(y)
topdown_weight′ = x[mxi]
@test topdown_weight topdown_weight′ atol = 2e-2
@test topdown_weight topdown_weight′ atol = 4e-2
end

0 comments on commit 032ed3e

Please sign in to comment.