Skip to content

Commit

Permalink
Merge pull request #134 from rest-for-physics/jporron-work-on-PeakFinder
Browse files Browse the repository at this point in the history
Change peak finding algorithm
  • Loading branch information
lobis authored May 30, 2024
2 parents 1137fde + a9c083c commit 22760ed
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 25 deletions.
6 changes: 5 additions & 1 deletion inc/TRestRawPeaksFinderProcess.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class TRestRawPeaksFinderProcess : public TRestEventProcess {
UShort_t fDistance = 10;
/// \brief window size to calculate the peak amplitude
UShort_t fWindow = 10;
/// \brief option to remove all veto signals after finding the peaks
Bool_t fRemoveAllVetos = false;
/// \brief option to remove peakless veto signals after finding the peaks
Bool_t fRemovePeaklessVetos = false;

std::set<std::string> fChannelTypes = {}; // this process will only be applied to selected channel types

Expand All @@ -45,7 +49,7 @@ class TRestRawPeaksFinderProcess : public TRestEventProcess {
TRestRawPeaksFinderProcess() = default;
~TRestRawPeaksFinderProcess() = default;

ClassDefOverride(TRestRawPeaksFinderProcess, 3);
ClassDefOverride(TRestRawPeaksFinderProcess, 4);
};

#endif // REST_TRESTRAWPEAKSFINDERPROCESS_H
1 change: 1 addition & 0 deletions inc/TRestRawSignal.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class TRestRawSignal {
/// Peaks are defined as the points that are above the threshold and are separated by a minimum distance
/// in time bin units. The threshold must be set in absolute value (regardless of the baseline)
std::vector<std::pair<UShort_t, double>> GetPeaks(double threshold, UShort_t distance = 5) const;
std::vector<std::pair<UShort_t, double>> GetPeaksVeto(double threshold, UShort_t distance = 5) const;

TRestRawSignal();
TRestRawSignal(Int_t nBins);
Expand Down
99 changes: 83 additions & 16 deletions src/TRestRawPeaksFinderProcess.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ TRestEvent* TRestRawPeaksFinderProcess::ProcessEvent(TRestEvent* inputEvent) {
fInputEvent->InitializeReferences(run);
}

auto event = fInputEvent->GetSignalEventForTypes(fChannelTypes, fReadoutMetadata);

if (fReadoutMetadata == nullptr) {
fReadoutMetadata = fInputEvent->GetReadoutMetadata();
}
Expand All @@ -32,8 +30,8 @@ TRestEvent* TRestRawPeaksFinderProcess::ProcessEvent(TRestEvent* inputEvent) {

vector<tuple<UShort_t, UShort_t, double>> eventPeaks;

for (int signalIndex = 0; signalIndex < event.GetNumberOfSignals(); signalIndex++) {
const auto signal = event.GetSignal(signalIndex);
for (int signalIndex = 0; signalIndex < fInputEvent->GetNumberOfSignals(); signalIndex++) {
const auto signal = fInputEvent->GetSignal(signalIndex);
const UShort_t signalId = signal->GetSignalID();

const string channelType = fReadoutMetadata->GetTypeForChannelDaqId(signalId);
Expand All @@ -44,18 +42,27 @@ TRestEvent* TRestRawPeaksFinderProcess::ProcessEvent(TRestEvent* inputEvent) {
continue;
}

signal->CalculateBaseLine(0, 5);
const auto peaks = signal->GetPeaks(signal->GetBaseLine() + 1.0, fDistance);
// Choose appropriate function based on channel type
if (channelType == "tpc") {
signal->CalculateBaseLine(fBaselineRange.X(), fBaselineRange.Y());
const auto peaks =
signal->GetPeaks(signal->GetBaseLine() + 5 * signal->GetBaseLineSigma(), fDistance);

for (const auto& [time, amplitude] : peaks) {
eventPeaks.emplace_back(signalId, time, amplitude);
}
/*
cout << "Signal ID: " << channelId << " Name: " << channelName << endl;
for (const auto& [time, amplitude] : peaks) {
cout << " - Peak at " << time << " with amplitude " << amplitude << endl;
for (const auto& [time, amplitude] : peaks) {
eventPeaks.emplace_back(signalId, time, amplitude);
}
} else if (channelType == "veto") {
// For veto signals the baseline is calculated over the whole range, as we don´t know where the
// signal will be.
signal->CalculateBaseLine(0, 511, "robust");
// For veto signals the threshold is selected by the user.
const auto peaks =
signal->GetPeaksVeto(signal->GetBaseLine() + fThresholdOverBaseline, fDistance);

for (const auto& [time, amplitude] : peaks) {
eventPeaks.emplace_back(signalId, time, amplitude);
}
}
*/
}

// sort eventPeaks by time, then signal id
Expand All @@ -68,16 +75,23 @@ TRestEvent* TRestRawPeaksFinderProcess::ProcessEvent(TRestEvent* inputEvent) {
vector<UShort_t> peaksChannelId;
vector<UShort_t> peaksTime;
vector<double> peaksAmplitude;
double peaksEnergy;
double amplitudeTotal = 0.0;

for (const auto& [channelId, time, amplitude] : eventPeaks) {
peaksChannelId.push_back(channelId);
peaksTime.push_back(time);
peaksAmplitude.push_back(amplitude);

amplitudeTotal += amplitude;
}

peaksEnergy = amplitudeTotal;

SetObservableValue("peaksChannelId", peaksChannelId);
SetObservableValue("peaksTime", peaksTime);
SetObservableValue("peaksAmplitude", peaksAmplitude);
SetObservableValue("totalPeaksEnergy", peaksEnergy);

vector<UShort_t> windowIndex(eventPeaks.size(), 0); // Initialize with zeros
vector<UShort_t> windowCenter; // for each different window, the center of the window
Expand Down Expand Up @@ -140,6 +154,49 @@ TRestEvent* TRestRawPeaksFinderProcess::ProcessEvent(TRestEvent* inputEvent) {
SetObservableValue("windowIndex", windowIndex);
SetObservableValue("windowCenter", windowCenter);

// Remove peakless veto signals after the peak finding if chosen
if (fRemovePeaklessVetos && !fRemoveAllVetos) {
set<UShort_t> peakSignalIds;
for (const auto& [channelId, time, amplitude] : eventPeaks) {
peakSignalIds.insert(channelId);
}

vector<UShort_t> signalsToRemove;
for (int signalIndex = 0; signalIndex < fInputEvent->GetNumberOfSignals(); signalIndex++) {
const auto signal = fInputEvent->GetSignal(signalIndex);
const UShort_t signalId = signal->GetSignalID();
const string signalType = fReadoutMetadata->GetTypeForChannelDaqId(signalId);

if (signalType == "veto" && peakSignalIds.find(signalId) == peakSignalIds.end()) {
signalsToRemove.push_back(signalId);
}
}

// Now remove all veto signals identified
for (const auto& signalId : signalsToRemove) {
fInputEvent->RemoveSignalWithId(signalId);
}
}

// Remove all veto signals after the peak finding if chosen
if (fRemoveAllVetos) {
vector<UShort_t> signalsToRemove;
for (int signalIndex = 0; signalIndex < fInputEvent->GetNumberOfSignals(); signalIndex++) {
const auto signal = fInputEvent->GetSignal(signalIndex);
const UShort_t signalId = signal->GetSignalID();
const string signalType = fReadoutMetadata->GetTypeForChannelDaqId(signalId);

if (signalType == "veto") {
signalsToRemove.push_back(signalId);
}
}

// Now remove all veto signals identified
for (const auto& signalId : signalsToRemove) {
fInputEvent->RemoveSignalWithId(signalId);
}
}

return inputEvent;
}

Expand All @@ -157,6 +214,8 @@ void TRestRawPeaksFinderProcess::InitFromConfigFile() {
fBaselineRange = Get2DVectorParameterWithUnits("baselineRange", fBaselineRange);
fDistance = StringToDouble(GetParameter("distance", fDistance));
fWindow = StringToDouble(GetParameter("window", fWindow));
fRemoveAllVetos = StringToBool(GetParameter("removeAllVetos", fRemoveAllVetos));
fRemovePeaklessVetos = StringToBool(GetParameter("removePeaklessVetos", fRemovePeaklessVetos));

if (fBaselineRange.X() > fBaselineRange.Y() || fBaselineRange.X() < 0 || fBaselineRange.Y() < 0) {
cerr << "TRestRawPeaksFinderProcess::InitProcess: baseline range is not sorted or < 0" << endl;
Expand All @@ -177,6 +236,13 @@ void TRestRawPeaksFinderProcess::InitFromConfigFile() {
cerr << "TRestRawPeaksFinderProcess::InitProcess: window is < 0" << endl;
exit(1);
}

if (filterType != "veto" && fRemovePeaklessVetos) {
cerr << "TRestRawPeaksFinderProcess::InitProcess: removing veto signals only makes sense when the "
"process is applied to veto signals. Remove \"removePeaklessVetos\" parameter"
<< endl;
exit(1);
}
}

void TRestRawPeaksFinderProcess::PrintMetadata() {
Expand All @@ -188,8 +254,9 @@ void TRestRawPeaksFinderProcess::PrintMetadata() {
}
RESTMetadata << RESTendl;

RESTMetadata << "Threshold over baseline: " << fThresholdOverBaseline << RESTendl;
RESTMetadata << "Baseline range: " << fBaselineRange.X() << " - " << fBaselineRange.Y() << RESTendl;
RESTMetadata << "Baseline range for tpc signals: " << fBaselineRange.X() << " - " << fBaselineRange.Y()
<< RESTendl;
RESTMetadata << "Threshold over baseline for veto signals: " << fThresholdOverBaseline << RESTendl;

RESTMetadata << "Distance: " << fDistance << RESTendl;
RESTMetadata << "Window: " << fWindow << RESTendl;
Expand Down
181 changes: 173 additions & 8 deletions src/TRestRawSignal.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

#include <TAxis.h>
#include <TF1.h>
#include <TH1D.h>
#include <TMath.h>
#include <TRandom3.h>

Expand Down Expand Up @@ -912,19 +913,183 @@ TGraph* TRestRawSignal::GetGraph(Int_t color) {
vector<pair<UShort_t, double>> TRestRawSignal::GetPeaks(double threshold, UShort_t distance) const {
vector<pair<UShort_t, double>> peaks;

for (UShort_t i = 0; i < GetNumberOfPoints(); i++) {
const double point = GetRawData(i);
if (i > 0 && i < GetNumberOfPoints() - 1) {
double prevPoint = GetRawData(i - 1);
double nextPoint = GetRawData(i + 1);
const UShort_t smoothingWindow =
10; // Region to compare for peak/no peak classification. 10 means 5 bins to each side
const size_t numPoints = GetNumberOfPoints();

if (point > threshold && point >= prevPoint && point >= nextPoint) {
// Check if the peak is spaced far enough from the previous peak
if (numPoints == 0) return peaks;

// Pre-calculate smoothed values for all bins using a rolling sum
vector<double> smoothedValues(numPoints, 0.0);
double currentSum = 0.0;
UShort_t windowSize = smoothingWindow + 1;

// Initialize the sum for the first window
for (UShort_t i = 0; i < static_cast<UShort_t>(std::min<size_t>(windowSize, numPoints)); ++i) {
currentSum += GetRawData(i);
}
smoothedValues[0] = currentSum / windowSize;

for (UShort_t i = 1; i < numPoints; ++i) {
if (i < smoothingWindow / 2 + 1) {
// Adjust the window size at the beginning
currentSum = 0.0;
UShort_t currentWindowSize =
static_cast<UShort_t>(std::min<size_t>(windowSize, i + smoothingWindow / 2 + 1));
for (UShort_t j = 0; j < currentWindowSize; ++j) {
currentSum += GetRawData(j);
}
smoothedValues[i] = currentSum / currentWindowSize;
} else if (i > numPoints - smoothingWindow / 2 - 1) {
// Adjust the window size at the end
currentSum = 0.0;
UShort_t currentWindowSize =
static_cast<UShort_t>(std::min<size_t>(windowSize, numPoints - i + smoothingWindow / 2));
for (UShort_t j = i - smoothingWindow / 2; j < numPoints; ++j) {
currentSum += GetRawData(j);
}
smoothedValues[i] = currentSum / currentWindowSize;
} else {
// Use the rolling sum for the middle bins
currentSum -= GetRawData(i - smoothingWindow / 2 - 1);
currentSum += GetRawData(i + smoothingWindow / 2);
smoothedValues[i] = currentSum / windowSize;
}
}

// Compare pre-calculated smoothed values to identify peaks
for (UShort_t i = 0; i < numPoints; ++i) {
const double smoothedValue = smoothedValues[i];

if (i >= smoothingWindow / 2 && i < numPoints - smoothingWindow / 2) {
bool isPeak = true;
int numGreaterEqual = 0; // Counter for smoothed values greater or equal to the studied bin

for (UShort_t j = i - smoothingWindow / 2; j <= i + smoothingWindow / 2; ++j) {
if (j != i && smoothedValue <= smoothedValues[j]) {
numGreaterEqual++;
if (numGreaterEqual >
2) { // If more than one smoothed value is greater or equal, it's not a peak
isPeak = false;
break;
}
}
}

// If it's a peak and it´s above the threshold and further than distance to the previous peak, add
// to peaks
if (isPeak && smoothedValue > threshold) {
if (peaks.empty() || i - peaks.back().first >= distance) {
peaks.push_back(std::make_pair(i, point));
double fitMinRange = i - 20;
double fitMaxRange = i + 20;

// Create a Gaussian fit function
TF1 fitFunction("gaussianFit", "gaus", fitMinRange, fitMaxRange);
// Fit the data with the Gaussian function
fitFunction.SetRange(fitMinRange, fitMaxRange); // Initial parameters

// Create histogram with the values to fit
TH1D histogram("hist", "hist", 40, fitMinRange, fitMaxRange);
for (int k = i - 20; k <= i + 20; ++k) {
histogram.SetBinContent(k - (i - 20) + 1, GetRawData(k)); // Set bin content
}
histogram.Fit(&fitFunction, "RQ");

// Get peak position and amplitude from the fit
double peakPosition = fitFunction.GetParameter(1);
UShort_t formattedPeakPosition = static_cast<UShort_t>(peakPosition);
double peakAmplitude = GetRawData(formattedPeakPosition);

peaks.push_back(std::make_pair(formattedPeakPosition, peakAmplitude));
}
}
}
}

return peaks;
}

vector<pair<UShort_t, double>> TRestRawSignal::GetPeaksVeto(double threshold, UShort_t distance) const {
vector<pair<UShort_t, double>> peaks;

const UShort_t smoothingWindow =
4; // Region to compare for peak/no peak classification. 10 means 5 bins to each side
const size_t numPoints = GetNumberOfPoints();

if (numPoints == 0) {
return peaks;
}

// Pre-calculate smoothed values for all bins using a rolling sum
vector<double> smoothedValues(numPoints, 0.0);
double currentSum = 0.0;
UShort_t windowSize = smoothingWindow + 1;

// Initialize the sum for the first window
for (UShort_t i = 0; i < static_cast<UShort_t>(std::min<size_t>(windowSize, numPoints)); ++i) {
currentSum += GetRawData(i);
}
smoothedValues[0] = currentSum / windowSize;

for (UShort_t i = 1; i < numPoints; ++i) {
if (i < smoothingWindow / 2 + 1) {
// Adjust the window size at the beginning
currentSum = 0.0;
UShort_t currentWindowSize =
static_cast<UShort_t>(std::min<size_t>(windowSize, i + smoothingWindow / 2 + 1));
for (UShort_t j = 0; j < currentWindowSize; ++j) {
currentSum += GetRawData(j);
}
smoothedValues[i] = currentSum / currentWindowSize;
} else if (i > numPoints - smoothingWindow / 2 - 1) {
// Adjust the window size at the end
currentSum = 0.0;
UShort_t currentWindowSize =
static_cast<UShort_t>(std::min<size_t>(windowSize, numPoints - i + smoothingWindow / 2));
for (UShort_t j = i - smoothingWindow / 2; j < numPoints; ++j) {
currentSum += GetRawData(j);
}
smoothedValues[i] = currentSum / currentWindowSize;
} else {
// Use the rolling sum for the middle bins
currentSum -= GetRawData(i - smoothingWindow / 2 - 1);
currentSum += GetRawData(i + smoothingWindow / 2);
smoothedValues[i] = currentSum / windowSize;
}
}

// Compare pre-calculated smoothed values to identify peaks
for (size_t i = 0; i < numPoints; ++i) {
const double smoothedValue = smoothedValues[i];

if (i >= smoothingWindow / 2 && i < numPoints - smoothingWindow / 2) {
bool isPeak = true;
int numGreaterEqual = 0; // Counter for smoothed values greater or equal to the studied bin

for (size_t j = i - smoothingWindow / 2; j <= i + smoothingWindow / 2; ++j) {
if (j != i && smoothedValue <= smoothedValues[j]) {
numGreaterEqual++;
if (numGreaterEqual >
0) { // If more than one smoothed value is greater or equal, it's not a peak
isPeak = false;
break;
}
}
}

// If it's a peak and it´s above the threshold and further than distance to the previous peak, add
// to peaks
if (isPeak && smoothedValue > threshold) {
if (peaks.empty() || i - peaks.back().first >= distance) {
auto peakPosition = double(i);
auto formattedPeakPosition = static_cast<UShort_t>(peakPosition);
double peakAmplitude = GetRawData(formattedPeakPosition);

peaks.emplace_back(formattedPeakPosition, peakAmplitude);
}
}
}
}

return peaks;
}

0 comments on commit 22760ed

Please sign in to comment.