Skip to content

Commit

Permalink
b/376134352 Expose SFTP file channel as stream (#1553)
Browse files Browse the repository at this point in the history
- Extend `SshFileSystemChannel` to return a `Stream` when creating or opening a file
- Add `Stream` implementation that reads or writes via SFTP
- Add `CopyToAsync` extension method
  • Loading branch information
jpassing authored Nov 15, 2024
1 parent df8361e commit 0072d43
Show file tree
Hide file tree
Showing 8 changed files with 549 additions and 308 deletions.
74 changes: 74 additions & 0 deletions sources/Google.Solutions.Common.Test/IO/TestStreamExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@


using Google.Solutions.Common.IO;
using Google.Solutions.Testing.Apis;
using Moq;
using NUnit.Framework;
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

namespace Google.Solutions.Common.Test.IO
{
Expand All @@ -35,6 +38,30 @@ public class TestStreamExtensions
// CopyTo.
//--------------------------------------------------------------------

[Test]
public void CopyTo_WhenSourceNotReadable()
{
var sourceStream = new Mock<Stream>();
sourceStream.SetupGet(s => s.CanRead).Returns(false);

Assert.Throws<NotSupportedException>(
() => sourceStream.Object.CopyTo(
new MemoryStream(),
new Mock<IProgress<int>>().Object));
}

[Test]
public void CopyTo_WhenDestinationNotWritable()
{
var destinationStream = new Mock<Stream>();
destinationStream.SetupGet(s => s.CanWrite).Returns(false);

Assert.Throws<NotSupportedException>(
() => new MemoryStream().CopyTo(
destinationStream.Object,
new Mock<IProgress<int>>().Object));
}

[Test]
public void CopyTo_ReportsProgress()
{
Expand All @@ -48,5 +75,52 @@ public void CopyTo_ReportsProgress()

progress.Verify(p => p.Report(It.IsAny<int>()), Times.AtLeast(4));
}

//--------------------------------------------------------------------
// CopyToAsync.
//--------------------------------------------------------------------

[Test]
public void CopyToAsync_WhenSourceNotReadable()
{
var sourceStream = new Mock<Stream>();
sourceStream.SetupGet(s => s.CanRead).Returns(false);

ExceptionAssert.ThrowsAggregateException<NotSupportedException>(
() => sourceStream.Object.CopyToAsync(
new MemoryStream(),
new Mock<IProgress<int>>().Object,
CancellationToken.None).Wait());
}

[Test]
public void CopyToAsync_WhenDestinationNotWritable()
{
var destinationStream = new Mock<Stream>();
destinationStream.SetupGet(s => s.CanWrite).Returns(false);

ExceptionAssert.ThrowsAggregateException<NotSupportedException>(
() => new MemoryStream().CopyToAsync(
destinationStream.Object,
new Mock<IProgress<int>>().Object,
CancellationToken.None).Wait());
}

[Test]
public async Task CopyToAsync_ReportsProgress()
{
var source = new byte[StreamExtensions.DefaultBufferSize * 3 + 1];
var sourceStream = new MemoryStream(source);

var progress = new Mock<IProgress<int>>();
await sourceStream
.CopyToAsync(
new MemoryStream(),
progress.Object,
CancellationToken.None)
.ConfigureAwait(false);

progress.Verify(p => p.Report(It.IsAny<int>()), Times.AtLeast(4));
}
}
}
75 changes: 75 additions & 0 deletions sources/Google.Solutions.Common/IO/StreamExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,31 @@

using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

namespace Google.Solutions.Common.IO
{
public static class StreamExtensions
{
internal const int DefaultBufferSize = 64 * 1024;

private static void ExpectReadable(Stream s)
{
if (!s.CanRead)
{
throw new NotSupportedException("Source is not readable");
}
}

private static void ExpectWritable(Stream s)
{
if (!s.CanWrite)
{
throw new NotSupportedException("Destination is not writable");
}
}

/// <summary>
/// Reads the bytes from the current stream and
/// writes them to another stream.
Expand All @@ -38,6 +56,9 @@ public static void CopyTo(
IProgress<int> progress,
int bufferSize = DefaultBufferSize)
{
ExpectReadable(source);
ExpectWritable(destination);

var buffer = new byte[bufferSize];
int count;
while ((count = source.Read(buffer, 0, buffer.Length)) != 0)
Expand All @@ -46,5 +67,59 @@ public static void CopyTo(
progress.Report(count);
}
}


/// <summary>
/// Reads the bytes from the current stream and
/// writes them to another stream.
/// </summary>
public static async Task CopyToAsync(
this Stream source,
Stream destination,
IProgress<int> progress,
int bufferSize,
CancellationToken cancellationToken)
{
ExpectReadable(source);
ExpectWritable(destination);

var buffer = new byte[bufferSize];
int count;
while ((count = await source
.ReadAsync(
buffer,
0,
buffer.Length,
cancellationToken)
.ConfigureAwait(false)) != 0)
{
await destination
.WriteAsync(
buffer,
0,
count,
cancellationToken)
.ConfigureAwait(false);
progress.Report(count);
}
}

/// <summary>
/// Reads the bytes from the current stream and
/// writes them to another stream.
/// </summary>
public static Task CopyToAsync(
this Stream source,
Stream destination,
IProgress<int> progress,
CancellationToken cancellationToken)
{
return CopyToAsync(
source,
destination,
progress,
DefaultBufferSize,
cancellationToken);
}
}
}
Loading

0 comments on commit 0072d43

Please sign in to comment.