42template <
typename SampleType>
46 public Matcher<SampleType, CorrelationAbove<SampleType>>
57 CorrelationAbove (
double minCorrelation,
double maxLagSeconds = 0.01):
58 m_minCorrelation (minCorrelation),
59 m_maxLagSeconds (maxLagSeconds)
61 if (m_minCorrelation < 0 || m_minCorrelation > 1.0)
64 if (m_maxLagSeconds < 0)
68 void prepare (
double sampleRateHz, size_t , size_t , size_t )
override
70 m_sampleRateHz = sampleRateHz;
71 m_maxLagFrames =
static_cast<
long long int> (std::round (m_maxLagSeconds * m_sampleRateHz));
74 bool canOperatePerBlock()
const override
83 m_bestCorrelation = 0.0;
85 m_hadValidData =
false;
88 bool supportsChannelLayout (size_t numInputChannels, size_t numOutputChannels)
const override
90 return numInputChannels == numOutputChannels;
95 const AudioBuffer<SampleType>& inputAudio = context.inputAudio();
96 const AudioBuffer<SampleType>& observedOutputAudio = context.outputAudio();
98 hassert (inputAudio.getNumChannels() == observedOutputAudio.getNumChannels());
99 hassert (inputAudio.getNumFrames() == observedOutputAudio.getNumFrames());
100 hassert (inputAudio.getSampleRateHz() == observedOutputAudio.getSampleRateHz());
102 const size_t numChannels = inputAudio.getNumChannels();
103 const size_t numFrames = inputAudio.getNumFrames();
107 m_hadValidData =
false;
111 double worstChannelCorrelation =
hart::inf;
112 size_t worstChannel = 0;
113 bool anyValidChannel =
false;
115 for (size_t channel = 0; channel < numChannels; ++channel)
117 if (!
this->appliesToChannel (channel))
121 long long int bestLag = 0;
122 bool channelValid =
false;
124 const SampleType* x = inputAudio[channel];
125 const SampleType* y = observedOutputAudio[channel];
126 std::vector<
double> prefixSumsSqX (numFrames + 1, 0.0);
127 std::vector<
double> prefixSumsSqY (numFrames + 1, 0.0);
131 for (size_t frame = 0; frame < numFrames; ++frame)
133 const double xVal =
static_cast<
double> (x[frame]);
134 const double yVal =
static_cast<
double> (y[frame]);
136 runningSumSqX
+= xVal * xVal;
137 runningSumSqY
+= yVal * yVal;
138 prefixSumsSqX[frame + 1] = runningSumSqX;
139 prefixSumsSqY[frame + 1] = runningSumSqY;
145 for (
long long int lag = -m_maxLagFrames; lag <= m_maxLagFrames; ++lag)
148 const bool lagShiftsOutputToTheLeft = lag < 0;
149 const size_t lagAbsFrames =
static_cast<size_t> (lagShiftsOutputToTheLeft ? -lag : lag);
151 if (lagAbsFrames >= numFrames)
156 const size_t inputOverlapBeginFrame = lagShiftsOutputToTheLeft ? lagAbsFrames : 0;
157 const size_t outputOverlapBeginFrame = lagShiftsOutputToTheLeft ? 0 : lagAbsFrames;
158 const size_t overlapSizeFrames = numFrames - lagAbsFrames;
159 const size_t inputOverlapEndFrame = inputOverlapBeginFrame + overlapSizeFrames;
160 const size_t outputOverlapEndFrame = outputOverlapBeginFrame + overlapSizeFrames;
161 const double sumSqX = prefixSumsSqX[inputOverlapEndFrame] - prefixSumsSqX[inputOverlapBeginFrame];
162 const double sumSqY = prefixSumsSqY[outputOverlapEndFrame] - prefixSumsSqY[outputOverlapBeginFrame];
164 for (size_t overlapFrame = 0; overlapFrame < overlapSizeFrames; ++overlapFrame)
166 const double xnValue =
static_cast<
double> (x[inputOverlapBeginFrame + overlapFrame]);
167 const double ynValue =
static_cast<
double> (y[outputOverlapBeginFrame + overlapFrame]);
168 dotProduct
+= xnValue * ynValue;
175 const double corr = dotProduct / std::sqrt (sumSqX * sumSqY);
176 const double absCorr = std::abs (corr);
178 if (absCorr > bestCorrelation)
180 bestCorrelation = absCorr;
186 bestCorrelation = absCorr;
195 anyValidChannel =
true;
197 if (bestCorrelation < worstChannelCorrelation)
199 worstChannelCorrelation = bestCorrelation;
200 worstChannel = channel;
201 m_bestCorrelation = bestCorrelation;
202 m_bestLagFrames = bestLag;
206 if (! anyValidChannel)
208 m_hadValidData =
false;
209 m_failureChannel = 0;
214 m_hadValidData =
true;
216 if (worstChannelCorrelation >= m_minCorrelation)
219 m_failureChannel = worstChannel;
229 details
.frame = m_failureFrame;
231 if (! m_hadValidData)
233 details
.description =
"Correlation could not be computed (no valid signal overlap)";
237 const double lagSeconds = m_bestLagFrames / m_sampleRateHz;
238 std::stringstream stream;
242 <<
" at lag " << m_bestLagFrames <<
" frames ("
249 void represent (std::ostream& stream)
const override
252 <<
"CorrelationAbove ("
258 const double m_minCorrelation;
259 const double m_maxLagSeconds;
261 double m_sampleRateHz = 0.0;
262 long long int m_maxLagFrames = 0;
264 double m_bestCorrelation = 0.0;
265 long long m_bestLagFrames = 0;
267 size_t m_failureChannel = 0;
268 size_t m_failureFrame = 0;
269 bool m_hadValidData =
false;
Implements Kahan algorithm for floating point accumulations.
AccurateSum(SampleType initialSum=(SampleType) 0)
Inits AccurateSum with a specific value.
AccurateSum & operator+=(SampleType value)
Adds a value to a sum, tracking the potential floating point error.
Contains audio-related artefacts useful for analysis by matchers.
Container for audio data.
Thrown when an inappropriate value is encountered.
#define hassert(condition)
Triggers a HartAssertException if the condition is false
#define HART_THROW_OR_RETURN(ExceptionType, message, returnValue)
Throws an exception if HART_DO_NOT_THROW_EXCEPTIONS is set, prints a message and returns a specified ...
std::ostream & secPrecision(std::ostream &stream)
Sets number of decimal places for values in seconds.
static std::ostream & correlationPrecision(std::ostream &stream)
Sets number of decimal places for correlation values.
#define HART_DEPRECATED(msg)
constexpr double inf
Infinity.
static SampleType floatsEqual(SampleType a, SampleType b, SampleType epsilon=(SampleType) 1e-8)
Compares two floating point numbers within a given tolerance.
#define HART_MATCHER_DECLARE_ALIASES_FOR(ClassName)
Details about matcher failure.
size_t channel
Index of channel at which the failure was detected.
std::string description
Readable description of why the match has failed.
size_t frame
Index of frame at which the match has failed.