blob: 53211e3820eda4184a66b0605b008881341910d8 [file] [log] [blame]
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "PosePredictorVerifier.h"
#include <memory>
#include <audio_utils/Statistics.h>
#include <media/PosePredictorType.h>
#include <media/Twist.h>
#include <media/VectorRecorder.h>
namespace android::media {
// Interface for generic pose predictors
class PredictorBase {
public:
virtual ~PredictorBase() = default;
virtual void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) = 0;
virtual Pose3f predict(int64_t atNs) const = 0;
virtual void reset() = 0;
virtual std::string name() const = 0;
virtual std::string toString(size_t index) const = 0;
};
/**
* LastPredictor uses the last sample Pose for prediction
*
* This class is not thread-safe.
*/
class LastPredictor : public PredictorBase {
public:
void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) override {
(void)atNs;
(void)twist;
mLastPose = pose;
}
Pose3f predict(int64_t atNs) const override {
(void)atNs;
return mLastPose;
}
void reset() override {
mLastPose = {};
}
std::string name() const override {
return "LAST";
}
std::string toString(size_t index) const override {
std::string s(index, ' ');
s.append("LastPredictor using last pose: ")
.append(mLastPose.toString())
.append("\n");
return s;
}
private:
Pose3f mLastPose;
};
/**
* TwistPredictor uses the last sample Twist and Pose for prediction
*
* This class is not thread-safe.
*/
class TwistPredictor : public PredictorBase {
public:
void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) override {
mLastAtNs = atNs;
mLastPose = pose;
mLastTwist = twist;
}
Pose3f predict(int64_t atNs) const override {
return mLastPose * integrate(mLastTwist, atNs - mLastAtNs);
}
void reset() override {
mLastAtNs = {};
mLastPose = {};
mLastTwist = {};
}
std::string name() const override {
return "TWIST";
}
std::string toString(size_t index) const override {
std::string s(index, ' ');
s.append("TwistPredictor using last pose: ")
.append(mLastPose.toString())
.append(" last twist: ")
.append(mLastTwist.toString())
.append("\n");
return s;
}
private:
int64_t mLastAtNs{};
Pose3f mLastPose;
Twist3f mLastTwist;
};
/**
* LeastSquaresPredictor uses the Pose history for prediction.
*
* A exponential weighted least squares is used.
*
* This class is not thread-safe.
*/
class LeastSquaresPredictor : public PredictorBase {
public:
// alpha is the exponential decay.
LeastSquaresPredictor(double alpha = kDefaultAlphaEstimator)
: mAlpha(alpha)
, mRw(alpha)
, mRx(alpha)
, mRy(alpha)
, mRz(alpha)
{}
void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) override;
Pose3f predict(int64_t atNs) const override;
void reset() override;
std::string name() const override {
return "LEAST_SQUARES(" + std::to_string(mAlpha) + ")";
}
std::string toString(size_t index) const override;
private:
const double mAlpha;
int64_t mLastAtNs{};
Pose3f mLastPose;
static constexpr double kDefaultAlphaEstimator = 0.2;
static constexpr size_t kMinimumSamplesForPrediction = 4;
audio_utils::LinearLeastSquaresFit<double> mRw;
audio_utils::LinearLeastSquaresFit<double> mRx;
audio_utils::LinearLeastSquaresFit<double> mRy;
audio_utils::LinearLeastSquaresFit<double> mRz;
};
/*
* PosePredictor predicts the pose given sensor input at a time in the future.
*
* This class is not thread safe.
*/
class PosePredictor {
public:
PosePredictor();
Pose3f predict(int64_t timestampNs, const Pose3f& pose, const Twist3f& twist,
float predictionDurationNs);
void setPosePredictorType(PosePredictorType type);
// convert predictions to a printable string
std::string toString(size_t index) const;
private:
static constexpr int64_t kMaximumSampleIntervalBeforeResetNs =
300'000'000;
// Predictors
const std::vector<std::shared_ptr<PredictorBase>> mPredictors;
// Verifiers, create one for an array of future lookaheads for comparison.
const std::vector<int> mLookaheadMs;
std::vector<PosePredictorVerifier> mVerifiers;
const std::vector<size_t> mDelimiterIdx;
// Recorders
media::VectorRecorder mPredictionRecorder{
std::size(mVerifiers) /* vectorSize */, std::chrono::seconds(1), 10 /* maxLogLine */,
mDelimiterIdx};
media::VectorRecorder mPredictionDurableRecorder{
std::size(mVerifiers) /* vectorSize */, std::chrono::minutes(1), 10 /* maxLogLine */,
mDelimiterIdx};
// Status
// SetType is the externally set predictor type. It may include AUTO.
PosePredictorType mSetType = PosePredictorType::LEAST_SQUARES;
// CurrentType is the actual predictor type used by this class.
// It does not include AUTO because that metatype means the class
// chooses the best predictor type based on sensor statistics.
PosePredictorType mCurrentType = PosePredictorType::LEAST_SQUARES;
int64_t mResets{};
int64_t mLastTimestampNs{};
// Returns current predictor
std::shared_ptr<PredictorBase> getCurrentPredictor() const;
};
} // namespace android::media