HART  0.2.0
High level Audio Regression and Testing
Loading...
Searching...
No Matches
hart_correlation_above.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <cmath> // round()
4#include <iomanip>
5#include <vector>
6#include <sstream>
7
10#include "hart_matcher.hpp"
12#include "hart_utils.hpp" // inf, floatsEqual()
13
14namespace hart
15{
16
17/// @brief Checks whether the output signal is sufficiently correlated with the input signal
18/// @details
19/// Uses normalized cross-correlation in the time domain to compare input and output audio,
20/// while searching for the best match within a configurable lag range.
21/// Correlation is calculated independently for every applicable channel using the formula:
22/// @f[
23/// \frac{\sum_n x[n]\,y[n+k]}
24/// {\sqrt{\left(\sum_n x[n]^2\right)\left(\sum_n y[n+k]^2\right)}}
25/// @f]
26///
27/// (`sum (x[n] * y[n+k]) / sqrt (sum (x[n]^2) * sum (y[n+k]^2))`),
28///
29/// where `x` is input signal and `y` is observed output signal.
30///
31/// For multi-channel audio, the lowest correlation value across all applicable channels is used.
32/// This matcher is useful for verifying transparent DSP, latency compensation, bypass paths,
33/// or processors that preserve waveform shape while introducing delay or mild coloration.
34/// Notes:
35/// - Gain differences do not affect the result due to normalization.
36/// - Constant DC offset reduces correlation.
37/// - Heavy nonlinear processing may significantly reduce correlation.
38/// - The absolute value of correlation is used, so polarity inversions do not affect the result.
39/// @tparam SampleType Floating point sample type, typically `float` or `double`
40/// @ingroup Matchers
41/// @deprecated Use `hart::crossCorrelation()` metric instead
42template <typename SampleType>
43class
44HART_DEPRECATED("Use hart::crossCorrelation() metric instead")
45CorrelationAbove :
46 public Matcher<SampleType, CorrelationAbove<SampleType>>
47{
48public:
49
50 /// @brief Creates a correlation matcher with a minimum accepted correlation threshold
51 /// @details The matcher scans lags in the range `[-maxLagSeconds, +maxLagSeconds]`
52 /// and finds the best normalized cross-correlation value.
53 /// A value of `1.0` requires a perfect waveform match (ignoring polarity and latency),
54 /// while lower values allow progressively more waveform deviation.
55 /// @param minCorrelation Minimum allowed absolute correlation in the range `[0, 1]`
56 /// @param maxLagSeconds Maximum absolute lag to search in seconds
57 CorrelationAbove (double minCorrelation, double maxLagSeconds = 0.01):
58 m_minCorrelation (minCorrelation),
59 m_maxLagSeconds (maxLagSeconds)
60 {
61 if (m_minCorrelation < 0 || m_minCorrelation > 1.0)
62 HART_THROW_OR_RETURN (hart::ValueError, "Correlation should be in 0..1 range", false);
63
64 if (m_maxLagSeconds < 0)
65 HART_THROW_OR_RETURN (hart::ValueError, "Max lag should be a non-negative number in seconds", false);
66 }
67
68 void prepare (double sampleRateHz, size_t /* numInputChannels */, size_t /* numOutputChannels */, size_t /*maxBlockSizeFrames*/) override
69 {
70 m_sampleRateHz = sampleRateHz;
71 m_maxLagFrames = static_cast<long long int> (std::round (m_maxLagSeconds * m_sampleRateHz));
72 }
73
74 bool canOperatePerBlock() const override
75 {
76 return false;
77 }
78
79 void reset() override
80 {
81 m_failureChannel = 0;
82 m_failureFrame = 0;
83 m_bestCorrelation = 0.0;
84 m_bestLagFrames = 0;
85 m_hadValidData = false;
86 }
87
88 bool supportsChannelLayout (size_t numInputChannels, size_t numOutputChannels) const override
89 {
90 return numInputChannels == numOutputChannels;
91 }
92
93 bool match (AnalysisContext<SampleType> context) override
94 {
95 const AudioBuffer<SampleType>& inputAudio = context.inputAudio();
96 const AudioBuffer<SampleType>& observedOutputAudio = context.outputAudio();
97
98 hassert (inputAudio.getNumChannels() == observedOutputAudio.getNumChannels());
99 hassert (inputAudio.getNumFrames() == observedOutputAudio.getNumFrames());
100 hassert (inputAudio.getSampleRateHz() == observedOutputAudio.getSampleRateHz());
101
102 const size_t numChannels = inputAudio.getNumChannels();
103 const size_t numFrames = inputAudio.getNumFrames();
104
105 if (numFrames == 0)
106 {
107 m_hadValidData = false;
108 return false;
109 }
110
111 double worstChannelCorrelation = hart::inf;
112 size_t worstChannel = 0;
113 bool anyValidChannel = false;
114
115 for (size_t channel = 0; channel < numChannels; ++channel)
116 {
117 if (! this->appliesToChannel (channel))
118 continue;
119
120 double bestCorrelation = -hart::inf;
121 long long int bestLag = 0;
122 bool channelValid = false;
123
124 const SampleType* x = inputAudio[channel];
125 const SampleType* y = observedOutputAudio[channel];
126 std::vector<double> prefixSumsSqX (numFrames + 1, 0.0);
127 std::vector<double> prefixSumsSqY (numFrames + 1, 0.0);
128 AccurateSum<double> runningSumSqX { 0.0 };
129 AccurateSum<double> runningSumSqY { 0.0 };
130
131 for (size_t frame = 0; frame < numFrames; ++frame)
132 {
133 const double xVal = static_cast<double> (x[frame]);
134 const double yVal = static_cast<double> (y[frame]);
135
136 runningSumSqX += xVal * xVal;
137 runningSumSqY += yVal * yVal;
138 prefixSumsSqX[frame + 1] = runningSumSqX;
139 prefixSumsSqY[frame + 1] = runningSumSqY;
140 }
141
142 // Formula:
143 // sum (x[n] * y[n+k]) / sqrt (sum (x[n]^2) * sum (y[n+k]^2))
144
145 for (long long int lag = -m_maxLagFrames; lag <= m_maxLagFrames; ++lag)
146 {
147 AccurateSum<double> dotProduct { 0.0 };
148 const bool lagShiftsOutputToTheLeft = lag < 0;
149 const size_t lagAbsFrames = static_cast<size_t> (lagShiftsOutputToTheLeft ? -lag : lag);
150
151 if (lagAbsFrames >= numFrames)
152 continue;
153
154 // For a given lag, correlate only the valid overlap interval:
155 // x[inputOverlapBeginFrame + offset] with y[outputOverlapBeginFrame + offset].
156 const size_t inputOverlapBeginFrame = lagShiftsOutputToTheLeft ? lagAbsFrames : 0;
157 const size_t outputOverlapBeginFrame = lagShiftsOutputToTheLeft ? 0 : lagAbsFrames;
158 const size_t overlapSizeFrames = numFrames - lagAbsFrames;
159 const size_t inputOverlapEndFrame = inputOverlapBeginFrame + overlapSizeFrames;
160 const size_t outputOverlapEndFrame = outputOverlapBeginFrame + overlapSizeFrames;
161 const double sumSqX = prefixSumsSqX[inputOverlapEndFrame] - prefixSumsSqX[inputOverlapBeginFrame];
162 const double sumSqY = prefixSumsSqY[outputOverlapEndFrame] - prefixSumsSqY[outputOverlapBeginFrame];
163
164 for (size_t overlapFrame = 0; overlapFrame < overlapSizeFrames; ++overlapFrame)
165 {
166 const double xnValue = static_cast<double> (x[inputOverlapBeginFrame + overlapFrame]);
167 const double ynValue = static_cast<double> (y[outputOverlapBeginFrame + overlapFrame]);
168 dotProduct += xnValue * ynValue;
169 }
170
171 if (floatsEqual (sumSqX, 0.0) || floatsEqual (sumSqY, 0.0))
172 continue;
173
174 channelValid = true;
175 const double corr = dotProduct / std::sqrt (sumSqX * sumSqY);
176 const double absCorr = std::abs (corr);
177
178 if (absCorr > bestCorrelation)
179 {
180 bestCorrelation = absCorr;
181 bestLag = lag;
182 }
183
184 if (floatsEqual (absCorr, 1.0))
185 {
186 bestCorrelation = absCorr;
187 bestLag = lag;
188 break;
189 }
190 }
191
192 if (! channelValid)
193 continue;
194
195 anyValidChannel = true;
196
197 if (bestCorrelation < worstChannelCorrelation)
198 {
199 worstChannelCorrelation = bestCorrelation;
200 worstChannel = channel;
201 m_bestCorrelation = bestCorrelation;
202 m_bestLagFrames = bestLag;
203 }
204 }
205
206 if (! anyValidChannel)
207 {
208 m_hadValidData = false;
209 m_failureChannel = 0;
210 m_failureFrame = 0;
211 return false;
212 }
213
214 m_hadValidData = true;
215
216 if (worstChannelCorrelation >= m_minCorrelation)
217 return true;
218
219 m_failureChannel = worstChannel;
220 m_failureFrame = 0; // no specific frame failure
221
222 return false;
223 }
224
225 MatcherFailureDetails getFailureDetails() const override
226 {
227 MatcherFailureDetails details;
228 details.channel = m_failureChannel;
229 details.frame = m_failureFrame;
230
231 if (! m_hadValidData)
232 {
233 details.description = "Correlation could not be computed (no valid signal overlap)";
234 return details;
235 }
236
237 const double lagSeconds = m_bestLagFrames / m_sampleRateHz;
238 std::stringstream stream;
239
240 stream
241 << "Best correlation: " << correlationPrecision << m_bestCorrelation
242 << " at lag " << m_bestLagFrames << " frames ("
243 << secPrecision << lagSeconds << " seconds)";
244
245 details.description = stream.str();
246 return details;
247 }
248
249 void represent (std::ostream& stream) const override
250 {
251 stream
252 << "CorrelationAbove ("
253 << correlationPrecision << m_minCorrelation << ", "
254 << secPrecision << m_maxLagSeconds << "_s)";
255 }
256
257private:
258 const double m_minCorrelation;
259 const double m_maxLagSeconds;
260
261 double m_sampleRateHz = 0.0;
262 long long int m_maxLagFrames = 0;
263
264 double m_bestCorrelation = 0.0;
265 long long m_bestLagFrames = 0;
266
267 size_t m_failureChannel = 0;
268 size_t m_failureFrame = 0;
269 bool m_hadValidData = false;
270};
271
272HART_MATCHER_DECLARE_ALIASES_FOR (CorrelationAbove)
273
274} // namespace hart
Implements Kahan algorithm for floating point accumulations.
AccurateSum(SampleType initialSum=(SampleType) 0)
Inits AccurateSum with a specific value.
AccurateSum & operator+=(SampleType value)
Adds a value to a sum, tracking the potential floating point error.
Contains audio-related artefacts useful for analysis by matchers.
Container for audio data.
Base for audio matchers.
Thrown when an inappropriate value is encountered.
#define hassert(condition)
Triggers a HartAssertException if the condition is false
#define HART_THROW_OR_RETURN(ExceptionType, message, returnValue)
Throws an exception if HART_DO_NOT_THROW_EXCEPTIONS is set, prints a message and returns a specified ...
std::ostream & secPrecision(std::ostream &stream)
Sets number of decimal places for values in seconds.
static std::ostream & correlationPrecision(std::ostream &stream)
Sets number of decimal places for correlation values.
#define HART_DEPRECATED(msg)
constexpr double inf
Infinity.
static SampleType floatsEqual(SampleType a, SampleType b, SampleType epsilon=(SampleType) 1e-8)
Compares two floating point numbers within a given tolerance.
#define HART_MATCHER_DECLARE_ALIASES_FOR(ClassName)
size_t channel
Index of channel at which the failure was detected.
std::string description
Readable description of why the match has failed.
size_t frame
Index of frame at which the match has failed.