Skip to content
This repository was archived by the owner on Feb 26, 2025. It is now read-only.

Commit e1a8bdb

Browse files
[NSETM-2281] Fix calculation of dynamic offset (#19)
1 parent d5362f0 commit e1a8bdb

51 files changed

Lines changed: 405 additions & 397 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/publish-sdist.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
steps:
1313
- uses: actions/checkout@v4
1414
- name: Set up Python 3.11
15-
uses: actions/setup-python@v4
15+
uses: actions/setup-python@v5
1616
with:
1717
python-version: 3.11
1818

.github/workflows/run-tox.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
steps:
1717
- uses: actions/checkout@v4
1818
- name: Set up Python ${{ matrix.python-version }}
19-
uses: actions/setup-python@v4
19+
uses: actions/setup-python@v5
2020
with:
2121
python-version: ${{ matrix.python-version }}
2222

CHANGELOG.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
Changelog
22
=========
33

4+
Version 0.7.0
5+
-------------
6+
7+
New Features
8+
~~~~~~~~~~~~
9+
10+
- Allow to specify ``trial_steps_label`` to calculate the dynamic offset of trial steps [NSETM-2281]
11+
12+
413
Version 0.6.0
514
-------------
615

src/blueetl/config/analysis.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from copy import deepcopy
66
from itertools import chain
77
from pathlib import Path
8-
from typing import NamedTuple, Optional, Union
8+
from typing import NamedTuple, Union
99

1010
from blueetl.config.analysis_model import (
1111
FeaturesConfig,
@@ -20,13 +20,23 @@
2020
L = logging.getLogger(__name__)
2121

2222

23-
def _resolve_paths(global_config: MultiAnalysisConfig, base_path: Optional[Path] = None) -> None:
23+
def _resolve_paths(global_config: MultiAnalysisConfig, base_path: Path) -> None:
2424
"""Resolve any relative path."""
2525
base_path = base_path or Path()
2626
global_config.output = base_path / global_config.output
2727
global_config.simulation_campaign = base_path / global_config.simulation_campaign
2828

2929

30+
def _resolve_trial_steps(global_config: MultiAnalysisConfig):
31+
"""Set trial_steps_config.base_path to the same value as global_config.output.
32+
33+
In this way, the custom function can use it as the base path to save any figure.
34+
"""
35+
for config in global_config.analysis.values():
36+
for trial_steps_config in config.extraction.trial_steps.values():
37+
trial_steps_config.base_path = str(global_config.output)
38+
39+
3040
def _resolve_windows(global_config: MultiAnalysisConfig) -> None:
3141
"""Calculate the hash of any referenced windows in each single analysis configuration.
3242
@@ -140,13 +150,12 @@ def _resolve_analysis_configs(global_config: MultiAnalysisConfig) -> None:
140150
config.features = _resolve_features(config.features)
141151

142152

143-
def init_multi_analysis_configuration(
144-
global_config: dict, base_path: Optional[Path] = None
145-
) -> MultiAnalysisConfig:
153+
def init_multi_analysis_configuration(global_config: dict, base_path: Path) -> MultiAnalysisConfig:
146154
"""Return a config object from a config dict."""
147155
validate_config(global_config, schema=read_schema("analysis_config"))
148156
config = MultiAnalysisConfig(**global_config)
149157
_resolve_paths(config, base_path=base_path)
158+
_resolve_trial_steps(config)
150159
_resolve_windows(config)
151160
_resolve_analysis_configs(config)
152161
return config

src/blueetl/config/analysis_model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ class WindowConfig(BaseModel):
6868
@model_validator(mode="after")
6969
def validate_values(self):
7070
"""Validate the values after loading them."""
71-
if self.trial_steps_label:
72-
raise ValueError("trial_steps_label cannot be used yet, see NSETM-2281")
7371
if self.trial_steps_list and (self.n_trials or self.trial_steps_value):
7472
raise ValueError("trial_steps_list cannot be set with n_trials or trial_steps_value")
7573
if self.n_trials > 1 and not self.trial_steps_value:
@@ -84,13 +82,22 @@ class TrialStepsConfig(BaseModel):
8482
**BaseModel.model_config,
8583
"extra": "allow",
8684
}
85+
_forbidden_extra_fields: set[str] = {
86+
"initial_offset",
87+
}
8788
function: str
88-
initial_offset: float = 0.0
8989
bounds: tuple[float, float]
9090
population: Optional[str] = None
9191
node_set: Optional[str] = None
9292
limit: Optional[int] = None
9393

94+
@model_validator(mode="after")
95+
def forbid_fields(self):
96+
"""Verify that the forbidden extra fields have not been specified."""
97+
if found := self._forbidden_extra_fields.intersection(self.model_extra):
98+
raise ValueError(f"Forbidden extra fields: {found}")
99+
return self
100+
94101

95102
class NeuronClassConfig(BaseModel):
96103
"""NeuronClassConfig Model."""

src/blueetl/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
DURATION = "duration"
2222
WINDOW_TYPE = "window_type"
2323
COUNT = "count"
24-
TRIAL_STEPS_LABEL = "trial_steps_label"
25-
TRIAL_STEPS_VALUE = "trial_steps_value"
2624
TIMES = "times"
2725
BIN = "bin"
2826
VALUE = "value"
Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,50 @@
11
"""Trial steps functions adapted from BlueNetworkActivityComparison/bnac/onset.py."""
22

3+
import logging
4+
from pathlib import Path
5+
36
import numpy as np
47
from scipy.ndimage import gaussian_filter
58

9+
L = logging.getLogger(__name__)
10+
11+
12+
def _get_bounds(params):
13+
lower_bound, upper_bound = params["bounds"]
14+
assert lower_bound <= 0
15+
assert upper_bound >= 0
16+
return lower_bound, upper_bound
17+
618

719
def _histogram_from_spikes(spikes, params):
8-
onset_test_window_length = params["post_window"][1] - params["pre_window"][0]
20+
lower_bound, upper_bound = _get_bounds(params)
21+
onset_test_window_length = upper_bound - lower_bound
922
histogram, _ = np.histogram(
1023
spikes,
11-
range=[params["pre_window"][0], params["post_window"][1]],
24+
range=(lower_bound, upper_bound),
1225
bins=int(onset_test_window_length * params["histo_bins_per_ms"]),
1326
)
1427
return histogram
1528

1629

1730
def _onset_from_histogram(histogram, params):
18-
onset_dict = {}
31+
lower_bound, _ = _get_bounds(params)
1932
smoothed_histogram = gaussian_filter(histogram, sigma=params["smoothing_width"])
2033

21-
onset_pre_window_length = params["pre_window"][1] - params["pre_window"][0]
22-
onset_zeroed_post_start = params["post_window"][0] - params["pre_window"][0]
23-
pre_smoothed_histogram = smoothed_histogram[
24-
: onset_pre_window_length * params["histo_bins_per_ms"]
25-
]
26-
post_smoothed_histogram = smoothed_histogram[
27-
onset_zeroed_post_start * params["histo_bins_per_ms"] :
28-
]
29-
30-
onset_dict["pre_mean"] = np.mean(pre_smoothed_histogram)
31-
onset_dict["pre_std"] = np.std(pre_smoothed_histogram)
32-
onset_dict["post_max"] = np.max(post_smoothed_histogram)
34+
pre_window_length = -lower_bound
35+
pre_window_bins = int(pre_window_length * params["histo_bins_per_ms"])
36+
pre_smoothed_histogram = smoothed_histogram[:pre_window_bins]
37+
post_smoothed_histogram = smoothed_histogram[pre_window_bins:]
38+
39+
onset_dict = {
40+
"pre_mean": np.mean(pre_smoothed_histogram),
41+
"pre_std": np.std(pre_smoothed_histogram),
42+
"post_max": np.max(post_smoothed_histogram),
43+
}
3344
onset_dict["pre_mean_post_max_ratio"] = onset_dict["pre_mean"] / onset_dict["post_max"]
3445

35-
where_above_thresh = np.where(
36-
post_smoothed_histogram
37-
> (onset_dict["pre_mean"] + params["threshold_std_multiple"] * onset_dict["pre_std"])
38-
)[0]
46+
threshold = onset_dict["pre_mean"] + params["threshold_std_multiple"] * onset_dict["pre_std"]
47+
where_above_thresh = np.where(post_smoothed_histogram > threshold)[0]
3948
index_above_std = 0
4049
if len(where_above_thresh) > 0:
4150
index_above_std = where_above_thresh[0]
@@ -45,13 +54,10 @@ def _onset_from_histogram(histogram, params):
4554
cortical_onset_index = index_above_std
4655

4756
onset_dict["cortical_onset"] = (
48-
float(cortical_onset_index) / float(params["histo_bins_per_ms"])
49-
+ float(params["post_window"][0])
50-
+ params["ms_post_offset"]
57+
float(cortical_onset_index) / float(params["histo_bins_per_ms"]) + params["ms_post_offset"]
5158
)
52-
if params["fig_paths"]:
59+
if params.get("figures_path"):
5360
_plot(smoothed_histogram, params, onset_dict)
54-
onset_dict["trial_steps_value"] = onset_dict["cortical_onset"]
5561
return onset_dict
5662

5763

@@ -63,27 +69,44 @@ def _plot(smoothed_histogram, params, onset_dict):
6369
import matplotlib.pyplot as plt
6470
import seaborn as sns
6571

72+
lower_bound, upper_bound = _get_bounds(params)
6673
plt.figure()
6774
sns.set()
6875
sns.set_style("ticks")
69-
x_vals = list(np.arange(params["pre_window"][0], params["post_window"][1], 0.2))
76+
x_vals = list(np.arange(lower_bound, upper_bound, 0.2))
7077
plt.plot(x_vals, smoothed_histogram)
7178
plt.scatter(
7279
onset_dict["cortical_onset"],
7380
[onset_dict["pre_mean"] + params["threshold_std_multiple"] * onset_dict["pre_std"]],
7481
)
75-
plt.gca().set_xlim([params["pre_window"][0], params["post_window"][1]])
82+
plt.gca().set_xlim([lower_bound, upper_bound])
7683
plt.gca().set_xlabel("Time (ms)")
7784
plt.gca().set_ylabel("Number of spikes")
7885
plt.gca().spines["right"].set_visible(False)
7986
plt.gca().spines["top"].set_visible(False)
80-
for fig_path in params["fig_paths"]:
81-
plt.savefig(fig_path)
87+
filepath = Path(params["base_path"], params["figures_path"], "plot.pdf")
88+
L.info("Figures path: %s", filepath)
89+
filepath.parent.mkdir(exist_ok=True)
90+
plt.savefig(filepath)
8291
plt.close()
8392

8493

85-
def onset_from_spikes(spikes, params):
86-
"""Calculate trial steps from spikes."""
94+
def onset_from_spikes(spikes_list, params):
95+
"""Calculate the cortical onset from a list of spikes, one for each trial.
96+
97+
Args:
98+
spikes_list: list of spikes as numpy arrays.
99+
params: dictionary of parameters from the trial steps configuration.
100+
101+
Returns:
102+
float representing the dynamic offset to be added to the initial offset of each trial step.
103+
"""
104+
L.info(
105+
"onset_from_spikes: processing %s arrays of spikes using params=%r",
106+
len(spikes_list),
107+
params,
108+
)
109+
spikes = np.concatenate(spikes_list)
87110
histogram = _histogram_from_spikes(spikes, params)
88111
onset_dict = _onset_from_histogram(histogram, params)
89-
return onset_dict
112+
return onset_dict["cortical_onset"]

src/blueetl/extract/trial_steps.py

Lines changed: 0 additions & 117 deletions
This file was deleted.

0 commit comments

Comments
 (0)