Skip to content

Commit

Permalink
Merge remote-tracking branch 'Po-Chen/master' into pckuo-master
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Nov 28, 2022
2 parents c564299 + 6b55aa3 commit 902a07f
Show file tree
Hide file tree
Showing 109 changed files with 17,603 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,6 @@ annotation_10.nrrd
notebook/dave_notebooks/*
report/*
**/*.nwb

# data
data/*
148 changes: 148 additions & 0 deletions notebook/pckuo/UtilFunctions_KH.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-
"""
Created on Thu Aug 25 16:57:31 2022
Functions for Ophys/DataJoint-Behavior Analysis
@author: kenta.hagihara
"""


import datajoint as dj
import numpy as np
from pipeline import lab, get_schema_name, experiment, foraging_model, ephys, foraging_analysis, histology, ccf


def _get_independent_variableKH(unit_key, model_id, var_name=None):
#modified from _get_unit_independent_variable
"""
Get independent variable over trial for a specified unit (ignored trials are skipped)
@param unit_key:
@param model_id:
@param var_name
@return: DataFrame (trial, variables)
"""

#hemi = _get_units_hemisphere(unit_key)
contra, ipsi = ['right', 'left'] #if hemi == 'left' else ['left', 'right']

# Get latent variables from model fitting
q_latent_variable = (foraging_model.FittedSessionModel.TrialLatentVariable
& unit_key
& {'model_id': model_id})

# Flatten latent variables to generate columns like 'left_action_value', 'right_choice_prob'
latent_variables = q_latent_variable.heading.secondary_attributes
q_latent_variable_all = dj.U('trial') & q_latent_variable
for lv in latent_variables:
for prefix, side in zip(['left_', 'right_', 'contra_', 'ipsi_'],
['left', 'right', contra, ipsi]):
# Better way here?
q_latent_variable_all *= eval(f"(q_latent_variable & {{'water_port': '{side}'}}).proj({prefix}{lv}='{lv}', {prefix}='water_port')")

# Add relative and total value
q_latent_variable_all = q_latent_variable_all.proj(...,
relative_action_value_lr='right_action_value - left_action_value',
relative_action_value_ic='contra_action_value - ipsi_action_value',
total_action_value='contra_action_value + ipsi_action_value')

# Add choice
q_independent_variable = (q_latent_variable_all * experiment.WaterPortChoice).proj(...,
choice='water_port',
choice_lr='water_port="right"',
choice_ic=f'water_port="{contra}"')

# Add reward
q_independent_variable = (q_independent_variable * experiment.BehaviorTrial.proj('outcome')).proj(...,
reward='outcome="hit"'
)

df = q_independent_variable.fetch(format='frame', order_by='trial').reset_index()

# Compute RPE
df['rpe'] = np.nan
df.loc[0, 'rpe'] = df.reward[0]
for side in ['left', 'right']:
_idx = df[(df.choice == side) & (df.trial > 1)].index
df.loc[_idx, 'rpe'] = df.reward.iloc[_idx] - df[f'{side}_action_value'].iloc[_idx - 1].values

return df if var_name is None else df[['trial', var_name]]


#### For trial alignment
def align_phys_to_behav_trials(phys_barcode, behav_barcode, behav_trialN=None):
'''
Align physiology trials (ephys/ophys) to behavioral trials using the barcode
Input: phys_barcode (list), behav_barcode (list), behav_trialN (list)
Output: a dictionary with fields
'phys_to_behav_mapping': a list of trial mapping [phys_trialN, behav_trialN]. Use this to trialize events in phys recording
'phys_not_in_behav': phys trials that are not found in behavioral trials
'behav_not_in_phys': behavioral trials that are not found in phys trials
'phys_aligned_blocks': blocks of consecutive phys trials that are aligned with behav
'behav_aligned_blocks': blocks of consecutive behav trials that are aligned with phys (each block has the same length as phys_aligned_blocks)
'perfectly_aligned': whether phys and behav trials are perfectly aligned
'''

if behav_trialN is None:
behav_trialN = np.r_[1:len(behav_barcode) + 1]
else:
behav_trialN = np.array(behav_trialN)

behav_barcode = np.array(behav_barcode)

phys_to_behav_mapping = [] # A list of [phys_trial, behav_trial]
phys_not_in_behav = [] # Happens when the bpod protocol is terminated during a trial (incomplete bpod trial will not be ingested to behavior)
behav_not_in_phys = [] # Happens when phys recording starts later or stops earlier than the bpod protocol
behav_aligned_blocks = [] # A list of well-aligned blocks
phys_aligned_blocks = [] # A list of well-aligned blocks
bitCollision = [] # Add the trial numbers with the same bitcode for restrospective sanity check purpose (220817KH)
behav_aligned_last = -999
phys_aligned_last = -999
in_a_continous_aligned_block = False # A flag indicating whether the previous phys trial is in a continuous aligned block

for phys_trialN_this, phys_barcode_this in zip(range(1, len(phys_barcode + ['fake']) + 1), phys_barcode + ['fake']): # Add a fake value to deal with the boundary effect
behav_trialN_this = behav_trialN[behav_barcode == phys_barcode_this]
#assert len(behav_trialN_this) <= 1 # Otherwise bitcode must be problematic

#'''
if len(behav_trialN_this) > 1:

bitCollision.append(behav_trialN_this)
closest_idx = np.abs(np.array(behav_trialN_this) - phys_trialN_this).argmin()
behav_trialN_this = behav_trialN_this[closest_idx:closest_idx+1] #only retaining the closest trialN (220817KH)
#'''
if len(behav_trialN_this) == 0 or behav_trialN_this - behav_aligned_last > 1: # The current continuously aligned block is broken
# Add a continuously aligned block
if behav_aligned_last != -999 and phys_aligned_last != -999 and in_a_continous_aligned_block:
behav_aligned_blocks.append([behav_aligned_block_start_this, behav_aligned_last])
phys_aligned_blocks.append([phys_aligned_block_start_this, phys_aligned_last])

in_a_continous_aligned_block = False

if len(behav_trialN_this) == 0:
phys_not_in_behav.append(phys_trialN_this)
else:
phys_to_behav_mapping.append([phys_trialN_this, behav_trialN_this[0]]) # The main output

# Cache the last behav-phys matched pair
behav_aligned_last = behav_trialN_this[0]
phys_aligned_last = phys_trialN_this

# Cache the start of each continuously aligned block
if not in_a_continous_aligned_block: # A new continuous block just starts
behav_aligned_block_start_this = behav_trialN_this[0]
phys_aligned_block_start_this = phys_trialN_this

# Switch on the flag
in_a_continous_aligned_block = True

phys_not_in_behav.pop(-1) # Remove the last fake value
behav_not_in_phys = list(np.setdiff1d(behav_trialN, [b for _, b in phys_to_behav_mapping]))

return {'phys_to_behav_mapping': phys_to_behav_mapping,
'phys_not_in_behav': phys_not_in_behav,
'behav_not_in_phys': behav_not_in_phys,
'phys_aligned_blocks': phys_aligned_blocks,
'behav_aligned_blocks': behav_aligned_blocks,
'perfectly_aligned': len(phys_not_in_behav + behav_not_in_phys) == 0
}
Binary file added notebook/pckuo/ephys_msp_decode_pop.pickle
Binary file not shown.
131 changes: 131 additions & 0 deletions notebook/pckuo/ephys_msp_decode_pop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy import signal
from scipy import stats
import itertools
import seaborn as sns
import statsmodels.api as sm
import random


import pickle
import json
json_open = open('../../dj_local_conf.json', 'r')
config = json.load(json_open)

import datajoint as dj
dj.config['database.host'] = config["database.host"]
dj.config['database.user'] = config ["database.user"]
dj.config['database.password'] = config["database.password"]
dj.conn().connect()

from pipeline import lab, get_schema_name, experiment, foraging_model, ephys, foraging_analysis, histology, ccf
from pipeline.plot import unit_psth
from pipeline.plot.foraging_model_plot import plot_session_model_comparison, plot_session_fitted_choice
from pipeline import psth_foraging
from pipeline import util
from pipeline.model import bandit_model






if __name__ == "__main__":

# load data
with open('./neurons_data_match_iti_all.pickle', 'rb') as handle:
neurons = pickle.load(handle)

with open('./q_latents_match.pickle', 'rb') as handle:
q_latents = pickle.load(handle)

with open('./pseudo_sessions_match.pickle', 'rb') as handle:
pseudo_sessions_dict = pickle.load(handle)


# msp fit, poppulation decoding
# compute the test statistic: sum of residuals from regression models
# in all possible permutations


regions_to_fit = ['ALM', 'PL', 'ACA', 'ORB', 'LSN', 'striatum', 'MD']
target_variables = ['Q_left', 'Q_right', 'sigma_Q', 'delta_Q']

msp_decode_pop_columns = ['gen_session_perm', 'fit_session_perm', 'target_variable', 'residuals', 'mse_total']
df_msp_decode_pop_dict = {region: pd.DataFrame(columns=msp_decode_pop_columns) for region in regions_to_fit}


# fit on all permutations of sessions
for region in regions_to_fit:
print(f'region {region}')
neurons_region = neurons[region]
sessions_with_unit = np.unique(neurons_region['session'].values)

df_Qs = q_latents[region]
# calculate minimal session length
sess_min_len = 100000
for sess in sessions_with_unit:
df_Qs_sess = df_Qs[df_Qs['session']==sess].sort_values(by=['trial'])
sess_len = len(df_Qs_sess)
sess_min_len = min(sess_min_len, sess_len)
print(f' min session length: {sess_min_len}')

df_msp_fit = df_msp_decode_pop_dict[region]


for j, p in enumerate(itertools.permutations(range(len(sessions_with_unit)))):
print(f' permutation {j}: {p}')

for target_variable in target_variables:
gen_session_perm = []
fit_session_perm = []
residuals = []
mse_total = []

for session_id in range(len(sessions_with_unit)):
gen_session = sessions_with_unit[session_id]
fit_session = sessions_with_unit[p[session_id]]
gen_session_perm.append(gen_session)
fit_session_perm.append(fit_session)
# print(f' using gen_session {gen_session} and fit_session {fit_session}')

# get population activity
neurons_region_session = neurons_region[neurons_region['session']==gen_session]
fr = np.empty((sess_min_len,
len(neurons_region_session)))
for j in range(len(neurons_region_session)):
fr[:, j] = neurons_region_session.iloc[j]['firing_rates'][:sess_min_len]
fr = sm.add_constant(fr)
#print(f' sess {gen_session} fr shape {fr.shape}')

# get Qs
df_Qs_session = df_Qs[df_Qs['session']==fit_session].sort_values(by=['trial'])
if target_variable in ['Q_left', 'Q_right']:
X = df_Qs_session[[target_variable]][:sess_min_len]
elif target_variable == 'sigma_Q':
X = (df_Qs_session[['Q_left']].values + df_Qs_session[['Q_right']].values)[:sess_min_len]
elif target_variable == 'delta_Q':
X = (df_Qs_session[['Q_left']].values - df_Qs_session[['Q_right']].values)[:sess_min_len]
else:
raise ValueError('incorrect target variable type!')

# decoding models: fr --> X
model = sm.OLS(X, fr)
results = model.fit()
#print(f'{neuron_type} {n} {target_variable}')
#print(f' {results.f_pvalue}')
residuals.append(results.resid)
mse_total.append(results.mse_total)

residuals = np.array(residuals)
mse_total = np.array(mse_total)
df_msp_fit.loc[len(df_msp_fit.index)] = [gen_session_perm, fit_session_perm,
target_variable, residuals, mse_total]


# save the ps population decoding df if not existed
with open('./ephys_msp_decode_pop.pickle', 'wb') as handle:
pickle.dump(df_msp_decode_pop_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
Binary file added notebook/pckuo/ephys_ps_decode.pickle
Binary file not shown.
Loading

0 comments on commit 902a07f

Please sign in to comment.