diff --git a/common/src/main/java/software/amazon/s3/analyticsaccelerator/request/ObjectClient.java b/common/src/main/java/software/amazon/s3/analyticsaccelerator/request/ObjectClient.java index ec268be4..e5e5d440 100644 --- a/common/src/main/java/software/amazon/s3/analyticsaccelerator/request/ObjectClient.java +++ b/common/src/main/java/software/amazon/s3/analyticsaccelerator/request/ObjectClient.java @@ -36,4 +36,13 @@ public interface ObjectClient extends Closeable { * @return ResponseInputStream */ CompletableFuture getObject(GetRequest getRequest); + + /** + * Make a getObject request to the object store. + * + * @param getRequest The GET request to be sent + * @param streamContext audit headers to be attached in the request header + * @return ResponseInputStream + */ + CompletableFuture getObject(GetRequest getRequest, StreamContext streamContext); } diff --git a/common/src/main/java/software/amazon/s3/analyticsaccelerator/request/StreamContext.java b/common/src/main/java/software/amazon/s3/analyticsaccelerator/request/StreamContext.java new file mode 100644 index 00000000..fdafb586 --- /dev/null +++ b/common/src/main/java/software/amazon/s3/analyticsaccelerator/request/StreamContext.java @@ -0,0 +1,54 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package software.amazon.s3.analyticsaccelerator.request; +/** + * The StreamContext interface provides methods for modifying and building referrer header which + * will then be attached to subsequent HTTP requests. + */ +public interface StreamContext { + + /** + * Modifies and builds the referrer header string for a given request context. + * + *

Implementation Note: To ensure thread safety, implementations should create and modify a + * copy of the internal state rather than modifying the original object directly. This is crucial + * as multiple threads may be accessing the same StreamContext instance concurrently. + * + *

Example implementation: + * + *

+   * public class S3AStreamContext implements StreamContext {
+   *     private final HttpReferrerAuditHeader referrer;
+   *
+   *     public S3AStreamContext(HttpReferrerAuditHeader referrer) {
+   *         this.referrer = referrer;
+   *     }
+   *
+   *     @Override
+   *     public String modifyAndBuildReferrerHeader(GetRequest getRequestContext) {
+   *         // Create a copy to ensure thread safety
+   *         HttpReferrerAuditHeader copyReferrer = new HttpReferrerAuditHeader(this.referrer);
+   *         copyReferrer.set(AuditConstants.PARAM_RANGE, getRequestContext.getRange().toHttpString());
+   *         return copyReferrer.buildHttpReferrer();
+   *     }
+   * }
+   * 
+ * + * @param getRequestContext the request context for building the referrer header + * @return the modified and built referrer header as a String + */ + public String modifyAndBuildReferrerHeader(GetRequest getRequestContext); +} diff --git a/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/S3SeekableInputStreamFactory.java b/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/S3SeekableInputStreamFactory.java index 5f9043ca..f9029e64 100644 --- a/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/S3SeekableInputStreamFactory.java +++ b/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/S3SeekableInputStreamFactory.java @@ -29,6 +29,7 @@ import software.amazon.s3.analyticsaccelerator.io.physical.impl.PhysicalIOImpl; import software.amazon.s3.analyticsaccelerator.request.ObjectClient; import software.amazon.s3.analyticsaccelerator.request.ObjectMetadata; +import software.amazon.s3.analyticsaccelerator.request.StreamContext; import software.amazon.s3.analyticsaccelerator.util.ObjectFormatSelector; import software.amazon.s3.analyticsaccelerator.util.S3URI; @@ -106,12 +107,28 @@ public S3SeekableInputStream createStream(@NonNull S3URI s3URI, long contentLeng return new S3SeekableInputStream(s3URI, createLogicalIO(s3URI), telemetry); } + /** + * Create an instance of S3SeekableInputStream with streamContext. + * + * @param s3URI the object's S3 URI + * @param streamContext contains audit headers to be attached in request header + * @return An instance of the input stream. + */ + public S3SeekableInputStream createStream(@NonNull S3URI s3URI, StreamContext streamContext) { + return new S3SeekableInputStream(s3URI, createLogicalIO(s3URI, streamContext), telemetry); + } + LogicalIO createLogicalIO(S3URI s3URI) { + return createLogicalIO(s3URI, null); + } + + LogicalIO createLogicalIO(S3URI s3URI, StreamContext streamContext) { switch (objectFormatSelector.getObjectFormat(s3URI)) { case PARQUET: return new ParquetLogicalIOImpl( s3URI, - new PhysicalIOImpl(s3URI, objectMetadataStore, objectBlobStore, telemetry), + new PhysicalIOImpl( + s3URI, objectMetadataStore, objectBlobStore, telemetry, streamContext), telemetry, configuration.getLogicalIOConfiguration(), parquetColumnPrefetchStore); @@ -119,7 +136,8 @@ LogicalIO createLogicalIO(S3URI s3URI) { default: return new DefaultLogicalIOImpl( s3URI, - new PhysicalIOImpl(s3URI, objectMetadataStore, objectBlobStore, telemetry), + new PhysicalIOImpl( + s3URI, objectMetadataStore, objectBlobStore, telemetry, streamContext), telemetry); } } diff --git a/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlobStore.java b/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlobStore.java index 0fdbabca..04564c08 100644 --- a/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlobStore.java +++ b/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlobStore.java @@ -24,6 +24,7 @@ import software.amazon.s3.analyticsaccelerator.common.telemetry.Telemetry; import software.amazon.s3.analyticsaccelerator.io.physical.PhysicalIOConfiguration; import software.amazon.s3.analyticsaccelerator.request.ObjectClient; +import software.amazon.s3.analyticsaccelerator.request.StreamContext; import software.amazon.s3.analyticsaccelerator.util.S3URI; /** A BlobStore is a container for Blobs and functions as a data cache. */ @@ -69,16 +70,18 @@ protected boolean removeEldestEntry(final Map.Entry eldest) { * Opens a new blob if one does not exist or returns the handle to one that exists already. * * @param s3URI the S3 URI of the object + * @param streamContext contains audit headers to be attached in the request header * @return the blob representing the object from the BlobStore */ - public Blob get(S3URI s3URI) { + public Blob get(S3URI s3URI, StreamContext streamContext) { return blobMap.computeIfAbsent( s3URI, uri -> new Blob( uri, metadataStore, - new BlockManager(uri, objectClient, metadataStore, telemetry, configuration), + new BlockManager( + uri, objectClient, metadataStore, telemetry, configuration, streamContext), telemetry)); } diff --git a/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/data/Block.java b/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/data/Block.java index a5da0d18..022b5319 100644 --- a/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/data/Block.java +++ b/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/data/Block.java @@ -28,6 +28,7 @@ import software.amazon.s3.analyticsaccelerator.request.Range; import software.amazon.s3.analyticsaccelerator.request.ReadMode; import software.amazon.s3.analyticsaccelerator.request.Referrer; +import software.amazon.s3.analyticsaccelerator.request.StreamContext; import software.amazon.s3.analyticsaccelerator.util.S3URI; import software.amazon.s3.analyticsaccelerator.util.StreamAttributes; import software.amazon.s3.analyticsaccelerator.util.StreamUtils; @@ -51,7 +52,7 @@ public class Block implements Closeable { private static final String OPERATION_BLOCK_GET_JOIN = "block.get.join"; /** - * Constructs a Block. data. + * Constructs a Block data. * * @param s3URI the S3 URI of the object * @param objectClient the object client to use to interact with the object store @@ -69,6 +70,32 @@ public Block( long end, long generation, @NonNull ReadMode readMode) { + + this(s3URI, objectClient, telemetry, start, end, generation, readMode, null); + } + + /** + * Constructs a Block data. + * + * @param s3URI the S3 URI of the object + * @param objectClient the object client to use to interact with the object store + * @param telemetry an instance of {@link Telemetry} to use + * @param start start of the block + * @param end end of the block + * @param generation generation of the block in a sequential read pattern (should be 0 by default) + * @param readMode read mode describing whether this is a sync or async fetch + * @param streamContext contains audit headers to be attached in the request header + */ + public Block( + @NonNull S3URI s3URI, + @NonNull ObjectClient objectClient, + @NonNull Telemetry telemetry, + long start, + long end, + long generation, + @NonNull ReadMode readMode, + StreamContext streamContext) { + Preconditions.checkArgument( 0 <= generation, "`generation` must be non-negative; was: %s", generation); Preconditions.checkArgument(0 <= start, "`start` must be non-negative; was: %s", start); @@ -97,7 +124,8 @@ public Block( .s3Uri(this.s3URI) .range(this.range) .referrer(new Referrer(range.toHttpString(), readMode)) - .build())); + .build(), + streamContext)); this.data = this.source.thenApply(StreamUtils::toByteArray); } diff --git a/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlockManager.java b/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlockManager.java index f59bd42c..0824c7d2 100644 --- a/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlockManager.java +++ b/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlockManager.java @@ -29,6 +29,7 @@ import software.amazon.s3.analyticsaccelerator.request.ObjectClient; import software.amazon.s3.analyticsaccelerator.request.Range; import software.amazon.s3.analyticsaccelerator.request.ReadMode; +import software.amazon.s3.analyticsaccelerator.request.StreamContext; import software.amazon.s3.analyticsaccelerator.util.S3URI; import software.amazon.s3.analyticsaccelerator.util.StreamAttributes; @@ -44,6 +45,7 @@ public class BlockManager implements Closeable { private final IOPlanner ioPlanner; private final PhysicalIOConfiguration configuration; private final RangeOptimiser rangeOptimiser; + private StreamContext streamContext; private static final String OPERATION_MAKE_RANGE_AVAILABLE = "block.manager.make.range.available"; @@ -62,6 +64,26 @@ public BlockManager( @NonNull MetadataStore metadataStore, @NonNull Telemetry telemetry, @NonNull PhysicalIOConfiguration configuration) { + this(s3URI, objectClient, metadataStore, telemetry, configuration, null); + } + + /** + * Constructs a new BlockManager. + * + * @param s3URI the S3 URI of the object + * @param objectClient object client capable of interacting with the underlying object store + * @param telemetry an instance of {@link Telemetry} to use + * @param metadataStore the metadata cache + * @param configuration the physicalIO configuration + * @param streamContext contains audit headers to be attached in the request header + */ + public BlockManager( + @NonNull S3URI s3URI, + @NonNull ObjectClient objectClient, + @NonNull MetadataStore metadataStore, + @NonNull Telemetry telemetry, + @NonNull PhysicalIOConfiguration configuration, + StreamContext streamContext) { this.s3URI = s3URI; this.objectClient = objectClient; this.metadataStore = metadataStore; @@ -72,6 +94,7 @@ public BlockManager( this.sequentialReadProgression = new SequentialReadProgression(configuration); this.ioPlanner = new IOPlanner(blockStore); this.rangeOptimiser = new RangeOptimiser(configuration); + this.streamContext = streamContext; } /** @@ -178,7 +201,8 @@ public synchronized void makeRangeAvailable(long pos, long len, ReadMode readMod r.getStart(), r.getEnd(), generation, - readMode); + readMode, + streamContext); blockStore.add(block); }); }); diff --git a/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/impl/PhysicalIOImpl.java b/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/impl/PhysicalIOImpl.java index a44d1a04..38573692 100644 --- a/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/impl/PhysicalIOImpl.java +++ b/input-stream/src/main/java/software/amazon/s3/analyticsaccelerator/io/physical/impl/PhysicalIOImpl.java @@ -26,6 +26,7 @@ import software.amazon.s3.analyticsaccelerator.io.physical.plan.IOPlan; import software.amazon.s3.analyticsaccelerator.io.physical.plan.IOPlanExecution; import software.amazon.s3.analyticsaccelerator.request.ObjectMetadata; +import software.amazon.s3.analyticsaccelerator.request.StreamContext; import software.amazon.s3.analyticsaccelerator.util.S3URI; import software.amazon.s3.analyticsaccelerator.util.StreamAttributes; @@ -35,6 +36,7 @@ public class PhysicalIOImpl implements PhysicalIO { private final MetadataStore metadataStore; private final BlobStore blobStore; private final Telemetry telemetry; + private final StreamContext streamContext; private final long physicalIOBirth = System.nanoTime(); @@ -56,10 +58,29 @@ public PhysicalIOImpl( @NonNull MetadataStore metadataStore, @NonNull BlobStore blobStore, @NonNull Telemetry telemetry) { + this(s3URI, metadataStore, blobStore, telemetry, null); + } + + /** + * Construct a new instance of PhysicalIOV2. + * + * @param s3URI the S3 URI of the object + * @param metadataStore a metadata cache + * @param blobStore a data cache + * @param telemetry The {@link Telemetry} to use to report measurements. + * @param streamContext contains audit headers to be attached in the request header + */ + public PhysicalIOImpl( + @NonNull S3URI s3URI, + @NonNull MetadataStore metadataStore, + @NonNull BlobStore blobStore, + @NonNull Telemetry telemetry, + StreamContext streamContext) { this.s3URI = s3URI; this.metadataStore = metadataStore; this.blobStore = blobStore; this.telemetry = telemetry; + this.streamContext = streamContext; } /** @@ -94,7 +115,7 @@ public int read(long pos) throws IOException { StreamAttributes.physicalIORelativeTimestamp( System.nanoTime() - physicalIOBirth)) .build(), - () -> blobStore.get(s3URI).read(pos)); + () -> blobStore.get(s3URI, streamContext).read(pos)); } /** @@ -124,7 +145,7 @@ public int read(byte[] buf, int off, int len, long pos) throws IOException { StreamAttributes.physicalIORelativeTimestamp( System.nanoTime() - physicalIOBirth)) .build(), - () -> blobStore.get(s3URI).read(buf, off, len, pos)); + () -> blobStore.get(s3URI, streamContext).read(buf, off, len, pos)); } /** @@ -151,7 +172,7 @@ public int readTail(byte[] buf, int off, int len) throws IOException { StreamAttributes.physicalIORelativeTimestamp( System.nanoTime() - physicalIOBirth)) .build(), - () -> blobStore.get(s3URI).read(buf, off, len, contentLength - len)); + () -> blobStore.get(s3URI, streamContext).read(buf, off, len, contentLength - len)); } /** @@ -172,7 +193,7 @@ public IOPlanExecution execute(IOPlan ioPlan) { StreamAttributes.physicalIORelativeTimestamp( System.nanoTime() - physicalIOBirth)) .build(), - () -> blobStore.get(s3URI).execute(ioPlan)); + () -> blobStore.get(s3URI, streamContext).execute(ioPlan)); } private long contentLength() { diff --git a/input-stream/src/referenceTest/java/software/amazon/s3/analyticsaccelerator/property/InMemoryS3SeekableInputStream.java b/input-stream/src/referenceTest/java/software/amazon/s3/analyticsaccelerator/property/InMemoryS3SeekableInputStream.java index 799013df..d071a6ad 100644 --- a/input-stream/src/referenceTest/java/software/amazon/s3/analyticsaccelerator/property/InMemoryS3SeekableInputStream.java +++ b/input-stream/src/referenceTest/java/software/amazon/s3/analyticsaccelerator/property/InMemoryS3SeekableInputStream.java @@ -24,11 +24,7 @@ import software.amazon.s3.analyticsaccelerator.S3SeekableInputStreamConfiguration; import software.amazon.s3.analyticsaccelerator.S3SeekableInputStreamFactory; import software.amazon.s3.analyticsaccelerator.SeekableInputStream; -import software.amazon.s3.analyticsaccelerator.request.GetRequest; -import software.amazon.s3.analyticsaccelerator.request.HeadRequest; -import software.amazon.s3.analyticsaccelerator.request.ObjectClient; -import software.amazon.s3.analyticsaccelerator.request.ObjectContent; -import software.amazon.s3.analyticsaccelerator.request.ObjectMetadata; +import software.amazon.s3.analyticsaccelerator.request.*; import software.amazon.s3.analyticsaccelerator.util.S3URI; public class InMemoryS3SeekableInputStream extends SeekableInputStream { @@ -73,6 +69,12 @@ public CompletableFuture headObject(HeadRequest headRequest) { @Override public CompletableFuture getObject(GetRequest getRequest) { + return getObject(getRequest, null); + } + + @Override + public CompletableFuture getObject( + GetRequest getRequest, StreamContext streamContext) { int start = 0; int end = size - 1; diff --git a/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/S3SeekableInputStreamFactoryTest.java b/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/S3SeekableInputStreamFactoryTest.java index 44a13bc5..2fe70eeb 100644 --- a/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/S3SeekableInputStreamFactoryTest.java +++ b/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/S3SeekableInputStreamFactoryTest.java @@ -24,6 +24,7 @@ import software.amazon.s3.analyticsaccelerator.io.logical.impl.DefaultLogicalIOImpl; import software.amazon.s3.analyticsaccelerator.io.logical.impl.ParquetLogicalIOImpl; import software.amazon.s3.analyticsaccelerator.request.ObjectClient; +import software.amazon.s3.analyticsaccelerator.request.StreamContext; import software.amazon.s3.analyticsaccelerator.util.S3URI; @SuppressFBWarnings( @@ -68,6 +69,11 @@ void testCreateDefaultStream() { S3SeekableInputStream inputStream = s3SeekableInputStreamFactory.createStream(S3URI.of("bucket", "key")); assertNotNull(inputStream); + + inputStream = + s3SeekableInputStreamFactory.createStream( + S3URI.of("bucket", "key"), mock(StreamContext.class)); + assertNotNull(inputStream); } @Test @@ -96,6 +102,11 @@ void testCreateIndependentStream() { S3SeekableInputStream inputStream = s3SeekableInputStreamFactory.createStream(S3URI.of("bucket", "key")); assertNotNull(inputStream); + + inputStream = + s3SeekableInputStreamFactory.createStream( + S3URI.of("bucket", "key"), mock(StreamContext.class)); + assertNotNull(inputStream); } @Test @@ -108,6 +119,12 @@ void testCreateStreamThrowsOnNullArgument() { () -> { s3SeekableInputStreamFactory.createStream(null); }); + + assertThrows( + NullPointerException.class, + () -> { + s3SeekableInputStreamFactory.createStream(null, mock(StreamContext.class)); + }); } @Test @@ -121,17 +138,21 @@ void testCreateLogicalIO() { new S3SeekableInputStreamFactory(mock(ObjectClient.class), configuration); assertTrue( - s3SeekableInputStreamFactory.createLogicalIO(S3URI.of("bucket", "key.parquet")) + s3SeekableInputStreamFactory.createLogicalIO( + S3URI.of("bucket", "key.parquet"), mock(StreamContext.class)) instanceof ParquetLogicalIOImpl); assertTrue( - s3SeekableInputStreamFactory.createLogicalIO(S3URI.of("bucket", "key.par")) + s3SeekableInputStreamFactory.createLogicalIO( + S3URI.of("bucket", "key.par"), mock(StreamContext.class)) instanceof ParquetLogicalIOImpl); assertTrue( - s3SeekableInputStreamFactory.createLogicalIO(S3URI.of("bucket", "key.java")) + s3SeekableInputStreamFactory.createLogicalIO( + S3URI.of("bucket", "key.java"), mock(StreamContext.class)) instanceof DefaultLogicalIOImpl); assertTrue( - s3SeekableInputStreamFactory.createLogicalIO(S3URI.of("bucket", "key.txt")) + s3SeekableInputStreamFactory.createLogicalIO( + S3URI.of("bucket", "key.txt"), mock(StreamContext.class)) instanceof DefaultLogicalIOImpl); } diff --git a/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/logical/impl/DefaultLogicalIOImplTest.java b/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/logical/impl/DefaultLogicalIOImplTest.java index 468e7b80..3e1e1432 100644 --- a/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/logical/impl/DefaultLogicalIOImplTest.java +++ b/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/logical/impl/DefaultLogicalIOImplTest.java @@ -74,7 +74,6 @@ void testRead() throws IOException { PhysicalIO physicalIO = mock(PhysicalIO.class); DefaultLogicalIOImpl logicalIO = new DefaultLogicalIOImpl(TEST_URI, physicalIO, mock(Telemetry.class)); - logicalIO.read(5); verify(physicalIO).read(5); } @@ -94,7 +93,6 @@ void testReadTail() throws IOException { PhysicalIO physicalIO = mock(PhysicalIO.class); when(physicalIO.metadata()).thenReturn(ObjectMetadata.builder().contentLength(123).build()); DefaultLogicalIOImpl logicalIO = new DefaultLogicalIOImpl(TEST_URI, physicalIO, Telemetry.NOOP); - byte[] buffer = new byte[5]; logicalIO.readTail(buffer, 0, 5); verify(physicalIO).readTail(buffer, 0, 5); diff --git a/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlobStoreTest.java b/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlobStoreTest.java index e42dcb57..d776dd3b 100644 --- a/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlobStoreTest.java +++ b/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlobStoreTest.java @@ -28,6 +28,7 @@ import software.amazon.s3.analyticsaccelerator.io.physical.PhysicalIOConfiguration; import software.amazon.s3.analyticsaccelerator.request.ObjectClient; import software.amazon.s3.analyticsaccelerator.request.ObjectMetadata; +import software.amazon.s3.analyticsaccelerator.request.StreamContext; import software.amazon.s3.analyticsaccelerator.util.FakeObjectClient; import software.amazon.s3.analyticsaccelerator.util.S3URI; @@ -81,7 +82,7 @@ public void testGetReturnsReadableBlob() { metadataStore, objectClient, TestTelemetry.DEFAULT, PhysicalIOConfiguration.DEFAULT); // When: a Blob is asked for - Blob blob = blobStore.get(S3URI.of("test", "test")); + Blob blob = blobStore.get(S3URI.of("test", "test"), mock(StreamContext.class)); // Then: byte[] b = new byte[TEST_DATA.length()]; diff --git a/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlockManagerTest.java b/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlockManagerTest.java index 4e00283e..b62bd0f9 100644 --- a/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlockManagerTest.java +++ b/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/physical/data/BlockManagerTest.java @@ -32,11 +32,7 @@ import software.amazon.s3.analyticsaccelerator.TestTelemetry; import software.amazon.s3.analyticsaccelerator.common.telemetry.Telemetry; import software.amazon.s3.analyticsaccelerator.io.physical.PhysicalIOConfiguration; -import software.amazon.s3.analyticsaccelerator.request.GetRequest; -import software.amazon.s3.analyticsaccelerator.request.ObjectClient; -import software.amazon.s3.analyticsaccelerator.request.ObjectContent; -import software.amazon.s3.analyticsaccelerator.request.ObjectMetadata; -import software.amazon.s3.analyticsaccelerator.request.ReadMode; +import software.amazon.s3.analyticsaccelerator.request.*; import software.amazon.s3.analyticsaccelerator.util.S3URI; @SuppressFBWarnings( @@ -128,7 +124,7 @@ void testMakePositionAvailableRespectsReadAhead() { // Then ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetRequest.class); - verify(objectClient).getObject(requestCaptor.capture()); + verify(objectClient).getObject(requestCaptor.capture(), any()); assertEquals(0, requestCaptor.getValue().getRange().getStart()); assertEquals( @@ -148,7 +144,7 @@ void testMakePositionAvailableRespectsLastObjectByte() { // Then ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetRequest.class); - verify(objectClient).getObject(requestCaptor.capture()); + verify(objectClient).getObject(requestCaptor.capture(), any()); assertEquals(0, requestCaptor.getValue().getRange().getStart()); assertEquals(objectSize - 1, requestCaptor.getValue().getRange().getEnd()); @@ -165,7 +161,7 @@ void testMakeRangeAvailableDoesNotOverread() { // When: requesting the byte at 64KB blockManager.makeRangeAvailable(64 * ONE_KB, 100, ReadMode.SYNC); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetRequest.class); - verify(objectClient, times(3)).getObject(requestCaptor.capture()); + verify(objectClient, times(3)).getObject(requestCaptor.capture(), any()); // Then: request size is a single byte as more is not needed GetRequest firstRequest = requestCaptor.getAllValues().get(0); @@ -219,7 +215,7 @@ private BlockManager getTestBlockManager(ObjectClient objectClient, int size) { private BlockManager getTestBlockManager( ObjectClient objectClient, int size, PhysicalIOConfiguration configuration) { S3URI testUri = S3URI.of("foo", "bar"); - when(objectClient.getObject(any())) + when(objectClient.getObject(any(), any())) .thenReturn( CompletableFuture.completedFuture( ObjectContent.builder().stream(new ByteArrayInputStream(new byte[size])).build())); diff --git a/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/physical/impl/PhysicalIOImplTest.java b/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/physical/impl/PhysicalIOImplTest.java index 79384949..0d42398a 100644 --- a/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/physical/impl/PhysicalIOImplTest.java +++ b/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/io/physical/impl/PhysicalIOImplTest.java @@ -16,20 +16,84 @@ package software.amazon.s3.analyticsaccelerator.io.physical.impl; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.*; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.IOException; import org.junit.jupiter.api.Test; import software.amazon.s3.analyticsaccelerator.TestTelemetry; import software.amazon.s3.analyticsaccelerator.io.physical.PhysicalIOConfiguration; import software.amazon.s3.analyticsaccelerator.io.physical.data.BlobStore; import software.amazon.s3.analyticsaccelerator.io.physical.data.MetadataStore; +import software.amazon.s3.analyticsaccelerator.request.StreamContext; import software.amazon.s3.analyticsaccelerator.util.FakeObjectClient; import software.amazon.s3.analyticsaccelerator.util.S3URI; +@SuppressFBWarnings( + value = "NP_NONNULL_PARAM_VIOLATION", + justification = "We mean to pass nulls to checks") public class PhysicalIOImplTest { private static final S3URI s3URI = S3URI.of("foo", "bar"); + @Test + void testConstructorThrowsOnNullArgument() { + assertThrows( + NullPointerException.class, + () -> { + new PhysicalIOImpl( + s3URI, null, mock(BlobStore.class), TestTelemetry.DEFAULT, mock(StreamContext.class)); + }); + + assertThrows( + NullPointerException.class, + () -> { + new PhysicalIOImpl( + s3URI, + mock(MetadataStore.class), + null, + TestTelemetry.DEFAULT, + mock(StreamContext.class)); + }); + + assertThrows( + NullPointerException.class, + () -> { + new PhysicalIOImpl( + null, mock(MetadataStore.class), mock(BlobStore.class), TestTelemetry.DEFAULT); + }); + + assertThrows( + NullPointerException.class, + () -> { + new PhysicalIOImpl(s3URI, mock(MetadataStore.class), mock(BlobStore.class), null); + }); + assertThrows( + NullPointerException.class, + () -> { + new PhysicalIOImpl(s3URI, null, mock(BlobStore.class), TestTelemetry.DEFAULT); + }); + + assertThrows( + NullPointerException.class, + () -> { + new PhysicalIOImpl(s3URI, mock(MetadataStore.class), null, TestTelemetry.DEFAULT); + }); + assertThrows( + NullPointerException.class, + () -> { + new PhysicalIOImpl( + null, mock(MetadataStore.class), mock(BlobStore.class), TestTelemetry.DEFAULT); + }); + + assertThrows( + NullPointerException.class, + () -> { + new PhysicalIOImpl(s3URI, mock(MetadataStore.class), mock(BlobStore.class), null); + }); + } + @Test public void test__readSingleByte_isCorrect() throws IOException { // Given: physicalIOImplV2 @@ -73,4 +137,41 @@ public void test__regression_singleByteStream() throws IOException { // Then: returned data is correct assertEquals(120, physicalIOImplV2.read(0)); // a } + + @Test + void testReadWithBuffer() throws IOException { + final String TEST_DATA = "abcdef0123456789"; + FakeObjectClient fakeObjectClient = new FakeObjectClient(TEST_DATA); + MetadataStore metadataStore = + new MetadataStore(fakeObjectClient, TestTelemetry.DEFAULT, PhysicalIOConfiguration.DEFAULT); + BlobStore blobStore = + new BlobStore( + metadataStore, + fakeObjectClient, + TestTelemetry.DEFAULT, + PhysicalIOConfiguration.DEFAULT); + PhysicalIOImpl physicalIOImplV2 = + new PhysicalIOImpl(s3URI, metadataStore, blobStore, TestTelemetry.DEFAULT); + + byte[] buffer = new byte[5]; + assertEquals(5, physicalIOImplV2.read(buffer, 0, 5, 5)); + } + + @Test + void testReadTail() throws IOException { + final String TEST_DATA = "abcdef0123456789"; + FakeObjectClient fakeObjectClient = new FakeObjectClient(TEST_DATA); + MetadataStore metadataStore = + new MetadataStore(fakeObjectClient, TestTelemetry.DEFAULT, PhysicalIOConfiguration.DEFAULT); + BlobStore blobStore = + new BlobStore( + metadataStore, + fakeObjectClient, + TestTelemetry.DEFAULT, + PhysicalIOConfiguration.DEFAULT); + PhysicalIOImpl physicalIOImplV2 = + new PhysicalIOImpl(s3URI, metadataStore, blobStore, TestTelemetry.DEFAULT); + byte[] buffer = new byte[5]; + assertEquals(5, physicalIOImplV2.readTail(buffer, 0, 5)); + } } diff --git a/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/util/FakeObjectClient.java b/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/util/FakeObjectClient.java index 86f917aa..db0c51d2 100644 --- a/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/util/FakeObjectClient.java +++ b/input-stream/src/test/java/software/amazon/s3/analyticsaccelerator/util/FakeObjectClient.java @@ -23,12 +23,7 @@ import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.atomic.AtomicInteger; import lombok.Getter; -import software.amazon.s3.analyticsaccelerator.request.GetRequest; -import software.amazon.s3.analyticsaccelerator.request.HeadRequest; -import software.amazon.s3.analyticsaccelerator.request.ObjectClient; -import software.amazon.s3.analyticsaccelerator.request.ObjectContent; -import software.amazon.s3.analyticsaccelerator.request.ObjectMetadata; -import software.amazon.s3.analyticsaccelerator.request.Range; +import software.amazon.s3.analyticsaccelerator.request.*; public class FakeObjectClient implements ObjectClient { @@ -60,6 +55,12 @@ public CompletableFuture headObject(HeadRequest headRequest) { @Override public CompletableFuture getObject(GetRequest getRequest) { + return getObject(getRequest, null); + } + + @Override + public CompletableFuture getObject( + GetRequest getRequest, StreamContext streamContext) { getRequestCount.incrementAndGet(); requestedRanges.add(getRequest.getRange()); return CompletableFuture.completedFuture( diff --git a/object-client/src/main/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClient.java b/object-client/src/main/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClient.java index b22ae045..06a2614f 100644 --- a/object-client/src/main/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClient.java +++ b/object-client/src/main/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClient.java @@ -18,6 +18,8 @@ import java.util.concurrent.CompletableFuture; import lombok.Getter; import lombok.NonNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.services.s3.S3AsyncClient; @@ -32,6 +34,7 @@ public class S3SdkObjectClient implements ObjectClient { private static final String HEADER_USER_AGENT = "User-Agent"; private static final String HEADER_REFERER = "Referer"; + private static final Logger LOG = LoggerFactory.getLogger(S3SdkObjectClient.class); @Getter @NonNull private final S3AsyncClient s3AsyncClient; @NonNull private final Telemetry telemetry; @@ -142,17 +145,40 @@ public CompletableFuture headObject(HeadRequest headRequest) { */ @Override public CompletableFuture getObject(GetRequest getRequest) { + return getObject(getRequest, null); + } + + /** + * Make a getObject request to the object store. + * + * @param getRequest The GET request to be sent + * @param streamContext audit headers to be attached in the request header + * @return ResponseInputStream + */ + @Override + public CompletableFuture getObject( + GetRequest getRequest, StreamContext streamContext) { + GetObjectRequest.Builder builder = GetObjectRequest.builder() .bucket(getRequest.getS3Uri().getBucket()) .key(getRequest.getS3Uri().getKey()); - String range = getRequest.getRange().toHttpString(); + final String range = getRequest.getRange().toHttpString(); builder.range(range); + final String referrerHeader; + if (streamContext != null) { + referrerHeader = streamContext.modifyAndBuildReferrerHeader(getRequest); + } else { + referrerHeader = getRequest.getReferrer().toString(); + } + + LOG.info("auditHeaders {}", referrerHeader); + builder.overrideConfiguration( AwsRequestOverrideConfiguration.builder() - .putHeader(HEADER_REFERER, getRequest.getReferrer().toString()) + .putHeader(HEADER_REFERER, referrerHeader) .putHeader(HEADER_USER_AGENT, this.userAgent.getUserAgent()) .build()); diff --git a/object-client/src/test/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClientTest.java b/object-client/src/test/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClientTest.java index 5cf92671..d8fb3ea1 100644 --- a/object-client/src/test/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClientTest.java +++ b/object-client/src/test/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClientTest.java @@ -27,6 +27,8 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import software.amazon.awssdk.core.ResponseInputStream; @@ -37,12 +39,7 @@ import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.HeadObjectRequest; import software.amazon.awssdk.services.s3.model.HeadObjectResponse; -import software.amazon.s3.analyticsaccelerator.request.GetRequest; -import software.amazon.s3.analyticsaccelerator.request.HeadRequest; -import software.amazon.s3.analyticsaccelerator.request.ObjectMetadata; -import software.amazon.s3.analyticsaccelerator.request.Range; -import software.amazon.s3.analyticsaccelerator.request.ReadMode; -import software.amazon.s3.analyticsaccelerator.request.Referrer; +import software.amazon.s3.analyticsaccelerator.request.*; import software.amazon.s3.analyticsaccelerator.util.S3URI; @SuppressFBWarnings( @@ -50,6 +47,9 @@ justification = "We mean to pass nulls to checks. Also, closures cannot be made static in this case") public class S3SdkObjectClientTest { + + private static final String HEADER_REFERER = "Referer"; + @Test void testForNullsInConstructor() { try (S3AsyncClient client = mock(S3AsyncClient.class)) { @@ -167,6 +167,71 @@ void testGetObjectWithRange() { } } + @Test + void testGetObjectWithAuditHeaders() { + S3AsyncClient mockS3AsyncClient = createMockClient(); + + S3SdkObjectClient client = new S3SdkObjectClient(mockS3AsyncClient); + + StreamContext mockStreamContext = mock(StreamContext.class); + when(mockStreamContext.modifyAndBuildReferrerHeader(any())).thenReturn("audit-referrer-value"); + + GetRequest getRequest = + GetRequest.builder() + .s3Uri(S3URI.of("bucket", "key")) + .range(new Range(0, 20)) + .referrer(new Referrer("bytes=0-20", ReadMode.SYNC)) + .build(); + + client.getObject(getRequest, mockStreamContext); + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(GetObjectRequest.class); + verify(mockS3AsyncClient) + .getObject( + requestCaptor.capture(), + ArgumentMatchers + .>> + any()); + + GetObjectRequest capturedRequest = requestCaptor.getValue(); + assertEquals( + "audit-referrer-value", + capturedRequest.overrideConfiguration().get().headers().get(HEADER_REFERER).get(0)); + } + + @Test + void testGetObjectWithoutAuditHeaders() { + S3AsyncClient mockS3AsyncClient = createMockClient(); + + S3SdkObjectClient client = new S3SdkObjectClient(mockS3AsyncClient); + + GetRequest getRequest = + GetRequest.builder() + .s3Uri(S3URI.of("bucket", "key")) + .range(new Range(0, 20)) + .referrer(new Referrer("original-referrer", ReadMode.SYNC)) + .build(); + + client.getObject(getRequest, null); + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(GetObjectRequest.class); + verify(mockS3AsyncClient) + .getObject( + requestCaptor.capture(), + ArgumentMatchers + .>> + any()); + + GetObjectRequest capturedRequest = requestCaptor.getValue(); + assertEquals( + "original-referrer,readMode=SYNC", + capturedRequest.overrideConfiguration().get().headers().get(HEADER_REFERER).get(0)); + } + @Test void testObjectClientClose() { try (S3AsyncClient s3AsyncClient = createMockClient()) {