HART  0.2.0
High level Audio Regression and Testing
Loading...
Searching...
No Matches
hart_polarity_preserved.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <cmath>
4#include <iomanip>
5#include <sstream>
6#include <vector>
7
10#include "hart_matcher.hpp"
13#include "hart_utils.hpp"
14
15namespace hart
16{
17
18/// @brief Checks whether the output signal preserves the polarity of the input signal
19/// @details
20/// Uses normalized cross-correlation in the time domain to compare input and output audio,
21/// while searching for the best match within a configurable lag range.
22/// Correlation is calculated independently for every applicable channel using the formula:
23/// @f[
24/// \frac{\sum_n x[n]\,y[n+k]}
25/// {\sqrt{\left(\sum_n x[n]^2\right)\left(\sum_n y[n+k]^2\right)}}
26/// @f]
27///
28/// (`sum (x[n] * y[n+k]) / sqrt (sum (x[n]^2) * sum (y[n+k]^2))`),
29///
30/// where `x` is input signal and `y` is observed output signal.
31///
32/// The lag with the highest absolute correlation is used to compensate for latency,
33/// and the signed correlation at this lag is then checked against a minimum signed
34/// correlation threshold.
35///
36/// For multi-channel audio, all applicable channels must preserve polarity.
37/// If at least one applicable channel exceeds the negative threshold, the match fails.
38///
39/// This matcher is useful for detecting accidental polarity inversions while remaining
40/// robust to latency and gain differences.
41///
42/// Notes:
43/// - Gain differences do not affect the result due to normalization.
44/// - Constant DC offset may bias the signed correlation.
45/// - Heavy nonlinear processing may reduce confidence in polarity detection.
46/// - Unlike @ref hart::CorrelationAbove, the sign of correlation is preserved.
47/// @tparam SampleType Floating point sample type, typically `float` or `double`
48/// @ingroup Matchers
49template <typename SampleType>
51 public Matcher<SampleType, PolarityPreserved<SampleType>>
52{
53public:
54
55 /// @brief Creates a polarity matcher with a minimum signed correlation threshold
56 /// @details
57 /// The matcher scans lags in the range `[-maxLagSeconds, +maxLagSeconds]`
58 /// and finds the lag with the strongest absolute normalized cross-correlation.
59 ///
60 /// Polarity is considered preserved only if the signed correlation at the best lag
61 /// is greater than or equal to `minimumSignedCorrelation`.
62 ///
63 /// Lower values make polarity detection more tolerant to distortion, noise,
64 /// or other waveform changes, while higher values require a cleaner match.
65 ///
66 /// @param minimumSignedCorrelation Minimum required correlation between input
67 /// and the output signal. If the observed correlation is in
68 /// `(minimumSignedCorrelation, minimumSignedCorrelation)` range, the matcher will
69 /// fail due to the signals being weakly correlated. If correlation is in
70 /// [-1, -minimumSignedCorrelation] range, the phase is considered flipped.
71 /// If it falls into [minimumSignedCorrelation, 1] range, it is considered
72 /// preserved, and this is where the matcher will pass.
73 /// @param maxLagSeconds Maximum absolute lag to search in seconds
74 /// @param silencePolicy Defines how channels with silence (zeros or almost zeros)
75 /// are handled. Available options are:
76 /// - `SilencePolicy::strict` - fails if any applicable channel is silent
77 /// - `SilencePolicy::relaxed` - ignores silent channels, as long as at least one
78 /// channel is not silent
79 /// @see hart::SilencePolicy
80 PolarityPreserved (double minimumSignedCorrelation = 0.5, double maxLagSeconds = 0.01, SilencePolicy silencePolicy = SilencePolicy::strict):
81 m_minimumSignedCorrelation (minimumSignedCorrelation),
82 m_maxLagSeconds (maxLagSeconds),
83 m_silencePolicy (silencePolicy)
84 {
85 if (m_minimumSignedCorrelation < 0 || m_minimumSignedCorrelation > 1.0)
88 "Signed correlation threshold should be in 0..1 range",
89 false
90 );
91
92 if (m_maxLagSeconds < 0)
95 "Max lag should be a non-negative number in seconds",
96 false
97 );
98 }
99
100 void prepare (
101 double sampleRateHz,
102 size_t numInputChannels,
103 size_t numOutputChannels,
104 size_t /*maxBlockSizeFrames*/
105 ) override
106 {
107 hassert (numInputChannels == numOutputChannels);
108 m_sampleRateHz = sampleRateHz;
109 m_maxLagFrames = static_cast<long long int> (std::round (m_maxLagSeconds * m_sampleRateHz));
110 }
111
112 bool canOperatePerBlock() const override
113 {
114 return false;
115 }
116
117 void reset() override
118 {
119 m_failureChannel = 0;
120 m_failureFrame = 0;
121 m_bestSignedCorrelation = 0.0;
122 m_bestLagFrames = 0;
123 m_hadValidData = false;
124 }
125
126 bool supportsChannelLayout (size_t numInputChannels, size_t numOutputChannels) const override
127 {
128 return numInputChannels == numOutputChannels;
129 }
130
131 bool match (AnalysisContext<SampleType> context) override
132 {
133 const AudioBuffer<SampleType>& inputAudio = context.inputAudio();
134 const AudioBuffer<SampleType>& observedOutputAudio = context.outputAudio();
135
136 hassert (inputAudio.getNumChannels() == observedOutputAudio.getNumChannels());
137 hassert (inputAudio.getNumFrames() == observedOutputAudio.getNumFrames());
138 hassert (inputAudio.getSampleRateHz() == observedOutputAudio.getSampleRateHz());
139
140 const size_t numFrames = inputAudio.getNumFrames();
141
142 if (numFrames == 0)
143 {
144 m_hadValidData = false;
145 return false;
146 }
147
148 const size_t numChannels = inputAudio.getNumChannels();
149 bool anyValidChannel = false;
150
151 for (size_t channel = 0; channel < numChannels; ++channel)
152 {
153 if (! this->appliesToChannel (channel))
154 continue;
155
156 const SampleType* x = inputAudio[channel];
157 const SampleType* y = observedOutputAudio[channel];
158 std::vector<double> prefixSumsSqX (numFrames + 1, 0.0);
159 std::vector<double> prefixSumsSqY (numFrames + 1, 0.0);
160 AccurateSum<double> runningSumSqX { 0.0 };
161 AccurateSum<double> runningSumSqY { 0.0 };
162
163 for (size_t frame = 0; frame < numFrames; ++frame)
164 {
165 const double xVal = static_cast<double> (x[frame]);
166 const double yVal = static_cast<double> (y[frame]);
167
168 runningSumSqX += xVal * xVal;
169 runningSumSqY += yVal * yVal;
170 prefixSumsSqX[frame + 1] = runningSumSqX;
171 prefixSumsSqY[frame + 1] = runningSumSqY;
172 }
173
174 double bestAbsCorrelation = -hart::inf;
175 double bestSignedCorrelation = 0.0;
176 long long int bestLag = 0;
177 bool channelValid = false;
178
179 for (long long int lag = -m_maxLagFrames; lag <= m_maxLagFrames; ++lag)
180 {
181 AccurateSum<double> dotProduct { 0.0 };
182 const bool lagShiftsOutputToTheLeft = lag < 0;
183 const size_t lagAbsFrames = static_cast<size_t> (lagShiftsOutputToTheLeft ? -lag : lag);
184
185 if (lagAbsFrames >= numFrames)
186 continue;
187
188 // For a given lag, correlate only the valid overlap interval:
189 // x[inputOverlapBeginFrame + offset] with y[outputOverlapBeginFrame + offset].
190 const size_t inputOverlapBeginFrame = lagShiftsOutputToTheLeft ? lagAbsFrames : 0;
191 const size_t outputOverlapBeginFrame = lagShiftsOutputToTheLeft ? 0 : lagAbsFrames;
192 const size_t overlapSizeFrames = numFrames - lagAbsFrames;
193 const size_t inputOverlapEndFrame = inputOverlapBeginFrame + overlapSizeFrames;
194 const size_t outputOverlapEndFrame = outputOverlapBeginFrame + overlapSizeFrames;
195 const double sumSqX = prefixSumsSqX[inputOverlapEndFrame] - prefixSumsSqX[inputOverlapBeginFrame];
196 const double sumSqY = prefixSumsSqY[outputOverlapEndFrame] - prefixSumsSqY[outputOverlapBeginFrame];
197
198 for (size_t overlapFrame = 0; overlapFrame < overlapSizeFrames; ++overlapFrame)
199 {
200 const double inputValue = static_cast<double> (x[inputOverlapBeginFrame + overlapFrame]);
201 const double outputValue = static_cast<double> (y[outputOverlapBeginFrame + overlapFrame]);
202 dotProduct += inputValue * outputValue;
203 }
204
205 if (floatsEqual (sumSqX, 0.0) || floatsEqual (sumSqY, 0.0))
206 continue;
207
208 channelValid = true;
209 const double corr = dotProduct / std::sqrt (sumSqX * sumSqY);
210 const double absCorr = std::abs (corr);
211
212 if (absCorr > bestAbsCorrelation)
213 {
214 bestAbsCorrelation = absCorr;
215 bestSignedCorrelation = corr;
216 bestLag = lag;
217 }
218
219 if (floatsEqual (absCorr, 1.0))
220 break;
221 }
222
223 if (! channelValid)
224 {
225 if (m_silencePolicy == SilencePolicy::strict)
226 {
227 m_hadValidData = false;
228 m_failureChannel = channel;
229 m_failureFrame = 0;
230 return false;
231 }
232
233 continue;
234 }
235
236 anyValidChannel = true;
237
238 if (bestSignedCorrelation < m_minimumSignedCorrelation)
239 {
240 m_hadValidData = true;
241 m_failureChannel = channel;
242 m_failureFrame = 0;
243 m_bestSignedCorrelation = bestSignedCorrelation;
244 m_bestLagFrames = bestLag;
245 return false;
246 }
247 }
248
249 if (! anyValidChannel)
250 {
251 m_hadValidData = false;
252 m_failureChannel = 0;
253 m_failureFrame = 0;
254 return false;
255 }
256
257 return true;
258 }
259
261 {
262 MatcherFailureDetails details;
263 details.channel = m_failureChannel;
264 details.frame = m_failureFrame;
265
266 if (!m_hadValidData)
267 {
268 details.description = "Polarity could not be determined with sufficient confidence";
269 return details;
270 }
271
272 const double lagSeconds = m_bestLagFrames / m_sampleRateHz;
273 std::stringstream stream;
274
275 stream
276 << "Detected signed correlation: "
277 << correlationPrecision << m_bestSignedCorrelation
278 << " at lag " << m_bestLagFrames << " frames ("
279 << secPrecision << lagSeconds << " seconds)";
280
281 details.description = stream.str();
282 return details;
283 }
284
285 void represent (std::ostream& stream) const override
286 {
287 stream
288 << "PolarityPreserved ("
289 << correlationPrecision << m_minimumSignedCorrelation << ", "
290 << secPrecision << m_maxLagSeconds << "_s, "
291 << "SilencePolicy::"
292 << (m_silencePolicy == SilencePolicy::strict ? "strict" : "relaxed")
293 << ")";
294 }
295
296private:
297 const double m_minimumSignedCorrelation;
298 const double m_maxLagSeconds;
299 const SilencePolicy m_silencePolicy;
300
301 double m_sampleRateHz = 0.0;
302 long long int m_maxLagFrames = 0;
303
304 double m_bestSignedCorrelation = 0.0;
305 long long int m_bestLagFrames = 0;
306
307 size_t m_failureChannel = 0;
308 size_t m_failureFrame = 0;
309 bool m_hadValidData = false;
310};
311
313
314} // namespace hart
Implements Kahan algorithm for floating point accumulations.
AccurateSum(SampleType initialSum=(SampleType) 0)
Inits AccurateSum with a specific value.
AccurateSum & operator+=(SampleType value)
Adds a value to a sum, tracking the potential floating point error.
Contains audio-related artefacts useful for analysis by matchers.
Container for audio data.
Base for audio matchers.
Checks whether the output signal preserves the polarity of the input signal.
MatcherFailureDetails getFailureDetails() const override
Returns a description of why the match has failed.
void represent(std::ostream &stream) const override
Makes a text representation of this Matcher for test failure outputs.
PolarityPreserved(double minimumSignedCorrelation=0.5, double maxLagSeconds=0.01, SilencePolicy silencePolicy=SilencePolicy::strict)
Creates a polarity matcher with a minimum signed correlation threshold.
bool match(AnalysisContext< SampleType > context) override
Tells the host if the piece of audio satisfies Matcher's condition or not.
void prepare(double sampleRateHz, size_t numInputChannels, size_t numOutputChannels, size_t) override
Prepare for processing It is guaranteed that all subsequent process() calls will be in line with the ...
bool canOperatePerBlock() const override
Tells the host if it can operate on a block-by-block basis.
void reset() override
Resets the matcher to its initial state.
bool supportsChannelLayout(size_t numInputChannels, size_t numOutputChannels) const override
Tells the host whether this Matcher is capable of operating on audio with a specific number of channe...
Thrown when an inappropriate value is encountered.
#define hassert(condition)
Triggers a HartAssertException if the condition is false
#define HART_THROW_OR_RETURN(ExceptionType, message, returnValue)
Throws an exception if HART_DO_NOT_THROW_EXCEPTIONS is set, prints a message and returns a specified ...
std::ostream & secPrecision(std::ostream &stream)
Sets number of decimal places for values in seconds.
static std::ostream & correlationPrecision(std::ostream &stream)
Sets number of decimal places for correlation values.
SilencePolicy
Defines how silence in various algorithms.
constexpr double inf
Infinity.
static SampleType floatsEqual(SampleType a, SampleType b, SampleType epsilon=(SampleType) 1e-8)
Compares two floating point numbers within a given tolerance.
#define HART_MATCHER_DECLARE_ALIASES_FOR(ClassName)
size_t channel
Index of channel at which the failure was detected.
std::string description
Readable description of why the match has failed.
size_t frame
Index of frame at which the match has failed.