Skip to content

Commit

Permalink
Merge pull request #437 from int-brain-lab/develop
Browse files Browse the repository at this point in the history
Release 2.8.0
  • Loading branch information
oliche authored Jan 19, 2022
2 parents 0ce0b86 + a61aa50 commit 8adae0a
Show file tree
Hide file tree
Showing 14 changed files with 924 additions and 140 deletions.
201 changes: 194 additions & 7 deletions brainbox/behavior/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,21 @@
from iblutil.util import Bunch
import brainbox.behavior.pyschofit as psy
import logging
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

_logger = logging.getLogger('ibllib')

TRIALS_KEYS = ['contrastLeft',
'contrastRight',
'feedbackType',
'probabilityLeft',
'choice',
'response_times',
'stimOn_times']


def get_lab_training_status(lab, date=None, details=True, one=None):
"""
Expand Down Expand Up @@ -303,14 +316,14 @@ def concatenate_trials(trials):
"""
Concatenate trials from different training sessions
:param trials: dict containing trials objects from three consective training sessions,
:param trials: dict containing trials objects from three consecutive training sessions,
keys are session dates
:type trials: Bunch
:return: trials object with data concatenated over three training sessions
:rtype: dict
"""
trials_all = Bunch()
for k in trials[list(trials.keys())[0]].keys():
for k in TRIALS_KEYS:
trials_all[k] = np.concatenate(list(trials[kk][k] for kk in trials.keys()))

return trials_all
Expand Down Expand Up @@ -395,6 +408,35 @@ def compute_performance_easy(trials):
return np.sum(trials['feedbackType'][easy_trials] == 1) / easy_trials.shape[0]


def compute_performance(trials, signed_contrast=None, block=None):
"""
Compute performance on all trials at each contrast level from trials object
:param trials: trials object that must contain contrastLeft, contrastRight and feedbackType
keys
:type trials: dict
returns: float containing performance on easy contrast trials
"""
if signed_contrast is None:
signed_contrast = get_signed_contrast(trials)

if block is None:
block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
else:
block_idx = trials.probabilityLeft == block

if not np.any(block_idx):
return np.nan * np.zeros(2)

contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
rightward = trials.choice == -1
# Calculate the proportion rightward for each contrast type
prob_choose_right = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) &
block_idx]))(contrasts)

return prob_choose_right, contrasts, n_contrasts


def compute_n_trials(trials):
"""
Compute number of trials in trials object
Expand All @@ -418,6 +460,7 @@ def compute_psychometric(trials, signed_contrast=None, block=None):
:type block: float
:return: array of psychometric fit parameters - bias, threshold, lapse high, lapse low
"""

if signed_contrast is None:
signed_contrast = get_signed_contrast(trials)

Expand All @@ -429,11 +472,7 @@ def compute_psychometric(trials, signed_contrast=None, block=None):
if not np.any(block_idx):
return np.nan * np.zeros(4)

contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
rightward = trials.choice == -1
# Calculate the proportion rightward for each contrast type
prob_choose_right = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) &
block_idx]))(contrasts)
prob_choose_right, contrasts, n_contrasts = compute_performance(trials, signed_contrast=signed_contrast, block=block)

psych, _ = psy.mle_fit_psycho(
np.vstack([contrasts, n_contrasts, prob_choose_right]),
Expand Down Expand Up @@ -471,6 +510,31 @@ def compute_median_reaction_time(trials, stim_on_type='stimOn_times', signed_con
return reaction_time


def compute_reaction_time(trials, stim_on_type='stimOn_times', signed_contrast=None, block=None):
"""
Compute median reaction time for all contrasts
:param trials: trials object that must contain response_times and stimOn_times
:param stim_on_type:
:param signed_contrast:
:param block:
:return:
"""

if signed_contrast is None:
signed_contrast = get_signed_contrast(trials)

if block is None:
block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
else:
block_idx = trials.probabilityLeft == block

contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
reaction_time = np.vectorize(lambda x: np.nanmedian((trials.response_times - trials[stim_on_type])
[(x == signed_contrast) & block_idx]))(contrasts)

return reaction_time, contrasts, n_contrasts


def criterion_1a(psych, n_trials, perf_easy):
"""
Returns bool indicating whether criterion for trained_1a is met. All criteria documented here
Expand Down Expand Up @@ -508,3 +572,126 @@ def criterion_delay(n_trials, perf_easy):
"""
criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9)
return criterion


def plot_psychometric(trials, ax=None, title=None, **kwargs):
"""
Function to plot pyschometric curve plots a la datajoint webpage
:param trials:
:return:
"""

signed_contrast = get_signed_contrast(trials)
contrasts_fit = np.arange(-100, 100)

prob_right_50, contrasts, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5)
pars_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5)
prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit)

prob_right_20, contrasts, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2)
pars_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2)
prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit)

prob_right_80, contrasts, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8)
pars_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8)
prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit)

cmap = sns.diverging_palette(20, 220, n=3, center="dark")

if not ax:
fig, ax = plt.subplots(**kwargs)
else:
fig = plt.gcf()

# TODO error bars

fit_50 = ax.plot(contrasts_fit, prob_right_fit_50, color=cmap[1])
data_50 = ax.scatter(contrasts, prob_right_50, color=cmap[1])
fit_20 = ax.plot(contrasts_fit, prob_right_fit_20, color=cmap[0])
data_20 = ax.scatter(contrasts, prob_right_20, color=cmap[0])
fit_80 = ax.plot(contrasts_fit, prob_right_fit_80, color=cmap[2])
data_80 = ax.scatter(contrasts, prob_right_80, color=cmap[2])
ax.legend([fit_50[0], data_50, fit_20[0], data_20, fit_80[0], data_80],
['p_left=0.5 fit', 'p_left=0.5 data', 'p_left=0.2 fit', 'p_left=0.2 data', 'p_left=0.8 fit', 'p_left=0.8 data'],
loc='upper left')
ax.set_ylim(-0.05, 1.05)
ax.set_ylabel('Probability choosing right')
ax.set_xlabel('Contrasts')
if title:
ax.set_title(title)

return fig, ax


def plot_reaction_time(trials, ax=None, title=None, **kwargs):
"""
Function to plot reaction time against contrast a la datajoint webpage (inversed for some reason??)
:param trials:
:return:
"""

signed_contrast = get_signed_contrast(trials)
reaction_50, contrasts, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5)
reaction_20, contrasts, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2)
reaction_80, contrasts, _ = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8)

cmap = sns.diverging_palette(20, 220, n=3, center="dark")

if not ax:
fig, ax = plt.subplots(**kwargs)
else:
fig = plt.gcf()

data_50 = ax.plot(contrasts, reaction_50, '-o', color=cmap[1])
data_20 = ax.plot(contrasts, reaction_20, '-o', color=cmap[0])
data_80 = ax.plot(contrasts, reaction_80, '-o', color=cmap[2])

# TODO error bars

ax.legend([data_50[0], data_20[0], data_80[0]],
['p_left=0.5 data', 'p_left=0.2 data', 'p_left=0.8 data'],
loc='upper left')
ax.set_ylabel('Reaction time (s)')
ax.set_xlabel('Contrasts')

if title:
ax.set_title(title)

return fig, ax


def plot_reaction_time_over_trials(trials, stim_on_type='stimOn_times', ax=None, title=None, **kwargs):
"""
Function to plot reaction time with trial number a la datajoint webpage
:param trials:
:param stim_on_type:
:param ax:
:param title:
:param kwargs:
:return:
"""

reaction_time = pd.DataFrame()
reaction_time['reaction_time'] = trials.response_times - trials[stim_on_type]
reaction_time.index = reaction_time.index + 1
reaction_time_rolled = reaction_time['reaction_time'].rolling(window=10).median()
reaction_time_rolled = reaction_time_rolled.where((pd.notnull(reaction_time_rolled)), None)
reaction_time = reaction_time.where((pd.notnull(reaction_time)), None)

if not ax:
fig, ax = plt.subplots(**kwargs)
else:
fig = plt.gcf()

ax.scatter(np.arange(len(reaction_time.values)), reaction_time.values, s=16, color='darkgray')
ax.plot(np.arange(len(reaction_time_rolled.values)), reaction_time_rolled.values, color='k', linewidth=2)
ax.set_yscale('log')
ax.set_ylim(0.1, 100)
ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax.set_ylabel('Reaction time (s)')
ax.set_xlabel('Trial number')
if title:
ax.set_title(title)

return fig, ax
Loading

0 comments on commit 8adae0a

Please sign in to comment.