-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbibliotheque_artifact_detection.py
267 lines (218 loc) · 8.04 KB
/
bibliotheque_artifact_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
# TOOLS FOR ARTIFACT DETECTION AND REPLACEMENT
import numpy as np
import pandas as pd
from scipy import interpolate
from scipy import signal
def detect_cross(sig, threshold):
"""
Detect crossings
----------
Parameters
----------
- sig : np.array
A signal
- threshold : float
Amplitude threshold where crossings are detected
-------
Returns
-------
- pd.DataFrame or None
Pandas dataframe with index of rises and decays or None if nothing detected
"""
rises, = np.where((sig[:-1] <=threshold) & (sig[1:] >threshold)) # detect where sign inversion from - to +
decays, = np.where((sig[:-1] >=threshold) & (sig[1:] <threshold)) # detect where sign inversion from + to -
if rises.size != 0:
if rises[0] > decays[0]: # first point detected has to be a rise
decays = decays[1:] # so remove the first decay if is before first rise
if rises[-1] > decays[-1]: # last point detected has to be a decay
rises = rises[:-1] # so remove the last rise if is after last decay
return pd.DataFrame.from_dict({'rises':rises, 'decays':decays}, orient = 'index').T
else:
return None
def compute_rms(x):
"""
Fast root mean square
----------
Parameters
----------
- x : np.array
A signal
-------
Returns
-------
- np.array
Same size than input but Root Mean Squared
"""
n = x.size
ms = 0
for i in range(n):
ms += x[i] ** 2
ms /= n
return np.sqrt(ms)
def sliding_rms(x, sf, window=0.5, step=0.2, interp=True):
"""
Compute a sliding root mean square
----------
Parameters
----------
- x : np.array
Signal
- sf : float
Sampling frequency
- window : float
Duration of the sliding window in seconds
- step : float
Time duration between steps
- interp : bool
Interpolate to return a vector of same size that input signal
-------
Returns
-------
- t : np.array
Time vector of the returned trace
- out : np.array
Signal (trace) smoothed by RMS
"""
halfdur = window / 2
n = x.size
total_dur = n / sf
last = n - 1
idx = np.arange(0, total_dur, step)
out = np.zeros(idx.size)
# Define beginning, end and time (centered) vector
beg = ((idx - halfdur) * sf).astype(int)
end = ((idx + halfdur) * sf).astype(int)
beg[beg < 0] = 0
end[end > last] = last
# Alternatively, to cut off incomplete windows (comment the 2 lines above)
# mask = ~((beg < 0) | (end > last))
# beg, end = beg[mask], end[mask]
t = np.column_stack((beg, end)).mean(1) / sf
# Now loop over successive epochs
for i in range(idx.size):
out[i] = compute_rms(x[beg[i] : end[i]])
# Finally interpolate
if interp and step != 1 / sf:
f = interpolate.interp1d(t, out, kind="cubic", bounds_error=False, fill_value=0, assume_sorted=True)
t = np.arange(n) / sf
out = f(t)
return t, out
def compute_artifact_features(inds, srate):
"""
Compute temporal features of output of detect_cross()
----------
Parameters
----------
- inds : pd.DataFrame
pd.DataFrame with rises and decays index = output of detect_cross()
-------
Returns
-------
- pd.DataFrame
Same size than input but with temporal features for each pair of rise/decay
"""
artifacts = pd.DataFrame()
artifacts['start_ind'] = inds['rises'].astype(int)
artifacts['stop_ind'] = inds['decays'].astype(int)
artifacts['start_t'] = artifacts['start_ind'] / srate
artifacts['stop_t'] = artifacts['stop_ind'] / srate
artifacts['duration'] = artifacts['stop_t'] - artifacts['start_t']
return artifacts
def detect_artifacts(sig, srate, n_deviations = 5, low_freq = 40 , high_freq = 150, wsize = 1, step = 0.2):
"""
Detect artifacts based on burst of low_freq to high_freq power deduced from filtering + sliding RMS of the filtered signal
----------
Parameters
----------
- sig : np.array
Raw signal
- srate : float
Sampling rate
- n_deviations : float
Number of MAD deviations from the median of the RMS filtered signal
- low_freq : float
Low cutoff frequency of the IIR filter
- high_freq : float
High cutoff frequency of the IIR filter
- wsize : float
Window size (duration) in seconds of the RMS window
- step :
Time duration between steps of RMS window
-------
Returns
-------
- pd.DataFrame or None
pd.DataFrame of tempotal features of artifacted windows or None if no artifact detected
"""
import ghibtools as gh
sig_filtered = gh.iirfilt(sig, srate, low_freq, high_freq, ftype = 'bessel', order = 2)
t, sig_cross = sliding_rms(sig_filtered, sf=srate, window = wsize, step = step)
pos = np.median(sig_cross)
dev = gh.mad(sig_cross)
detect_threshold = pos + n_deviations * dev
times = detect_cross(sig_cross, detect_threshold)
if not times is None:
artifacts = compute_artifact_features(times, sig_cross, srate)
return artifacts
else:
return None
def insert_noise(sig, srate, chan_artifacts, freq_min=30., margin_s=0.2, seed=None):
"""
Insert patches of white noise coloured with average power content of a signal where artifacts have been detected in itself
----------
Parameters
----------
- sig : np.array
Raw signal
- srate : float
Sampling rate
- chan_artifacts : pd.DataFrame
Output of detect_artifacts() which contains temporal features of artifacts
- freq_min : float
Low cutoff frequency of the computing of average power content of the signal
- margin_s : float
Duration in seconds of the margins at the edge of patches where a tapering is done to smooth transition between sig-patch-sig
- seed : int or None
Seed to generate random noise
-------
Returns
-------
- sig_corrected : np.array
Signal corrected with patches of coloured noise
"""
sig_corrected = sig.copy()
margin = int(srate * margin_s)
up = np.linspace(0, 1, margin)
down = np.linspace(1, 0, margin)
noise_size = np.sum(chan_artifacts['stop_ind'].values - chan_artifacts['start_ind'].values) + 2 * margin * chan_artifacts.shape[0]
# estimate psd sig
freqs, spectrum = signal.welch(sig, nperseg=noise_size, nfft=noise_size, noverlap=0, scaling='spectrum', window='box', return_onesided=False, average='median')
spectrum = np.sqrt(spectrum)
# pregenerate long noise piece
rng = np.random.RandomState(seed=seed)
long_noise = rng.randn(noise_size)
noise_F = np.fft.fft(long_noise)
#long_noise = np.fft.ifft(np.abs(noise_F) * spectrum * np.exp(1j * np.angle(noise_F)))
long_noise = np.fft.ifft(spectrum * np.exp(1j * np.angle(noise_F)))
long_noise = long_noise.astype(sig.dtype)
sos = signal.iirfilter(2, freq_min / (srate / 2), analog=False, btype='highpass', ftype='bessel', output='sos')
long_noise = signal.sosfiltfilt(sos, long_noise, axis=0)
filtered_sig = signal.sosfiltfilt(sos, sig, axis=0)
rms_sig = np.median(filtered_sig**2)
rms_noise = np.median(long_noise**2)
factor = np.sqrt(rms_sig) / np.sqrt(rms_noise)
long_noise *= factor
noise_ind = 0
for _, artifact in chan_artifacts.iterrows():
ind0, ind1 = artifact['start_ind'], artifact['stop_ind']
n_samples = ind1 - ind0 + 2 * margin
sig_corrected[ind0:ind1] = 0
sig_corrected[ind0-margin:ind0] *= down
sig_corrected[ind1:ind1+margin] *= up
noise = long_noise[noise_ind: noise_ind + n_samples]
noise_ind += n_samples
noise += np.linspace(sig[ind0-1-margin], sig[ind1+1+margin], n_samples)
noise[:margin] *= up
noise[-margin:] *= down
sig_corrected[ind0-margin:ind1+margin] += noise
return sig_corrected