-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Feat/memory transport #3022
Open
pyropy
wants to merge
14
commits into
libp2p:master
Choose a base branch
from
pyropy:feat/memory-transport
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+993
−11
Open
Changes from 12 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
39983c2
Add memory transport
pyropy 37899d3
Daily commit
pyropy 0df38f9
Daily commit
pyropy 67da192
Upgrade go-multiaddr
pyropy e2a5865
Daily commit
pyropy 079bd3e
Use plain channels to send data between streams
pyropy e3203f2
Daily commit
pyropy 4164b48
Merge branch 'master' into feat/memory-transport
pyropy b96f386
Daily commit
pyropy 5236ff2
Daily commit
pyropy f511c2d
Revert to stream using channels instead of io.Pipe
pyropy 35ebf85
Update errors when remote stream is closed
pyropy 2e3378f
Add read and write timeouts
pyropy b948613
Merge branch 'master' into feat/memory-transport
pyropy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
package memory | ||
|
||
import ( | ||
"context" | ||
"log" | ||
"sync" | ||
"sync/atomic" | ||
|
||
ic "github.com/libp2p/go-libp2p/core/crypto" | ||
"github.com/libp2p/go-libp2p/core/network" | ||
"github.com/libp2p/go-libp2p/core/peer" | ||
tpt "github.com/libp2p/go-libp2p/core/transport" | ||
ma "github.com/multiformats/go-multiaddr" | ||
) | ||
|
||
type conn struct { | ||
id int64 | ||
rconn *conn | ||
|
||
scope network.ConnManagementScope | ||
listener *listener | ||
transport *transport | ||
|
||
localPeer peer.ID | ||
localMultiaddr ma.Multiaddr | ||
|
||
remotePeerID peer.ID | ||
remotePubKey ic.PubKey | ||
remoteMultiaddr ma.Multiaddr | ||
|
||
mu sync.Mutex | ||
|
||
closed atomic.Bool | ||
closeOnce sync.Once | ||
|
||
streamC chan *stream | ||
streams map[int64]network.MuxedStream | ||
} | ||
|
||
var _ tpt.CapableConn = &conn{} | ||
|
||
func newConnection( | ||
t *transport, | ||
s *stream, | ||
localPeer peer.ID, | ||
localMultiaddr ma.Multiaddr, | ||
remotePubKey ic.PubKey, | ||
remotePeer peer.ID, | ||
remoteMultiaddr ma.Multiaddr, | ||
) *conn { | ||
c := &conn{ | ||
id: connCounter.Add(1), | ||
transport: t, | ||
localPeer: localPeer, | ||
localMultiaddr: localMultiaddr, | ||
remotePubKey: remotePubKey, | ||
remotePeerID: remotePeer, | ||
remoteMultiaddr: remoteMultiaddr, | ||
streamC: make(chan *stream, 1), | ||
streams: make(map[int64]network.MuxedStream), | ||
} | ||
|
||
c.addStream(s.id, s) | ||
return c | ||
} | ||
|
||
func (c *conn) Close() error { | ||
c.closeOnce.Do(func() { | ||
c.closed.Store(true) | ||
go c.rconn.Close() | ||
c.teardown() | ||
}) | ||
|
||
return nil | ||
} | ||
|
||
func (c *conn) IsClosed() bool { | ||
return c.closed.Load() | ||
} | ||
|
||
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { | ||
sl, sr := newStreamPair() | ||
sl.conn = c | ||
c.addStream(sl.id, sl) | ||
log.Println("opening stream", sl.id, sr.id) | ||
|
||
c.rconn.streamC <- sr | ||
return sl, nil | ||
} | ||
|
||
func (c *conn) AcceptStream() (network.MuxedStream, error) { | ||
in := <-c.streamC | ||
in.conn = c | ||
c.addStream(in.id, in) | ||
return in, nil | ||
} | ||
|
||
func (c *conn) LocalPeer() peer.ID { return c.localPeer } | ||
|
||
// RemotePeer returns the peer ID of the remote peer. | ||
func (c *conn) RemotePeer() peer.ID { return c.remotePeerID } | ||
|
||
// RemotePublicKey returns the public pkey of the remote peer. | ||
func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } | ||
|
||
// LocalMultiaddr returns the local Multiaddr associated | ||
func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.localMultiaddr } | ||
|
||
// RemoteMultiaddr returns the remote Multiaddr associated | ||
func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remoteMultiaddr } | ||
|
||
func (c *conn) Transport() tpt.Transport { | ||
return c.transport | ||
} | ||
|
||
func (c *conn) Scope() network.ConnScope { | ||
return c.scope | ||
} | ||
|
||
// ConnState is the state of security connection. | ||
func (c *conn) ConnState() network.ConnectionState { | ||
return network.ConnectionState{Transport: "memory"} | ||
} | ||
|
||
func (c *conn) addStream(id int64, stream network.MuxedStream) { | ||
c.mu.Lock() | ||
defer c.mu.Unlock() | ||
|
||
c.streams[id] = stream | ||
} | ||
|
||
func (c *conn) removeStream(id int64) { | ||
c.mu.Lock() | ||
defer c.mu.Unlock() | ||
|
||
delete(c.streams, id) | ||
} | ||
|
||
func (c *conn) teardown() { | ||
for id, s := range c.streams { | ||
log.Println("tearing down stream", id) | ||
s.Reset() | ||
} | ||
|
||
// TODO: remove self from listener | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
package memory | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package memory | ||
|
||
import ( | ||
ic "github.com/libp2p/go-libp2p/core/crypto" | ||
"github.com/libp2p/go-libp2p/core/peer" | ||
ma "github.com/multiformats/go-multiaddr" | ||
mafmt "github.com/multiformats/go-multiaddr-fmt" | ||
"sync" | ||
"sync/atomic" | ||
) | ||
|
||
var ( | ||
connCounter atomic.Int64 | ||
streamCounter atomic.Int64 | ||
listenerCounter atomic.Int64 | ||
dialMatcher = mafmt.Base(ma.P_MEMORY) | ||
memhub = newHub() | ||
) | ||
|
||
type hub struct { | ||
mu sync.RWMutex | ||
closeOnce sync.Once | ||
pubKeys map[peer.ID]ic.PubKey | ||
listeners map[string]*listener | ||
} | ||
|
||
func newHub() *hub { | ||
return &hub{ | ||
pubKeys: make(map[peer.ID]ic.PubKey), | ||
listeners: make(map[string]*listener), | ||
} | ||
} | ||
|
||
func (h *hub) addListener(addr string, l *listener) { | ||
h.mu.Lock() | ||
defer h.mu.Unlock() | ||
|
||
h.listeners[addr] = l | ||
} | ||
|
||
func (h *hub) removeListener(addr string, l *listener) { | ||
h.mu.Lock() | ||
defer h.mu.Unlock() | ||
|
||
delete(h.listeners, addr) | ||
} | ||
|
||
func (h *hub) getListener(addr string) (*listener, bool) { | ||
h.mu.RLock() | ||
defer h.mu.RUnlock() | ||
|
||
l, ok := h.listeners[addr] | ||
return l, ok | ||
} | ||
|
||
func (h *hub) addPubKey(p peer.ID, pk ic.PubKey) { | ||
h.mu.Lock() | ||
defer h.mu.Unlock() | ||
|
||
h.pubKeys[p] = pk | ||
} | ||
|
||
func (h *hub) getPubKey(p peer.ID) (ic.PubKey, bool) { | ||
h.mu.RLock() | ||
defer h.mu.RUnlock() | ||
|
||
pk, ok := h.pubKeys[p] | ||
return pk, ok | ||
} | ||
|
||
func (h *hub) close() { | ||
h.closeOnce.Do(func() { | ||
h.mu.Lock() | ||
defer h.mu.Unlock() | ||
|
||
for _, l := range h.listeners { | ||
l.Close() | ||
} | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package memory | ||
|
||
import ( | ||
"context" | ||
"net" | ||
"sync" | ||
|
||
tpt "github.com/libp2p/go-libp2p/core/transport" | ||
ma "github.com/multiformats/go-multiaddr" | ||
) | ||
|
||
const ( | ||
listenerQueueSize = 16 | ||
) | ||
|
||
type listener struct { | ||
id int64 | ||
|
||
t *transport | ||
ctx context.Context | ||
cancel context.CancelFunc | ||
laddr ma.Multiaddr | ||
|
||
mu sync.Mutex | ||
connCh chan *conn | ||
connections map[int64]*conn | ||
} | ||
|
||
func (l *listener) Multiaddr() ma.Multiaddr { | ||
return l.laddr | ||
} | ||
|
||
func newListener(t *transport, laddr ma.Multiaddr) *listener { | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
return &listener{ | ||
id: listenerCounter.Add(1), | ||
t: t, | ||
ctx: ctx, | ||
cancel: cancel, | ||
laddr: laddr, | ||
connCh: make(chan *conn, listenerQueueSize), | ||
connections: make(map[int64]*conn), | ||
} | ||
} | ||
|
||
// Accept accepts new connections. | ||
func (l *listener) Accept() (tpt.CapableConn, error) { | ||
select { | ||
case <-l.ctx.Done(): | ||
return nil, tpt.ErrListenerClosed | ||
case c, ok := <-l.connCh: | ||
if !ok { | ||
return nil, tpt.ErrListenerClosed | ||
} | ||
|
||
l.mu.Lock() | ||
defer l.mu.Unlock() | ||
|
||
c.listener = l | ||
c.transport = l.t | ||
l.connections[c.id] = c | ||
return c, nil | ||
} | ||
} | ||
|
||
// Close closes the listener. | ||
func (l *listener) Close() error { | ||
l.cancel() | ||
return nil | ||
} | ||
|
||
// Addr returns the address of this listener. | ||
func (l *listener) Addr() net.Addr { | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
package memory | ||
|
||
import ( | ||
"bytes" | ||
"errors" | ||
"io" | ||
"net" | ||
"sync" | ||
"time" | ||
|
||
"github.com/libp2p/go-libp2p/core/network" | ||
) | ||
|
||
// onceError is an object that will only store an error once. | ||
type onceError struct { | ||
sync.Mutex // guards following | ||
err error | ||
} | ||
|
||
func (a *onceError) Store(err error) { | ||
a.Lock() | ||
defer a.Unlock() | ||
if a.err != nil { | ||
return | ||
} | ||
a.err = err | ||
} | ||
func (a *onceError) Load() error { | ||
a.Lock() | ||
defer a.Unlock() | ||
return a.err | ||
} | ||
|
||
// stream implements network.Stream | ||
type stream struct { | ||
id int64 | ||
conn *conn | ||
|
||
wrMu sync.Mutex // Serialize Write operations | ||
buf *bytes.Buffer // Buffer for partial reads | ||
|
||
// Used by local Read to interact with remote Write. | ||
rdRx <-chan []byte | ||
|
||
// Used by local Write to interact with remote Read. | ||
wrTx chan<- []byte | ||
|
||
closeOnce sync.Once // Protects closing localDone | ||
localDone chan struct{} | ||
remoteDone <-chan struct{} | ||
|
||
resetOnce sync.Once // Protects closing localReset | ||
localReset chan struct{} | ||
remoteReset <-chan struct{} | ||
|
||
rerr onceError | ||
werr onceError | ||
} | ||
|
||
var ErrClosed = errors.New("stream closed") | ||
|
||
func newStreamPair() (*stream, *stream) { | ||
cb1 := make(chan []byte, 1) | ||
cb2 := make(chan []byte, 1) | ||
|
||
done1 := make(chan struct{}) | ||
done2 := make(chan struct{}) | ||
|
||
reset1 := make(chan struct{}) | ||
reset2 := make(chan struct{}) | ||
|
||
sa := newStream(cb1, cb2, done1, done2, reset1, reset2) | ||
sb := newStream(cb2, cb1, done2, done1, reset2, reset1) | ||
|
||
return sa, sb | ||
} | ||
|
||
func newStream(rdRx <-chan []byte, wrTx chan<- []byte, localDone chan struct{}, remoteDone <-chan struct{}, localReset chan struct{}, remoteReset <-chan struct{}) *stream { | ||
s := &stream{ | ||
id: streamCounter.Add(1), | ||
rdRx: rdRx, | ||
wrTx: wrTx, | ||
buf: new(bytes.Buffer), | ||
localDone: localDone, | ||
remoteDone: remoteDone, | ||
localReset: localReset, | ||
remoteReset: remoteReset, | ||
} | ||
|
||
return s | ||
} | ||
|
||
func (p *stream) Write(b []byte) (int, error) { | ||
if err := p.werr.Load(); err != nil { | ||
return 0, err | ||
} | ||
|
||
return p.write(b) | ||
//if err != nil && err != io.ErrClosedPipe && err != network.ErrReset { | ||
// err = &net.OpError{Op: "write", Net: "pipe", Err: err} | ||
//} | ||
// | ||
//return n, err | ||
} | ||
|
||
func (p *stream) write(b []byte) (n int, err error) { | ||
switch { | ||
case isClosedChan(p.remoteReset): | ||
return 0, network.ErrReset | ||
case isClosedChan(p.remoteDone): | ||
return 0, io.ErrClosedPipe | ||
} | ||
|
||
p.wrMu.Lock() // Ensure entirety of b is written together | ||
defer p.wrMu.Unlock() | ||
|
||
select { | ||
case p.wrTx <- b: | ||
n += len(b) | ||
} | ||
|
||
return n, nil | ||
} | ||
|
||
func (p *stream) Read(b []byte) (int, error) { | ||
if err := p.rerr.Load(); err != nil { | ||
return 0, err | ||
} | ||
|
||
return p.read(b) | ||
//if err != nil && err != io.EOF && err != io.ErrClosedPipe && err != network.ErrReset { | ||
// err = &net.OpError{Op: "read", Net: "pipe", Err: err} | ||
//} | ||
// | ||
//return n, err | ||
} | ||
|
||
func (p *stream) read(b []byte) (n int, err error) { | ||
var readErr error | ||
|
||
switch { | ||
case isClosedChan(p.remoteReset): | ||
err = network.ErrReset | ||
case isClosedChan(p.remoteDone): | ||
err = io.EOF | ||
} | ||
|
||
select { | ||
case bw, ok := <-p.rdRx: | ||
if !ok { | ||
err = io.EOF | ||
p.rerr.Store(err) | ||
return | ||
} | ||
|
||
p.buf.Write(bw) | ||
default: | ||
} | ||
|
||
n, readErr = p.buf.Read(b) | ||
if err == nil { | ||
err = readErr | ||
} | ||
|
||
return n, err | ||
} | ||
|
||
func (s *stream) CloseWrite() error { | ||
s.werr.Store(ErrClosed) | ||
return nil | ||
} | ||
|
||
func (s *stream) CloseRead() error { | ||
s.rerr.Store(ErrClosed) | ||
return nil | ||
} | ||
|
||
func (s *stream) Close() error { | ||
s.closeOnce.Do(func() { | ||
close(s.localDone) | ||
}) | ||
|
||
_ = s.CloseRead() | ||
return s.CloseWrite() | ||
} | ||
|
||
func (s *stream) Reset() error { | ||
s.rerr.Store(network.ErrReset) | ||
s.werr.Store(network.ErrReset) | ||
|
||
s.resetOnce.Do(func() { | ||
close(s.localReset) | ||
}) | ||
|
||
// No meaningful error case here. | ||
return nil | ||
} | ||
|
||
func (s *stream) SetDeadline(t time.Time) error { | ||
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} | ||
} | ||
|
||
func (s *stream) SetReadDeadline(t time.Time) error { | ||
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} | ||
} | ||
|
||
func (s *stream) SetWriteDeadline(t time.Time) error { | ||
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} | ||
} | ||
|
||
func isClosedChan(c <-chan struct{}) bool { | ||
select { | ||
case <-c: | ||
return true | ||
default: | ||
return false | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
package memory | ||
|
||
import ( | ||
"errors" | ||
"github.com/libp2p/go-libp2p/core/network" | ||
"github.com/stretchr/testify/require" | ||
"io" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestStreamSimpleReadWriteClose(t *testing.T) { | ||
t.Parallel() | ||
streamLocal, streamRemote := newStreamPair() | ||
|
||
// send a foobar from the client | ||
n, err := streamLocal.Write([]byte("foobar")) | ||
require.NoError(t, err) | ||
require.Equal(t, 6, n) | ||
require.NoError(t, streamLocal.CloseWrite()) | ||
|
||
// writing after closing should error | ||
_, err = streamLocal.Write([]byte("foobar")) | ||
require.Error(t, err) | ||
|
||
// now read all the data on the server side | ||
b, err := io.ReadAll(streamRemote) | ||
require.NoError(t, err) | ||
require.Equal(t, []byte("foobar"), b) | ||
|
||
// reading again should give another io.EOF | ||
n, err = streamRemote.Read(make([]byte, 10)) | ||
require.Zero(t, n) | ||
require.ErrorIs(t, err, io.EOF) | ||
|
||
// send something back | ||
_, err = streamRemote.Write([]byte("lorem ipsum")) | ||
require.NoError(t, err) | ||
require.NoError(t, streamRemote.CloseWrite()) | ||
|
||
// and read it at the client | ||
b, err = io.ReadAll(streamLocal) | ||
require.NoError(t, err) | ||
require.Equal(t, []byte("lorem ipsum"), b) | ||
|
||
// stream is only cleaned up on calling Close or Reset | ||
require.NoError(t, streamLocal.Close()) | ||
require.NoError(t, streamRemote.Close()) | ||
} | ||
|
||
func TestStreamPartialReads(t *testing.T) { | ||
t.Parallel() | ||
streamLocal, streamRemote := newStreamPair() | ||
|
||
_, err := streamRemote.Write([]byte("foobar")) | ||
require.NoError(t, err) | ||
require.NoError(t, streamRemote.CloseWrite()) | ||
|
||
n, err := streamLocal.Read([]byte{}) // empty read | ||
require.NoError(t, err) | ||
require.Zero(t, n) | ||
b := make([]byte, 3) | ||
n, err = streamLocal.Read(b) | ||
require.Equal(t, 3, n) | ||
require.NoError(t, err) | ||
require.Equal(t, []byte("foo"), b) | ||
b, err = io.ReadAll(streamLocal) | ||
require.NoError(t, err) | ||
require.Equal(t, []byte("bar"), b) | ||
} | ||
|
||
func TestStreamResets(t *testing.T) { | ||
clientStr, serverStr := newStreamPair() | ||
|
||
// send a foobar from the client | ||
_, err := clientStr.Write([]byte("foobar")) | ||
require.NoError(t, err) | ||
_, err = serverStr.Write([]byte("lorem ipsum")) | ||
require.NoError(t, err) | ||
require.NoError(t, clientStr.Reset()) // resetting resets both directions | ||
// attempting to write more data should result in a reset error | ||
_, err = clientStr.Write([]byte("foobar")) | ||
require.ErrorIs(t, err, network.ErrReset) | ||
// read what the server sent | ||
b, err := io.ReadAll(clientStr) | ||
require.Empty(t, b) | ||
require.ErrorIs(t, err, network.ErrReset) | ||
|
||
// read the data on the server side | ||
b, err = io.ReadAll(serverStr) | ||
require.Equal(t, []byte("foobar"), b) | ||
require.ErrorIs(t, err, network.ErrReset) | ||
require.Eventually(t, func() bool { | ||
_, err := serverStr.Write([]byte("foobar")) | ||
return errors.Is(err, network.ErrReset) | ||
}, time.Second, 50*time.Millisecond) | ||
serverStr.Close() | ||
} | ||
|
||
func TestStreamReadAfterClose(t *testing.T) { | ||
clientStr, serverStr := newStreamPair() | ||
|
||
serverStr.Close() | ||
b := make([]byte, 1) | ||
_, err := clientStr.Read(b) | ||
require.Equal(t, io.EOF, err) | ||
_, err = clientStr.Read(nil) | ||
require.Equal(t, io.EOF, err) | ||
|
||
clientStr, serverStr = newStreamPair() | ||
|
||
serverStr.Reset() | ||
b = make([]byte, 1) | ||
_, err = clientStr.Read(b) | ||
require.ErrorIs(t, err, network.ErrReset) | ||
_, err = clientStr.Read(nil) | ||
require.ErrorIs(t, err, network.ErrReset) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
package memory | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
ic "github.com/libp2p/go-libp2p/core/crypto" | ||
"github.com/libp2p/go-libp2p/core/network" | ||
"github.com/libp2p/go-libp2p/core/peer" | ||
"github.com/libp2p/go-libp2p/core/pnet" | ||
tpt "github.com/libp2p/go-libp2p/core/transport" | ||
ma "github.com/multiformats/go-multiaddr" | ||
"sync" | ||
) | ||
|
||
type transport struct { | ||
psk pnet.PSK | ||
rcmgr network.ResourceManager | ||
localPeerID peer.ID | ||
localPrivKey ic.PrivKey | ||
localPubKey ic.PubKey | ||
|
||
mu sync.RWMutex | ||
|
||
connections map[int64]*conn | ||
} | ||
|
||
func NewTransport(privKey ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { | ||
if rcmgr == nil { | ||
rcmgr = &network.NullResourceManager{} | ||
} | ||
|
||
id, err := peer.IDFromPrivateKey(privKey) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
memhub.addPubKey(id, privKey.GetPublic()) | ||
return &transport{ | ||
psk: psk, | ||
rcmgr: rcmgr, | ||
localPeerID: id, | ||
localPrivKey: privKey, | ||
localPubKey: privKey.GetPublic(), | ||
connections: make(map[int64]*conn), | ||
}, nil | ||
} | ||
|
||
func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { | ||
scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
c, err := t.dialWithScope(ctx, raddr, p, scope) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return c, nil | ||
} | ||
|
||
func (t *transport) dialWithScope(_ context.Context, raddr ma.Multiaddr, rpid peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { | ||
if err := scope.SetPeer(rpid); err != nil { | ||
return nil, err | ||
} | ||
|
||
rl, ok := memhub.getListener(raddr.String()) | ||
if !ok { | ||
return nil, errors.New("failed to get listener") | ||
} | ||
|
||
remotePubKey, ok := memhub.getPubKey(rpid) | ||
if !ok { | ||
return nil, errors.New("failed to get remote public key") | ||
} | ||
|
||
lc, rc := t.newConnPair(remotePubKey, rpid, raddr) | ||
|
||
rl.connCh <- rc | ||
return lc, nil | ||
} | ||
|
||
func (t *transport) CanDial(addr ma.Multiaddr) bool { | ||
return dialMatcher.Matches(addr) | ||
} | ||
|
||
func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { | ||
// TODO: Check if we need to add scope via conn mngr | ||
l := newListener(t, laddr) | ||
memhub.addListener(laddr.String(), l) | ||
|
||
return l, nil | ||
} | ||
|
||
func (t *transport) Proxy() bool { | ||
return false | ||
} | ||
|
||
// Protocols returns the set of protocols handled by this transport. | ||
func (t *transport) Protocols() []int { | ||
return []int{ma.P_MEMORY} | ||
} | ||
|
||
func (t *transport) String() string { | ||
return "MemoryTransport" | ||
} | ||
|
||
func (t *transport) Close() error { | ||
// TODO: Go trough all listeners and close them | ||
t.mu.Lock() | ||
defer t.mu.Unlock() | ||
|
||
for _, c := range t.connections { | ||
c.Close() | ||
//delete(t.connections, c.id) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (t *transport) addConn(c *conn) { | ||
t.mu.Lock() | ||
defer t.mu.Unlock() | ||
|
||
t.connections[c.id] = c | ||
} | ||
|
||
func (t *transport) removeConn(c *conn) { | ||
t.mu.Lock() | ||
defer t.mu.Unlock() | ||
|
||
delete(t.connections, c.id) | ||
} | ||
|
||
func (t *transport) newConnPair(remotePubKey ic.PubKey, rpid peer.ID, raddr ma.Multiaddr) (*conn, *conn) { | ||
sl, sr := newStreamPair() | ||
|
||
lc := newConnection(t, sl, t.localPeerID, nil, remotePubKey, rpid, raddr) | ||
rc := newConnection(nil, sr, rpid, raddr, t.localPubKey, t.localPeerID, nil) | ||
|
||
lc.rconn = rc | ||
rc.rconn = lc | ||
return lc, rc | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package memory | ||
|
||
import ( | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"crypto/x509" | ||
"io" | ||
"testing" | ||
|
||
ic "github.com/libp2p/go-libp2p/core/crypto" | ||
tpt "github.com/libp2p/go-libp2p/core/transport" | ||
|
||
ma "github.com/multiformats/go-multiaddr" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func getTransport(t *testing.T) tpt.Transport { | ||
t.Helper() | ||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) | ||
require.NoError(t, err) | ||
key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) | ||
require.NoError(t, err) | ||
tr, err := NewTransport(key, nil, nil) | ||
require.NoError(t, err) | ||
return tr | ||
} | ||
|
||
func TestMemoryProtocol(t *testing.T) { | ||
t.Parallel() | ||
tr := getTransport(t) | ||
defer tr.(io.Closer).Close() | ||
|
||
protocols := tr.Protocols() | ||
if len(protocols) > 1 { | ||
t.Fatalf("expected at most one protocol, got %v", protocols) | ||
} | ||
|
||
if protocols[0] != ma.P_MEMORY { | ||
t.Fatalf("expected the supported protocol to be memory, got %d", protocols[0]) | ||
} | ||
} | ||
|
||
func TestCanDial(t *testing.T) { | ||
t.Parallel() | ||
tr := getTransport(t) | ||
defer tr.(io.Closer).Close() | ||
|
||
invalid := []string{ | ||
"/ip4/127.0.0.1/udp/1234", | ||
"/ip4/5.5.5.5/tcp/1234", | ||
"/dns/google.com/udp/443/quic-v1", | ||
"/ip4/127.0.0.1/udp/1234/quic", | ||
} | ||
valid := []string{ | ||
"/memory/1234", | ||
"/memory/1337123", | ||
} | ||
for _, s := range invalid { | ||
invalidAddr, err := ma.NewMultiaddr(s) | ||
require.NoError(t, err) | ||
if tr.CanDial(invalidAddr) { | ||
t.Errorf("didn't expect to be able to dial a non-memory address (%s)", invalidAddr) | ||
} | ||
} | ||
for _, s := range valid { | ||
validAddr, err := ma.NewMultiaddr(s) | ||
require.NoError(t, err) | ||
if !tr.CanDial(validAddr) { | ||
t.Errorf("expected to be able to dial memory address (%s)", validAddr) | ||
} | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Write end to end transport tests here. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Write these before transport integration tests.