HART  0.2.0
High level Audio Regression and Testing
Loading...
Searching...
No Matches
hart_correlation_latency_detector.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <sstream>
4#include <vector>
5
11#include "hart_utils.hpp" // make_unique(), roundToSizeT(), inf, floatsEqual()
12
13namespace hart
14{
15
16/// @brief Correlation-based latency detector implementation for the hart::LatencyBelow class. For internal use.
17/// @private
18template <typename SampleType>
19class CorrelationLatencyDetector :
20 public LatencyDetector<SampleType>
21{
22public:
23 CorrelationLatencyDetector (double maxLatencySeconds, SilencePolicy silencePolicy, double absCorrelationThreshold) :
24 m_maxLatencySeconds (maxLatencySeconds),
25 m_silencePolicy (silencePolicy),
26 m_absCorrelationThreshold (absCorrelationThreshold)
27 {
28 if (absCorrelationThreshold > 1.0)
29 HART_THROW_OR_RETURN_VOID (hart::ValueError, "Normalized correlation threshold can not be higher than 1");
30
31 if (absCorrelationThreshold < 0.0)
32 HART_THROW_OR_RETURN_VOID (hart::ValueError, "Correlation threshold is an absolute value, so should not be negative");
33
34 // Technically, zero correlation is okay, but a bit too weird...
35 if (floatsEqual (absCorrelationThreshold, 0.0))
36 HART_THROW_OR_RETURN_VOID (hart::ValueError, "Zero correlation threshold is not a meaningful value in latency detector context");
37 }
38
39 std::unique_ptr<LatencyDetector<SampleType>> copy() const override
40 {
41 return hart::make_unique<CorrelationLatencyDetector<SampleType>> (*this);
42 }
43
44 void prepare (double sampleRateHz, size_t /* numChannels */, size_t /* maxBlockSizeFrames */) override
45 {
46 m_sampleRateHz = sampleRateHz;
47 }
48
49 void reset() override
50 {
51 m_hadValidData = false;
52 m_failureChannel = 0;
53 m_failureFrame = 0;
54 m_detectedLatencyFrames = 0;
55 m_bestCorrelation = 0.0;
56 }
57
58 bool match (
59 const AudioBuffer<SampleType>& inputAudio,
60 const AudioBuffer<SampleType>& observedOutputAudio,
61 const std::function<bool (size_t)>& appliesToChannel
62 ) override
63 {
64 const size_t numFrames = inputAudio.getNumFrames();
65
66 if (numFrames == 0)
67 {
68 m_hadValidData = false;
69 return false;
70 }
71
72 const size_t maxLagFrames = numFrames - 1;
73 bool anyValidChannel = false;
74 size_t worstLatencyFrames = 0;
75 size_t worstChannel = 0;
76
77 for (size_t channel = 0; channel < inputAudio.getNumChannels(); ++channel)
78 {
79 if (! appliesToChannel (channel))
80 continue;
81
82 const SampleType* x = inputAudio[channel];
83 const SampleType* y = observedOutputAudio[channel];
84 std::vector<double> prefixSumsSqX (numFrames + 1, 0.0);
85 std::vector<double> prefixSumsSqY (numFrames + 1, 0.0);
86 AccurateSum<double> runningSumSqX { 0.0 };
87 AccurateSum<double> runningSumSqY { 0.0 };
88
89 for (size_t frame = 0; frame < numFrames; ++frame)
90 {
91 const double xVal = static_cast<double> (x[frame]);
92 const double yVal = static_cast<double> (y[frame]);
93
94 runningSumSqX += xVal * xVal;
95 runningSumSqY += yVal * yVal;
96 prefixSumsSqX[frame + 1] = runningSumSqX;
97 prefixSumsSqY[frame + 1] = runningSumSqY;
98 }
99
100 double bestAbsCorrelation = -hart::inf;
101 size_t bestLag = 0;
102 bool channelValid = false;
103
104 for (size_t lag = 0; lag <= maxLagFrames; ++lag)
105 {
106 AccurateSum<double> dotProduct = { 0.0 };
107 const size_t inputOverlapBeginFrame = 0;
108 const size_t outputOverlapBeginFrame = lag;
109 const size_t overlapSizeFrames = numFrames - lag;
110 const size_t inputOverlapEndFrame = inputOverlapBeginFrame + overlapSizeFrames;
111 const size_t outputOverlapEndFrame = outputOverlapBeginFrame + overlapSizeFrames;
112 const double sumSqX = prefixSumsSqX[inputOverlapEndFrame] - prefixSumsSqX[inputOverlapBeginFrame];
113 const double sumSqY = prefixSumsSqY[outputOverlapEndFrame] - prefixSumsSqY[outputOverlapBeginFrame];
114
115 for (size_t overlapFrame = 0; overlapFrame < overlapSizeFrames; ++overlapFrame)
116 {
117 const double inputValue = static_cast<double> (x[inputOverlapBeginFrame + overlapFrame]);
118 const double outputValue = static_cast<double> (y[outputOverlapBeginFrame + overlapFrame]);
119 dotProduct += inputValue * outputValue;
120 }
121
122 if (floatsEqual (sumSqX, 0.0) || floatsEqual (sumSqY, 0.0))
123 continue;
124
125 channelValid = true;
126 const double correlation = dotProduct / std::sqrt (sumSqX * sumSqY);
127 const double absCorrelation = std::abs (correlation);
128
129 if (absCorrelation > bestAbsCorrelation)
130 {
131 bestAbsCorrelation = absCorrelation;
132 bestLag = lag;
133 }
134
135 if (floatsEqual (bestAbsCorrelation, 1.0))
136 break;
137 }
138
139 if (! channelValid || bestAbsCorrelation < m_absCorrelationThreshold)
140 {
141 if (m_silencePolicy == SilencePolicy::strict)
142 {
143 m_hadValidData = false;
144 m_failureChannel = channel;
145 return false;
146 }
147
148 continue;
149 }
150
151 anyValidChannel = true;
152
153 if (bestLag > worstLatencyFrames)
154 {
155 worstLatencyFrames = bestLag;
156 worstChannel = channel;
157 m_bestCorrelation = bestAbsCorrelation;
158 }
159 }
160
161 if (!anyValidChannel)
162 {
163 m_hadValidData = false;
164 return false;
165 }
166
167 m_hadValidData = true;
168 m_detectedLatencyFrames = worstLatencyFrames;
169 const double latencySeconds = m_detectedLatencyFrames / m_sampleRateHz;
170
171 if (latencySeconds <= m_maxLatencySeconds)
172 return true;
173
174 m_failureChannel = worstChannel;
175 return false;
176 }
177
178 MatcherFailureDetails getFailureDetails() const override
179 {
180 MatcherFailureDetails details;
181 details.channel = m_failureChannel;
182 details.frame = m_failureFrame;
183
184 if (! m_hadValidData)
185 {
186 details.description = "Latency could not be determined with sufficient correlation";
187 return details;
188 }
189
190 const double latencySeconds = m_detectedLatencyFrames / m_sampleRateHz;
191
192 std::stringstream descriptionStream;
193 descriptionStream
194 << "Detected latency: "
195 << secPrecision << latencySeconds << " seconds ("
196 << m_detectedLatencyFrames << " frames), "
197 << "best correlation: " << correlationPrecision << m_bestCorrelation;
198
199 details.description = descriptionStream.str();
200 return details;
201 }
202
203private:
204 const double m_maxLatencySeconds;
205 const SilencePolicy m_silencePolicy;
206 const double m_absCorrelationThreshold;
207 double m_sampleRateHz = 0.0;
208
209 bool m_hadValidData = false;
210 size_t m_detectedLatencyFrames = 0;
211 double m_bestCorrelation = 0.0;
212 size_t m_failureChannel = 0;
213 size_t m_failureFrame = 0;
214};
215
216} // namespace hart
Implements Kahan algorithm for floating point accumulations.
AccurateSum(SampleType initialSum=(SampleType) 0)
Inits AccurateSum with a specific value.
AccurateSum & operator+=(SampleType value)
Adds a value to a sum, tracking the potential floating point error.
Container for audio data.
Thrown when an inappropriate value is encountered.
#define HART_THROW_OR_RETURN_VOID(ExceptionType, message)
Throws an exception if HART_DO_NOT_THROW_EXCEPTIONS is set, prints a message and returns otherwise.
std::ostream & secPrecision(std::ostream &stream)
Sets number of decimal places for values in seconds.
static std::ostream & correlationPrecision(std::ostream &stream)
Sets number of decimal places for correlation values.
SilencePolicy
Defines how silence in various algorithms.
constexpr double inf
Infinity.
static SampleType floatsEqual(SampleType a, SampleType b, SampleType epsilon=(SampleType) 1e-8)
Compares two floating point numbers within a given tolerance.
size_t channel
Index of channel at which the failure was detected.
std::string description
Readable description of why the match has failed.
size_t frame
Index of frame at which the match has failed.