HART  0.2.0
High level Audio Regression and Testing
Loading...
Searching...
No Matches
hart_lag_at_max_cross_correlation.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <algorithm> // min()
4#include <cmath> // abs(), sqrt()
5
9#include "metrics/hart_metric_query.hpp"
10#include "metrics/hart_metrics_common.hpp" // CorrelationSearchMode, ChannelSubsets
11#include "hart_slice.hpp"
12#include "hart_units.hpp" // Unit
13#include "hart_utils.hpp" // floatsEqual()
14
15namespace hart
16{
17
18/// @brief Calculates lag corresponding to maximum normalized cross-correlation between two audio buffers
19/// @details
20/// Searches for the lag producing the strongest normalized cross-correlation
21/// independently for each selected pair of channels.
22///
23/// Cross-correlation is calculated using the following formula:
24/// @f[
25/// \frac{\sum_n x[n]\,y[n+k]}
26/// {\sqrt{
27/// \left(\sum_n x[n]^2\right)
28/// \left(\sum_n y[n+k]^2\right)
29/// }}
30/// @f]
31///
32/// (`sum (x[n] * y[n + k]) / sqrt (sum (x[n]^2) * sum (y[n + k]^2))`)
33///
34/// where:
35/// - `x[n]` is the left-hand-side signal
36/// - `y[n + k]` is the right-hand-side signal shifted by lag `k`
37/// - `k` is searched in the range `[-maxLag, +maxLag]`
38///
39/// Positive lag means that `bufferB` is delayed relative to `bufferA`.
40///
41/// Depending on `searchMode`, the metric either:
42/// - searches for the largest signed correlation value
43/// - or searches for the largest absolute correlation value
44///
45/// Correlation is calculated independently for each selected pair of channels.
46/// Use a reducer to combine multiple lag values into a scalar.
47///
48/// Supports `Unit::frames` (default/native) and `Unit::seconds`.
49/// For conversion to seconds, it uses sample rate metadata contained
50/// in the provided buffers.
51///
52/// Usage examples:
53/// @code
54/// // Detect latency in frames
55/// const double lagFrames = lagAtMaxCrossCorrelation (input, output, 100_ms).get();
56///
57/// // Same, but returned in seconds
58/// const double lagSeconds = lagAtMaxCrossCorrelation (input, output, 100_ms)
59/// .as (seconds)
60/// .get();
61///
62/// // Strongest lag across matched stereo channels
63/// const double maxLagFrames = lagAtMaxCrossCorrelation (input, output, 100_ms)
64/// .get (max());
65///
66/// // Custom channel mapping
67/// // Lags between:
68/// // - input, channel 0 vs output, channel 1
69/// // - input, channel 3 vs output, channel 3
70/// // - input, channel 1 vs output, channel 2
71/// const double swappedLagFrames = lagAtMaxCrossCorrelation (multiChanneInput, multiChanneOutput, 100_ms)
72/// .ch ({{0, 1}, {3, 3}, {1, 2}})
73/// .get (mean());
74/// @endcode
75///
76/// Notes:
77/// - Gain differences do not affect the result due to normalization.
78/// - Returned lag may be negative.
79/// - If no valid overlap exists, returns `NaN`.
80///
81/// @param bufferA Left-hand-side audio buffer
82/// @param bufferB Right-hand-side audio buffer
83/// @param maxLagSeconds Maximum lag to search in seconds
84/// @param minAbsBestCorrelation If best correlation (rectified) is under this value,
85/// then signals will be considered to not have valid overlap, and result will be NaN
86/// @param searchMode Controls how the best lag is selected
87///
88/// @return MetricQuery containing lag values corresponding to best cross-correlation
89///
90/// @tparam SampleType Floating point sample type, typically `float` or `double`
91///
92/// @throws hart::ValueError If `maxLagSeconds` is negative
93/// @throws hart::SampleRateError If sample rates differ
94/// @throws hart::IndexError If requested channel indices are out of range
95/// @throws hart::UnitError If unsupported unit is requested
96///
97/// @ingroup Metrics
98template <typename SampleType>
100 const AudioBuffer<SampleType>& bufferA,
101 const AudioBuffer<SampleType>& bufferB,
102 double maxLagSeconds,
103 double minAbsBestCorrelation = 0.5,
105)
106{
107 if (maxLagSeconds < 0.0)
108 HART_THROW_OR_RETURN (hart::ValueError, "Maximum lag must be non-negative", {});
109
110 if ((bufferA.hasSampleRate() || bufferB.hasSampleRate()) && bufferA.getSampleRateHz() != bufferB.getSampleRateHz())
111 HART_THROW_OR_RETURN (hart::SampleRateError, "Audio buffers must have equal sample rates", {});
112
113 if (minAbsBestCorrelation < 0 || minAbsBestCorrelation > 1.0)
114 HART_THROW_OR_RETURN (hart::ValueError, "minAbsBestCorrelation should be in 0..1 range", {});
115
116 typename MetricQuery<double>::ChannelPairMetricEvaluator evaluator =
117 [&bufferA, &bufferB, maxLagSeconds, minAbsBestCorrelation, searchMode]
118 (size_t channelA, size_t channelB, Slice slice, Unit requestedUnit)
119 -> double
120 {
121 // Should be checked by MetricQuery
122 hassert (channelA < bufferA.getNumChannels());
123 hassert (channelB < bufferB.getNumChannels());
124
125 if (requestedUnit != Unit::native &&
126 requestedUnit != Unit::frames &&
127 requestedUnit != Unit::seconds
128 )
129 {
130 HART_THROW_OR_RETURN (hart::UnitError, "Unsupported unit", hart::nan<double>());
131 }
132
133 if (requestedUnit == Unit::seconds)
134 {
135 if (! bufferA.hasSampleRate()
136 || ! bufferB.hasSampleRate()
137 )
138 {
139 HART_THROW_OR_RETURN (hart::SampleRateError, "Audio buffers must have sample rate metadata to convert lag to seconds", hart::nan<double>());
140 }
141
142 if (hart::floatsEqual (bufferA.getSampleRateHz(), 0.0)
143 || hart::floatsEqual (bufferB.getSampleRateHz(), 0.0)
144 )
145 {
146 HART_THROW_OR_RETURN (hart::SampleRateError, "Audio buffers must have non-zero sample rates to convert lag to seconds", hart::nan<double>());
147 }
148 }
149
150 // This might be a bit too strict. So, if a legit case with two buffers with
151 // mismatched lengths presents itself, remove this check and handle it properly.
152 if (bufferA.getNumFrames() != bufferB.getNumFrames())
153 HART_THROW_OR_RETURN (hart::SizeError, "Audio buffers must have matching n umber of frames", hart::nan<double>());
154
155 if (slice.isEmpty())
156 return hart::nan<double>();
157
158 const auto sliceFrameIndices = bufferA.getFrameIndices (slice);
159 const size_t sliceStart = sliceFrameIndices.first;
160 const size_t sliceStop = sliceFrameIndices.second;
161 hassert (sliceStop > sliceStart);
162 hassert (sliceStop <= bufferA.getNumFrames());
163 hassert (sliceStop <= bufferB.getNumFrames());
164
165 const size_t numFrames = sliceStop - sliceStart;
166 hassert (numFrames != 0);
167
168 const double sampleRateHz = bufferA.getSampleRateHz();
169 const size_t maxLagFrames = static_cast<size_t> (std::round (maxLagSeconds * sampleRateHz));
170
171 const SampleType* x = bufferA[channelA] + sliceStart;
172 const SampleType* y = bufferB[channelB] + sliceStart;
173
174 double bestCorrelation =
175 (searchMode == bestSignedCorrelation)
176 ? -hart::inf
177 : 0.0;
178
179 int bestLagFrames = 0;
180 bool hadValidOverlap = false;
181
182 for (int lag = -static_cast<int> (maxLagFrames); lag <= static_cast<int> (maxLagFrames); ++lag)
183 {
184 const bool lagIsNegative = lag < 0;
185 const size_t lagAbsFrames = static_cast<size_t> (lagIsNegative ? -lag : lag);
186
187 if (lagAbsFrames >= numFrames)
188 continue;
189
190 const size_t xBegin = lagIsNegative ? lagAbsFrames : 0;
191 const size_t yBegin = lagIsNegative ? 0 : lagAbsFrames;
192 const size_t overlapFrames = numFrames - lagAbsFrames;
193 AccurateSum<double> dotProduct;
194 AccurateSum<double> sumSquaresX;
195 AccurateSum<double> sumSquaresY;
196
197 for (size_t frame = 0; frame < overlapFrames; ++frame)
198 {
199 const double xn = static_cast<double> (x[xBegin + frame]);
200 const double yn = static_cast<double> (y[yBegin + frame]);
201
202 dotProduct += xn * yn;
203 sumSquaresX += xn * xn;
204 sumSquaresY += yn * yn;
205 }
206
207 const double energyX = sumSquaresX.getValue();
208 const double energyY = sumSquaresY.getValue();
209
210 if (floatsEqual (energyX, 0.0) || floatsEqual (energyY, 0.0))
211 continue;
212
213 hadValidOverlap = true;
214
215 const double correlation = dotProduct.getValue() / std::sqrt (energyX * energyY);
216
217 bool isBetter =
218 (searchMode == bestSignedCorrelation)
219 ? correlation > bestCorrelation
220 : std::abs (correlation) > std::abs (bestCorrelation);
221
222 if (isBetter)
223 {
224 bestCorrelation = correlation;
225 bestLagFrames = lag;
226 }
227
228 if (floatsEqual (std::abs (bestCorrelation), 1.0))
229 break;
230 }
231
232 if (std::abs (bestCorrelation) < minAbsBestCorrelation)
233 hadValidOverlap = false;
234
235 if (! hadValidOverlap)
236 return hart::nan<double>();
237
238 switch (requestedUnit)
239 {
240 case Unit::native:
241 case Unit::frames:
242 return static_cast<double> (bestLagFrames);
243
244 case Unit::seconds:
245 return bestLagFrames / sampleRateHz;
246
247 default: // Should be unreachable
248 HART_THROW_OR_RETURN (hart::UnitError, "Unsupported unit", hart::nan<double>());
249 }
250 };
251
252 const size_t numPairs = std::min (bufferA.getNumChannels(), bufferB.getNumChannels());
253 return MetricQuery<double> (
254 std::move (evaluator),
255 bufferA.getNumChannels(),
256 bufferB.getNumChannels(),
258 );
259}
260
261} // namespace hart
Implements Kahan algorithm for floating point accumulations.
SampleType getValue() const
AccurateSum & operator+=(SampleType value)
Adds a value to a sum, tracking the potential floating point error.
Container for audio data.
Manages the metrics calculations.
Thrown when sample rate is mismatched.
Thrown when an unexpected container size is encountered.
Thrown when some metric is requested to return a value in an unsupported unit.
Thrown when an inappropriate value is encountered.
#define hassert(condition)
Triggers a HartAssertException if the condition is false
#define HART_THROW_OR_RETURN(ExceptionType, message, returnValue)
Throws an exception if HART_DO_NOT_THROW_EXCEPTIONS is set, prints a message and returns a specified ...
MetricQuery< double > lagAtMaxCrossCorrelation(const AudioBuffer< SampleType > &bufferA, const AudioBuffer< SampleType > &bufferB, double maxLagSeconds, double minAbsBestCorrelation=0.5, CorrelationSearchMode searchMode=bestAbsoluteCorrelation)
Calculates lag corresponding to maximum normalized cross-correlation between two audio buffers.
FloatType nan()
Returns a quiet NaN value for the given floating-point type.
constexpr double inf
Infinity.
static SampleType floatsEqual(SampleType a, SampleType b, SampleType epsilon=(SampleType) 1e-8)
Compares two floating point numbers within a given tolerance.
CorrelationSearchMode
Describes how to look for best cross-correlation.
Unit
Represents a physical unit.
@ seconds
Time stamps, intervals, durations.
@ native
Default (native) unit of whatever returns some value.
@ frames
Value of something in frames (samples)
Helpers to generate common default channel subsets.
static std::vector< std::pair< size_t, size_t > > diagonalChannelPairs(size_t numChannels)
Represents a slice of analysis data.
bool isEmpty() const