Skip to content

Commit

Permalink
Added new pose estimation algorithms for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
mpulte committed Feb 5, 2024
1 parent 869e984 commit 4256ba3
Show file tree
Hide file tree
Showing 10 changed files with 783 additions and 85 deletions.
19 changes: 19 additions & 0 deletions src/main/java/com/team1701/lib/drivers/gyros/GyroIOSim.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
public class GyroIOSim implements GyroIO {
private Supplier<Rotation2d> mYawSupplier;
private boolean mYawSamplingEnabled;
private Rotation2d mYaw = GeometryUtil.kRotationIdentity;
private int mSamples = 0;

public GyroIOSim() {
mYawSupplier = () -> GeometryUtil.kRotationIdentity;
Expand All @@ -29,6 +31,18 @@ public void updateInputs(GyroInputs inputs) {
if (mYawSamplingEnabled) {
inputs.yawSamples = new Rotation2d[] {inputs.yaw};
}

if (mYawSamplingEnabled) {
var samples = mSamples;
inputs.yawSamples = new Rotation2d[samples];
var lerp = inputs.yaw.minus(mYaw).div(samples + 1);
for (int i = 0; i < samples; i++) {
inputs.yawSamples[i] = mYaw.plus(lerp.times(i + 1));
}
}

mYaw = inputs.yaw;
mSamples = 0;
}

@Override
Expand All @@ -37,6 +51,11 @@ public synchronized void enableYawSampling(SignalSamplingThread samplingThread)
throw new IllegalStateException("Yaw sampling already enabled");
}

samplingThread.addSignal(() -> {
mSamples++;
return 0.0; // We will interpolate in updateInputs
});

mYawSamplingEnabled = true;
}

Expand Down
38 changes: 31 additions & 7 deletions src/main/java/com/team1701/lib/drivers/motors/MotorIOSim.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public class MotorIOSim implements MotorIO {
private double mPositionRadians;
private boolean mPositionSamplingEnabled;
private boolean mVelocitySamplingEnabled;
private int mPositionSamples = 0;
private int mVelocitySamples = 0;

public MotorIOSim(DCMotor motor, double reduction, double jKgMetersSquared, double loopPeriodSeconds) {
mSim = new DCMotorSim(motor, 1.0 / reduction, jKgMetersSquared);
Expand All @@ -32,19 +34,31 @@ public void updateInputs(MotorInputs inputs) {

mSim.update(mLoopPeriodSeconds);

mVelocityRadiansPerSecond = mSim.getAngularVelocityRadPerSec();
mPositionRadians += mVelocityRadiansPerSecond * mLoopPeriodSeconds;

inputs.positionRadians = mPositionRadians;
inputs.velocityRadiansPerSecond = mVelocityRadiansPerSecond;
inputs.velocityRadiansPerSecond = mSim.getAngularVelocityRadPerSec();
inputs.positionRadians = mPositionRadians + inputs.velocityRadiansPerSecond * mLoopPeriodSeconds;

if (mPositionSamplingEnabled) {
inputs.positionRadiansSamples = new double[] {mPositionRadians};
var samples = mPositionSamples;
inputs.positionRadiansSamples = new double[samples];
var lerp = (inputs.positionRadians - mPositionRadians) / (samples + 1);
for (int i = 0; i < samples; i++) {
inputs.positionRadiansSamples[i] = mPositionRadians + lerp * (i + 1);
}
}

if (mVelocitySamplingEnabled) {
inputs.velocityRadiansPerSecondSamples = new double[] {mVelocityRadiansPerSecond};
var samples = mVelocitySamples;
inputs.velocityRadiansPerSecondSamples = new double[samples];
var lerp = (inputs.velocityRadiansPerSecond - mVelocityRadiansPerSecond) / (samples + 1);
for (int i = 0; i < samples; i++) {
inputs.velocityRadiansPerSecondSamples[i] = mVelocityRadiansPerSecond + lerp * (i + 1);
}
}

mPositionRadians = inputs.positionRadians;
mVelocityRadiansPerSecond = inputs.velocityRadiansPerSecond;
mPositionSamples = 0;
mVelocitySamples = 0;
}

@Override
Expand Down Expand Up @@ -85,6 +99,11 @@ public synchronized void enablePositionSampling(SignalSamplingThread samplingThr
throw new IllegalStateException("Position sampling already enabled");
}

samplingThread.addSignal(() -> {
mPositionSamples++;
return 0.0; // We will interpolate in updateInputs
});

mPositionSamplingEnabled = true;
}

Expand All @@ -94,6 +113,11 @@ public synchronized void enableVelocitySampling(SignalSamplingThread samplingThr
throw new IllegalStateException("Velocity sampling already enabled");
}

samplingThread.addSignal(() -> {
mVelocitySamples++;
return 0.0; // We will interpolate in updateInputs
});

mVelocitySamplingEnabled = true;
}

Expand Down
209 changes: 209 additions & 0 deletions src/main/java/com/team1701/lib/estimation/PoseEstimator1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
package com.team1701.lib.estimation;

import java.util.Arrays;
import java.util.Objects;
import java.util.stream.Stream;

import com.team1701.lib.swerve.ExtendedSwerveDriveKinematics;
import com.team1701.lib.util.GeometryUtil;
import edu.wpi.first.math.Matrix;
import edu.wpi.first.math.Nat;
import edu.wpi.first.math.VecBuilder;
import edu.wpi.first.math.geometry.Pose2d;
import edu.wpi.first.math.geometry.Rotation2d;
import edu.wpi.first.math.geometry.Twist2d;
import edu.wpi.first.math.interpolation.Interpolatable;
import edu.wpi.first.math.interpolation.TimeInterpolatableBuffer;
import edu.wpi.first.math.kinematics.Odometry;
import edu.wpi.first.math.kinematics.SwerveDriveWheelPositions;
import edu.wpi.first.math.kinematics.SwerveModulePosition;
import edu.wpi.first.math.numbers.N1;
import edu.wpi.first.math.numbers.N3;

public class PoseEstimator1 {
private final ExtendedSwerveDriveKinematics mKinematics;
private final Odometry<SwerveDriveWheelPositions> mOdometry;
private final Matrix<N3, N1> mKalmanQ = new Matrix<>(Nat.N3(), Nat.N1());

private static final double kBufferDuration = 0.5;
private final TimeInterpolatableBuffer<InterpolationRecord> mPoseBuffer =
TimeInterpolatableBuffer.createBuffer(kBufferDuration);

private DriveMeasurement mLastDriveMeasurement;

public static record DriveMeasurement(
double timestampSeconds, Rotation2d gyroAngle, SwerveDriveWheelPositions modulePositions) {
public DriveMeasurement(double timestampSeconds, Rotation2d gyroAngle, SwerveModulePosition[] modulePositions) {
this(timestampSeconds, gyroAngle, new SwerveDriveWheelPositions(modulePositions));
}
}

public static record VisionMeasurement(double timestampSeconds, Pose2d pose, Matrix<N3, N1> stdDevs) {}

public PoseEstimator1(ExtendedSwerveDriveKinematics kinematics, Matrix<N3, N1> stateStdDevs) {
mKinematics = kinematics;

var positions = new SwerveModulePosition[kinematics.getNumModules()];
Arrays.fill(positions, new SwerveModulePosition());
mOdometry = new Odometry<>(
mKinematics,
GeometryUtil.kRotationIdentity,
new SwerveDriveWheelPositions(positions),
GeometryUtil.kPoseIdentity);

for (int i = 0; i < 3; ++i) {
mKalmanQ.set(i, 0, stateStdDevs.get(i, 0) * stateStdDevs.get(i, 0));
}
}

public void resetPose(Pose2d pose) {
mOdometry.resetPosition(mLastDriveMeasurement.gyroAngle, mLastDriveMeasurement.modulePositions, pose);
}

public Pose2d getEstimatedPose() {
return mOdometry.getPoseMeters();
}

public void addDriveMeasurement(DriveMeasurement measurement) {
var modulePositions = measurement.modulePositions.copy();
mOdometry.update(measurement.gyroAngle, modulePositions);
mPoseBuffer.addSample(
measurement.timestampSeconds,
new InterpolationRecord(mOdometry.getPoseMeters(), measurement.gyroAngle, modulePositions));
mLastDriveMeasurement =
new DriveMeasurement(measurement.timestampSeconds, measurement.gyroAngle, modulePositions);
}

public void addVisionMeasurements(VisionMeasurement[] visionMeasurements) {
var bufferTimespanThreshold = mPoseBuffer.getInternalBuffer().lastKey() - kBufferDuration;

Stream.of(visionMeasurements)
.filter(measurement -> measurement.timestampSeconds() > bufferTimespanThreshold)
.sorted((a, b) -> Double.compare(a.timestampSeconds(), b.timestampSeconds()))
.toArray(VisionMeasurement[]::new);

// (https://github.com/wpilibsuite/allwpilib/blob/main/wpimath/src/main/java/edu/wpi/first/math/estimator/)

for (var i = 0; i < visionMeasurements.length; i++) {
var measurement = visionMeasurements[i];

// Step 1: Get the pose odometry measured at the moment the vision measurement was made.
var sample = mPoseBuffer.getSample(measurement.timestampSeconds);
if (sample.isEmpty()) {
return;
}

// Step 2: Measure the twist between the odometry pose and the vision pose.
var twist = sample.get().poseMeters.log(measurement.pose);

// Step 3: We should not trust the twist entirely, so instead we scale this twist by a Kalman
// gain matrix representing how much we trust vision measurements compared to our current pose.
var kTimesTwist =
calculateKalmanGain(measurement.stdDevs).times(VecBuilder.fill(twist.dx, twist.dy, twist.dtheta));

// Step 4: Convert back to Twist2d.
var scaledTwist = new Twist2d(kTimesTwist.get(0, 0), kTimesTwist.get(1, 0), kTimesTwist.get(2, 0));

// Step 5: Reset Odometry to state at sample with vision adjustment.
mOdometry.resetPosition(
sample.get().gyroAngle,
sample.get().wheelPositions,
sample.get().poseMeters.exp(scaledTwist));

// Step 6: Record the current pose to allow multiple measurements from the same timestamp
mPoseBuffer.addSample(
measurement.timestampSeconds,
new InterpolationRecord(getEstimatedPose(), sample.get().gyroAngle, sample.get().wheelPositions));

// Step 7: Replay odometry inputs to update the pose buffer and correct odometry.
var entries = mPoseBuffer
.getInternalBuffer()
.tailMap(measurement.timestampSeconds)
.entrySet();
var maxTimestamp =
i + 1 < visionMeasurements.length ? visionMeasurements[i + 1].timestampSeconds() : Double.MAX_VALUE;
for (var entry : entries) {
addDriveMeasurement(new DriveMeasurement(
entry.getKey(), entry.getValue().gyroAngle, entry.getValue().wheelPositions));

// Need to update one entry past next vision measurement to allow for interpolation
if (entry.getKey() > maxTimestamp) {
break;
}
}
}
}

private Matrix<N3, N3> calculateKalmanGain(Matrix<N3, N1> visionMeasurementStdDevs) {
// (https://github.com/wpilibsuite/allwpilib/blob/main/wpimath/src/main/java/edu/wpi/first/math/estimator/)
var r = new double[3];
for (int i = 0; i < 3; ++i) {
r[i] = visionMeasurementStdDevs.get(i, 0) * visionMeasurementStdDevs.get(i, 0);
}

var visionK = new Matrix<>(Nat.N3(), Nat.N3());

for (int row = 0; row < 3; ++row) {
if (mKalmanQ.get(row, 0) == 0.0) {
visionK.set(row, row, 0.0);
} else {
visionK.set(
row,
row,
mKalmanQ.get(row, 0) / (mKalmanQ.get(row, 0) + Math.sqrt(mKalmanQ.get(row, 0) * r[row])));
}
}

return visionK;
}

// TODO: Consider using Twist2d instead of gyro/wheels
private class InterpolationRecord implements Interpolatable<InterpolationRecord> {
private final Pose2d poseMeters;
private final Rotation2d gyroAngle;
private final SwerveDriveWheelPositions wheelPositions;

private InterpolationRecord(Pose2d poseMeters, Rotation2d gyro, SwerveDriveWheelPositions wheelPositions) {
this.poseMeters = poseMeters;
this.gyroAngle = gyro;
this.wheelPositions = wheelPositions;
}

@Override
public InterpolationRecord interpolate(InterpolationRecord endValue, double t) {
if (t < 0) {
return this;
} else if (t >= 1) {
return endValue;
} else {
var wheelLerp = wheelPositions.interpolate(endValue.wheelPositions, t);
var gyroLerp = gyroAngle.interpolate(endValue.gyroAngle, t);

// Create a twist to represent the change based on the interpolated sensor inputs.
Twist2d twist = mKinematics.toTwist2d(wheelPositions, wheelLerp);
twist.dtheta = gyroLerp.minus(gyroAngle).getRadians();

return new InterpolationRecord(poseMeters.exp(twist), gyroLerp, wheelLerp);
}
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (!(obj instanceof PoseEstimator1.InterpolationRecord)) {
return false;
}
var record = (InterpolationRecord) obj;
return Objects.equals(gyroAngle, record.gyroAngle)
&& Objects.equals(wheelPositions, record.wheelPositions)
&& Objects.equals(poseMeters, record.poseMeters);
}

@Override
public int hashCode() {
return Objects.hash(gyroAngle, wheelPositions, poseMeters);
}
}
}
Loading

0 comments on commit 4256ba3

Please sign in to comment.