11"""Trial steps functions adapted from BlueNetworkActivityComparison/bnac/onset.py."""
22
3+ import logging
4+ from pathlib import Path
5+
36import numpy as np
47from scipy .ndimage import gaussian_filter
58
9+ L = logging .getLogger (__name__ )
10+
11+
12+ def _get_bounds (params ):
13+ lower_bound , upper_bound = params ["bounds" ]
14+ assert lower_bound <= 0
15+ assert upper_bound >= 0
16+ return lower_bound , upper_bound
17+
618
719def _histogram_from_spikes (spikes , params ):
8- onset_test_window_length = params ["post_window" ][1 ] - params ["pre_window" ][0 ]
20+ lower_bound , upper_bound = _get_bounds (params )
21+ onset_test_window_length = upper_bound - lower_bound
922 histogram , _ = np .histogram (
1023 spikes ,
11- range = [ params [ "pre_window" ][ 0 ], params [ "post_window" ][ 1 ]] ,
24+ range = ( lower_bound , upper_bound ) ,
1225 bins = int (onset_test_window_length * params ["histo_bins_per_ms" ]),
1326 )
1427 return histogram
1528
1629
1730def _onset_from_histogram (histogram , params ):
18- onset_dict = {}
31+ lower_bound , _ = _get_bounds ( params )
1932 smoothed_histogram = gaussian_filter (histogram , sigma = params ["smoothing_width" ])
2033
21- onset_pre_window_length = params ["pre_window" ][1 ] - params ["pre_window" ][0 ]
22- onset_zeroed_post_start = params ["post_window" ][0 ] - params ["pre_window" ][0 ]
23- pre_smoothed_histogram = smoothed_histogram [
24- : onset_pre_window_length * params ["histo_bins_per_ms" ]
25- ]
26- post_smoothed_histogram = smoothed_histogram [
27- onset_zeroed_post_start * params ["histo_bins_per_ms" ] :
28- ]
29-
30- onset_dict ["pre_mean" ] = np .mean (pre_smoothed_histogram )
31- onset_dict ["pre_std" ] = np .std (pre_smoothed_histogram )
32- onset_dict ["post_max" ] = np .max (post_smoothed_histogram )
34+ pre_window_length = - lower_bound
35+ pre_window_bins = int (pre_window_length * params ["histo_bins_per_ms" ])
36+ pre_smoothed_histogram = smoothed_histogram [:pre_window_bins ]
37+ post_smoothed_histogram = smoothed_histogram [pre_window_bins :]
38+
39+ onset_dict = {
40+ "pre_mean" : np .mean (pre_smoothed_histogram ),
41+ "pre_std" : np .std (pre_smoothed_histogram ),
42+ "post_max" : np .max (post_smoothed_histogram ),
43+ }
3344 onset_dict ["pre_mean_post_max_ratio" ] = onset_dict ["pre_mean" ] / onset_dict ["post_max" ]
3445
35- where_above_thresh = np .where (
36- post_smoothed_histogram
37- > (onset_dict ["pre_mean" ] + params ["threshold_std_multiple" ] * onset_dict ["pre_std" ])
38- )[0 ]
46+ threshold = onset_dict ["pre_mean" ] + params ["threshold_std_multiple" ] * onset_dict ["pre_std" ]
47+ where_above_thresh = np .where (post_smoothed_histogram > threshold )[0 ]
3948 index_above_std = 0
4049 if len (where_above_thresh ) > 0 :
4150 index_above_std = where_above_thresh [0 ]
@@ -45,13 +54,10 @@ def _onset_from_histogram(histogram, params):
4554 cortical_onset_index = index_above_std
4655
4756 onset_dict ["cortical_onset" ] = (
48- float (cortical_onset_index ) / float (params ["histo_bins_per_ms" ])
49- + float (params ["post_window" ][0 ])
50- + params ["ms_post_offset" ]
57+ float (cortical_onset_index ) / float (params ["histo_bins_per_ms" ]) + params ["ms_post_offset" ]
5158 )
52- if params [ "fig_paths" ] :
59+ if params . get ( "figures_path" ) :
5360 _plot (smoothed_histogram , params , onset_dict )
54- onset_dict ["trial_steps_value" ] = onset_dict ["cortical_onset" ]
5561 return onset_dict
5662
5763
@@ -63,27 +69,44 @@ def _plot(smoothed_histogram, params, onset_dict):
6369 import matplotlib .pyplot as plt
6470 import seaborn as sns
6571
72+ lower_bound , upper_bound = _get_bounds (params )
6673 plt .figure ()
6774 sns .set ()
6875 sns .set_style ("ticks" )
69- x_vals = list (np .arange (params [ "pre_window" ][ 0 ], params [ "post_window" ][ 1 ] , 0.2 ))
76+ x_vals = list (np .arange (lower_bound , upper_bound , 0.2 ))
7077 plt .plot (x_vals , smoothed_histogram )
7178 plt .scatter (
7279 onset_dict ["cortical_onset" ],
7380 [onset_dict ["pre_mean" ] + params ["threshold_std_multiple" ] * onset_dict ["pre_std" ]],
7481 )
75- plt .gca ().set_xlim ([params [ "pre_window" ][ 0 ], params [ "post_window" ][ 1 ] ])
82+ plt .gca ().set_xlim ([lower_bound , upper_bound ])
7683 plt .gca ().set_xlabel ("Time (ms)" )
7784 plt .gca ().set_ylabel ("Number of spikes" )
7885 plt .gca ().spines ["right" ].set_visible (False )
7986 plt .gca ().spines ["top" ].set_visible (False )
80- for fig_path in params ["fig_paths" ]:
81- plt .savefig (fig_path )
87+ filepath = Path (params ["base_path" ], params ["figures_path" ], "plot.pdf" )
88+ L .info ("Figures path: %s" , filepath )
89+ filepath .parent .mkdir (exist_ok = True )
90+ plt .savefig (filepath )
8291 plt .close ()
8392
8493
85- def onset_from_spikes (spikes , params ):
86- """Calculate trial steps from spikes."""
94+ def onset_from_spikes (spikes_list , params ):
95+ """Calculate the cortical onset from a list of spikes, one for each trial.
96+
97+ Args:
98+ spikes_list: list of spikes as numpy arrays.
99+ params: dictionary of parameters from the trial steps configuration.
100+
101+ Returns:
102+ float representing the dynamic offset to be added to the initial offset of each trial step.
103+ """
104+ L .info (
105+ "onset_from_spikes: processing %s arrays of spikes using params=%r" ,
106+ len (spikes_list ),
107+ params ,
108+ )
109+ spikes = np .concatenate (spikes_list )
87110 histogram = _histogram_from_spikes (spikes , params )
88111 onset_dict = _onset_from_histogram (histogram , params )
89- return onset_dict
112+ return onset_dict [ "cortical_onset" ]
0 commit comments