HART  0.2.0
High level Audio Regression and Testing
Loading...
Searching...
No Matches
hart_max_cross_correlation.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <algorithm> // min()
4#include <cmath> // abs()
5
9#include "metrics/hart_metric_query.hpp"
10#include "metrics/hart_metrics_common.hpp" // CorrelationSearchMode, ChannelSubsets
11#include "hart_slice.hpp"
12#include "hart_units.hpp" // Unit
13#include "hart_utils.hpp" // roundToSizeT()
14
15namespace hart
16{
17
18/// @brief Calculates maximum normalized cross-correlation between two audio buffers
19/// @details
20/// Searches for the best normalized cross-correlation value within a specified
21/// lag range independently for each selected pair of channels.
22///
23/// Cross-correlation is calculated using the following formula:
24/// @f[
25/// \frac{\sum_n x[n]\,y[n+k]}
26/// {\sqrt{
27/// \left(\sum_n x[n]^2\right)
28/// \left(\sum_n y[n+k]^2\right)
29/// }}
30/// @f]
31///
32/// (`sum (x[n] * y[n + k]) / sqrt (sum (x[n]^2) * sum (y[n + k]^2))`)
33///
34/// where:
35/// - `x[n]` is the left-hand-side signal
36/// - `y[n + k]` is the right-hand-side signal shifted by lag `k`
37/// - `k` is searched in the range `[-maxLag, +maxLag]`
38///
39/// The result is normalized to the range `[-1, 1]`, where:
40/// - `+1` means perfect positive correlation
41/// - `-1` means perfect negative correlation (polarity inversion)
42/// - `0` means no linear correlation
43///
44/// Depending on @p searchMode, the metric either:
45/// - searches for the largest signed correlation value
46/// - or searches for the largest absolute correlation value while still returning
47/// the original signed correlation.
48///
49/// Correlation is calculated independently for each selected pair of channels.
50/// Use a reducer to combine multiple channel-pair results into a scalar.
51///
52/// Usage examples:
53/// @code
54/// // Mono signals, default channel mapping {0,0}
55/// const double corr = maxCrossCorrelation (monoInput, monoOutput, 100_ms).get();
56///
57/// // Same, but polarity-invariant lag search
58/// const double corrAbs = maxCrossCorrelation (
59/// monoInput,
60/// monoOutput,
61/// 100_ms,
62/// bestAbsoluteCorrelation
63/// ).get();
64///
65/// // Stereo signals, strongest matched pair correlation
66/// const double maxCorr = maxCrossCorrelation (stereoInput, stereoOutput, 100_ms).get (max());
67///
68/// // Cross-map channels explicitly
69/// const double swappedCorr = maxCrossCorrelation (input, output, 100_ms)
70/// .ch ({ {0, 1}, {1, 0} })
71/// .get (min());
72///
73/// // Detect polarity inversion
74/// const double corrSigned = maxCrossCorrelation (
75/// input,
76/// invertedOutput,
77/// 100_ms,
78/// bestAbsoluteCorrelation
79/// ).get();
80///
81/// HART_EXPECT_LT (corrSigned, 0.0);
82/// @endcode
83///
84/// Notes:
85/// - Gain differences do not affect the result due to normalization.
86/// - DC offset may reduce correlation.
87/// - Heavy non-linear processing may significantly reduce correlation.
88/// - Returned value remains signed even in `bestAbsoluteCorrelation` mode.
89/// - If no valid overlap exists, returns `NaN`.
90///
91/// Supports only `Unit::native` and `Unit::none` units.
92///
93/// @param bufferA Left-hand-side audio buffer
94/// @param bufferB Right-hand-side audio buffer
95/// @param maxLagSeconds Maximum lag to search in seconds
96/// @param searchMode Controls how the best lag is selected, see @ref `CorrelationSearchMode`
97///
98/// @return MetricQuery containing signed normalized cross-correlation values
99///
100/// @tparam SampleType Floating point sample type, typically `float` or `double`
101///
102/// @throws hart::ValueError If `maxLagSeconds` is negative
103/// @throws hart::SampleRateError If sample rates differ
104/// @throws hart::IndexError If requested channel indices are out of range
105/// @throws hart::UnitError If unsupported unit is requested
106///
107/// @ingroup Metrics
108template <typename SampleType>
110 const AudioBuffer<SampleType>& bufferA,
111 const AudioBuffer<SampleType>& bufferB,
112 double maxLagSeconds,
114)
115{
116 if (maxLagSeconds < 0.0)
117 HART_THROW_OR_RETURN (hart::ValueError,"Maximum lag must be non-negative", {});
118
119 if ((bufferA.hasSampleRate() || bufferB.hasSampleRate()) && bufferA.getSampleRateHz() != bufferB.getSampleRateHz())
120 HART_THROW_OR_RETURN (hart::SampleRateError, "Audio buffers must have equal sample rates", {});
121
122 typename MetricQuery<double>::ChannelPairMetricEvaluator evaluator =
123 [&bufferA, &bufferB, maxLagSeconds, searchMode]
124 (size_t channelA,size_t channelB, Slice slice, Unit requestedUnit)
125 -> double
126 {
127 // Should be checked by MetricQuery
128 hassert (channelA < bufferA.getNumChannels());
129 hassert (channelB < bufferB.getNumChannels());
130
131 if (requestedUnit != Unit::native && requestedUnit != Unit::none)
132 HART_THROW_OR_RETURN (hart::UnitError, "Cross-correlation does not support requested unit", hart::nan<double>());
133
134 if (slice.isEmpty())
135 return hart::nan<double>();
136
137 const auto sliceFrameIndices = bufferA.getFrameIndices (slice);
138 const size_t sliceStart = sliceFrameIndices.first;
139 const size_t sliceStop = sliceFrameIndices.second;
140 hassert (sliceStop > sliceStart);
141 hassert (sliceStop <= bufferA.getNumFrames());
142 hassert (sliceStop <= bufferB.getNumFrames());
143
144 const size_t numFrames = sliceStop - sliceStart;
145 hassert (numFrames != 0);
146
147 const double sampleRateHz = bufferA.getSampleRateHz();
148 const size_t maxLagFrames = roundToSizeT (maxLagSeconds * sampleRateHz);
149
150 const SampleType* x = bufferA[channelA] + sliceStart;
151 const SampleType* y = bufferB[channelB] + sliceStart;
152
153 double bestCorrelation = (searchMode == bestSignedCorrelation) ? -hart::inf : 0.0;
154 bool hadValidOverlap = false;
155
156 for (int lag = -static_cast<int> (maxLagFrames); lag <= static_cast<int> (maxLagFrames); ++lag)
157 {
158 const bool lagIsNegative = lag < 0;
159 const size_t lagAbsFrames = static_cast<size_t> (lagIsNegative ? -lag : lag);
160
161 if (lagAbsFrames >= numFrames)
162 continue;
163
164 const size_t xBegin = lagIsNegative ? lagAbsFrames : 0;
165 const size_t yBegin = lagIsNegative ? 0 : lagAbsFrames;
166 const size_t overlapFrames = numFrames - lagAbsFrames;
167
168 AccurateSum<double> dotProduct;
169 AccurateSum<double> sumSquaresX;
170 AccurateSum<double> sumSquaresY;
171
172 for (size_t frame = 0; frame < overlapFrames; ++frame)
173 {
174 const double xn = static_cast<double> (x[xBegin + frame]);
175 const double yn = static_cast<double> (y[yBegin + frame]);
176
177 dotProduct += xn * yn;
178 sumSquaresX += xn * xn;
179 sumSquaresY += yn * yn;
180 }
181
182 const double energyX = sumSquaresX.getValue();
183 const double energyY = sumSquaresY.getValue();
184
185 if (floatsEqual (energyX, 0.0) || floatsEqual (energyY, 0.0))
186 continue;
187
188 hadValidOverlap = true;
189 const double correlation = dotProduct.getValue() / std::sqrt (energyX * energyY);
190
191 if (searchMode == bestSignedCorrelation)
192 {
193 if (correlation > bestCorrelation)
194 bestCorrelation = correlation;
195 }
196 else // bestAbsoluteCorrelation
197 {
198 if (std::abs (correlation) > std::abs (bestCorrelation))
199 bestCorrelation = correlation;
200 }
201
202 if (floatsEqual (std::abs (bestCorrelation), 1.0))
203 break;
204 }
205
206 if (! hadValidOverlap)
207 return hart::nan<double>();
208
209 return bestCorrelation;
210 };
211
212 const size_t numPairs = std::min (bufferA.getNumChannels(), bufferB.getNumChannels());
213 return MetricQuery<double> (
214 std::move (evaluator),
215 bufferA.getNumChannels(),
216 bufferB.getNumChannels(),
218 );
219}
220
221} // namespace hart
Implements Kahan algorithm for floating point accumulations.
SampleType getValue() const
AccurateSum & operator+=(SampleType value)
Adds a value to a sum, tracking the potential floating point error.
Container for audio data.
Manages the metrics calculations.
Thrown when sample rate is mismatched.
Thrown when some metric is requested to return a value in an unsupported unit.
Thrown when an inappropriate value is encountered.
#define hassert(condition)
Triggers a HartAssertException if the condition is false
#define HART_THROW_OR_RETURN(ExceptionType, message, returnValue)
Throws an exception if HART_DO_NOT_THROW_EXCEPTIONS is set, prints a message and returns a specified ...
MetricQuery< double > maxCrossCorrelation(const AudioBuffer< SampleType > &bufferA, const AudioBuffer< SampleType > &bufferB, double maxLagSeconds, CorrelationSearchMode searchMode=bestAbsoluteCorrelation)
Calculates maximum normalized cross-correlation between two audio buffers.
FloatType nan()
Returns a quiet NaN value for the given floating-point type.
static size_t roundToSizeT(SampleType x)
Rounds a floating point value to a size_t value.
constexpr double inf
Infinity.
static SampleType floatsEqual(SampleType a, SampleType b, SampleType epsilon=(SampleType) 1e-8)
Compares two floating point numbers within a given tolerance.
CorrelationSearchMode
Describes how to look for best cross-correlation.
Unit
Represents a physical unit.
@ none
Unitless value.
@ native
Default (native) unit of whatever returns some value.
Helpers to generate common default channel subsets.
static std::vector< std::pair< size_t, size_t > > diagonalChannelPairs(size_t numChannels)
Represents a slice of analysis data.
bool isEmpty() const