From 37d3ca28492c29ea2e2a13ce1f95c43d9207a31e Mon Sep 17 00:00:00 2001 From: Peter Nied Date: Thu, 11 Jan 2024 11:56:23 -0600 Subject: [PATCH] HeapBasedRateTracker uses time provider to allow simluating of time in unit tests (#3934) Signed-off-by: Peter Nied --- .../ratetracking/HeapBasedRateTracker.java | 10 ++++++- .../limiting/HeapBasedRateTrackerTest.java | 27 ++++++++++--------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/opensearch/security/util/ratetracking/HeapBasedRateTracker.java b/src/main/java/org/opensearch/security/util/ratetracking/HeapBasedRateTracker.java index 40b1f622d0..46aa577254 100644 --- a/src/main/java/org/opensearch/security/util/ratetracking/HeapBasedRateTracker.java +++ b/src/main/java/org/opensearch/security/util/ratetracking/HeapBasedRateTracker.java @@ -18,8 +18,10 @@ package org.opensearch.security.util.ratetracking; import java.util.Arrays; +import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.function.LongSupplier; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; @@ -33,16 +35,22 @@ public class HeapBasedRateTracker implements RateTracker cache; + private final LongSupplier timeProvider; private final long timeWindowMs; private final int maxTimeOffsets; public HeapBasedRateTracker(long timeWindowMs, int allowedTries, int maxEntries) { + this(timeWindowMs, allowedTries, maxEntries, null); + } + + public HeapBasedRateTracker(long timeWindowMs, int allowedTries, int maxEntries, LongSupplier timeProvider) { if (allowedTries < 2) { throw new IllegalArgumentException("allowedTries must be >= 2"); } this.timeWindowMs = timeWindowMs; this.maxTimeOffsets = allowedTries > 2 ? allowedTries - 2 : 0; + this.timeProvider = Optional.ofNullable(timeProvider).orElse(System::currentTimeMillis); this.cache = CacheBuilder.newBuilder() .expireAfterAccess(this.timeWindowMs, TimeUnit.MILLISECONDS) .maximumSize(maxEntries) @@ -89,7 +97,7 @@ private class ClientRecord { private short timeOffsetEnd = -1; synchronized boolean track() { - long timestamp = System.currentTimeMillis(); + long timestamp = timeProvider.getAsLong(); if (this.startTime == -1 || timestamp - getMostRecent() >= timeWindowMs) { this.startTime = timestamp; diff --git a/src/test/java/org/opensearch/security/auth/limiting/HeapBasedRateTrackerTest.java b/src/test/java/org/opensearch/security/auth/limiting/HeapBasedRateTrackerTest.java index c92c328564..aaae27e8c3 100644 --- a/src/test/java/org/opensearch/security/auth/limiting/HeapBasedRateTrackerTest.java +++ b/src/test/java/org/opensearch/security/auth/limiting/HeapBasedRateTrackerTest.java @@ -17,7 +17,9 @@ package org.opensearch.security.auth.limiting; -import org.junit.Ignore; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.LongSupplier; + import org.junit.Test; import org.opensearch.security.util.ratetracking.HeapBasedRateTracker; @@ -27,9 +29,12 @@ public class HeapBasedRateTrackerTest { + private final AtomicLong currentTime = new AtomicLong(1); + private LongSupplier timeProvider = () -> currentTime.getAndAdd(1); + @Test public void simpleTest() throws Exception { - HeapBasedRateTracker tracker = new HeapBasedRateTracker<>(100, 5, 100_000); + HeapBasedRateTracker tracker = new HeapBasedRateTracker<>(100, 5, 100_000, timeProvider); assertFalse(tracker.track("a")); assertFalse(tracker.track("a")); @@ -40,9 +45,8 @@ public void simpleTest() throws Exception { } @Test - @Ignore // https://github.com/opensearch-project/security/issues/2193 public void expiryTest() throws Exception { - HeapBasedRateTracker tracker = new HeapBasedRateTracker<>(100, 5, 100_000); + HeapBasedRateTracker tracker = new HeapBasedRateTracker<>(100, 5, 100_000, timeProvider); assertFalse(tracker.track("a")); assertFalse(tracker.track("a")); @@ -58,20 +62,20 @@ public void expiryTest() throws Exception { assertFalse(tracker.track("c")); - Thread.sleep(50); + currentTime.addAndGet(50); assertFalse(tracker.track("c")); assertFalse(tracker.track("c")); assertFalse(tracker.track("c")); - Thread.sleep(55); + currentTime.addAndGet(55); assertFalse(tracker.track("c")); assertTrue(tracker.track("c")); assertFalse(tracker.track("a")); - Thread.sleep(55); + currentTime.addAndGet(55); assertFalse(tracker.track("c")); assertFalse(tracker.track("c")); assertTrue(tracker.track("c")); @@ -79,21 +83,20 @@ public void expiryTest() throws Exception { } @Test - @Ignore // https://github.com/opensearch-project/security/issues/2193 public void maxTwoTriesTest() throws Exception { - HeapBasedRateTracker tracker = new HeapBasedRateTracker<>(100, 2, 100_000); + HeapBasedRateTracker tracker = new HeapBasedRateTracker<>(100, 2, 100_000, timeProvider); assertFalse(tracker.track("a")); assertTrue(tracker.track("a")); assertFalse(tracker.track("b")); - Thread.sleep(50); + currentTime.addAndGet(50); assertTrue(tracker.track("b")); - Thread.sleep(55); + currentTime.addAndGet(55); assertTrue(tracker.track("b")); - Thread.sleep(105); + currentTime.addAndGet(105); assertFalse(tracker.track("b")); assertTrue(tracker.track("b"));