|
6 | 6 | "testing"
|
7 | 7 | "time"
|
8 | 8 |
|
| 9 | + "github.com/datarhei/gosrt/packet" |
9 | 10 | "github.com/stretchr/testify/assert"
|
10 | 11 | "github.com/stretchr/testify/require"
|
11 | 12 | )
|
@@ -146,6 +147,159 @@ func TestEncryption(t *testing.T) {
|
146 | 147 | require.Equal(t, message, reader1)
|
147 | 148 | }
|
148 | 149 |
|
| 150 | +// Test for https://github.com/datarhei/gosrt/pull/94 |
| 151 | +func TestEncryptionRetransmit(t *testing.T) { |
| 152 | + message := "Hello World!" |
| 153 | + passphrase := "foobarfoobar" |
| 154 | + channel := NewPubSub(PubSubConfig{}) |
| 155 | + |
| 156 | + config := DefaultConfig() |
| 157 | + config.EnforcedEncryption = true |
| 158 | + |
| 159 | + server := Server{ |
| 160 | + Addr: "127.0.0.1:6003", |
| 161 | + Config: &config, |
| 162 | + HandleConnect: func(req ConnRequest) ConnType { |
| 163 | + if req.IsEncrypted() { |
| 164 | + if err := req.SetPassphrase(passphrase); err != nil { |
| 165 | + return REJECT |
| 166 | + } |
| 167 | + } |
| 168 | + |
| 169 | + streamid := req.StreamId() |
| 170 | + |
| 171 | + if streamid == "publish" { |
| 172 | + return PUBLISH |
| 173 | + } else if streamid == "subscribe" { |
| 174 | + return SUBSCRIBE |
| 175 | + } |
| 176 | + |
| 177 | + return REJECT |
| 178 | + }, |
| 179 | + HandlePublish: func(conn Conn) { |
| 180 | + channel.Publish(conn) |
| 181 | + |
| 182 | + conn.Close() |
| 183 | + }, |
| 184 | + HandleSubscribe: func(conn Conn) { |
| 185 | + channel.Subscribe(conn) |
| 186 | + |
| 187 | + conn.Close() |
| 188 | + }, |
| 189 | + } |
| 190 | + |
| 191 | + err := server.Listen() |
| 192 | + require.NoError(t, err) |
| 193 | + |
| 194 | + defer server.Shutdown() |
| 195 | + |
| 196 | + go func() { |
| 197 | + err := server.Serve() |
| 198 | + if err == ErrServerClosed { |
| 199 | + return |
| 200 | + } |
| 201 | + require.NoError(t, err) |
| 202 | + }() |
| 203 | + |
| 204 | + { |
| 205 | + // Reject connection if wrong password is set |
| 206 | + config := DefaultConfig() |
| 207 | + config.StreamId = "subscribe" |
| 208 | + config.Passphrase = "barfoobarfoo" |
| 209 | + |
| 210 | + _, err := Dial("srt", "127.0.0.1:6003", config) |
| 211 | + require.Error(t, err) |
| 212 | + } |
| 213 | + |
| 214 | + // Test transmitting an encrypted message |
| 215 | + |
| 216 | + readerConnected := make(chan struct{}) |
| 217 | + readerDone := make(chan struct{}) |
| 218 | + |
| 219 | + dataReader1 := bytes.Buffer{} |
| 220 | + |
| 221 | + go func() { |
| 222 | + defer close(readerDone) |
| 223 | + |
| 224 | + config := DefaultConfig() |
| 225 | + config.StreamId = "subscribe" |
| 226 | + config.Passphrase = "foobarfoobar" |
| 227 | + |
| 228 | + conn, err := Dial("srt", "127.0.0.1:6003", config) |
| 229 | + if !assert.NoError(t, err) { |
| 230 | + panic(err.Error()) |
| 231 | + } |
| 232 | + |
| 233 | + close(readerConnected) |
| 234 | + |
| 235 | + buffer := make([]byte, 2048) |
| 236 | + |
| 237 | + for { |
| 238 | + n, err := conn.Read(buffer) |
| 239 | + if n != 0 { |
| 240 | + dataReader1.Write(buffer[:n]) |
| 241 | + } |
| 242 | + |
| 243 | + if err != nil { |
| 244 | + break |
| 245 | + } |
| 246 | + } |
| 247 | + |
| 248 | + err = conn.Close() |
| 249 | + require.NoError(t, err) |
| 250 | + }() |
| 251 | + |
| 252 | + <-readerConnected |
| 253 | + |
| 254 | + writerDone := make(chan struct{}) |
| 255 | + |
| 256 | + go func() { |
| 257 | + defer close(writerDone) |
| 258 | + |
| 259 | + config := DefaultConfig() |
| 260 | + config.StreamId = "publish" |
| 261 | + config.Passphrase = "foobarfoobar" |
| 262 | + |
| 263 | + conn, err := Dial("srt", "127.0.0.1:6003", config) |
| 264 | + if !assert.NoError(t, err) { |
| 265 | + panic(err.Error()) |
| 266 | + } |
| 267 | + |
| 268 | + dialer, _ := conn.(*dialer) |
| 269 | + originalOnSend := dialer.conn.onSend |
| 270 | + dialer.conn.onSend = func(p packet.Packet) { |
| 271 | + if !p.Header().IsControlPacket { |
| 272 | + // Drop every 2nd original packet |
| 273 | + if !p.Header().RetransmittedPacketFlag && p.Header().PacketSequenceNumber.Val()%2 == 1 { |
| 274 | + return |
| 275 | + } |
| 276 | + } |
| 277 | + |
| 278 | + originalOnSend(p) |
| 279 | + } |
| 280 | + |
| 281 | + for i := 0; i < 5; i++ { |
| 282 | + n, err := conn.Write([]byte(message)) |
| 283 | + if !assert.NoError(t, err) { |
| 284 | + panic(err.Error()) |
| 285 | + } |
| 286 | + assert.Equal(t, 12, n) |
| 287 | + } |
| 288 | + |
| 289 | + time.Sleep(3 * time.Second) |
| 290 | + |
| 291 | + err = conn.Close() |
| 292 | + assert.NoError(t, err) |
| 293 | + }() |
| 294 | + |
| 295 | + <-writerDone |
| 296 | + <-readerDone |
| 297 | + |
| 298 | + reader1 := dataReader1.String() |
| 299 | + |
| 300 | + require.Equal(t, message+message+message+message+message, reader1) |
| 301 | +} |
| 302 | + |
149 | 303 | func TestEncryptionKeySwap(t *testing.T) {
|
150 | 304 | message := "Hello World!"
|
151 | 305 | passphrase := "foobarfoobar"
|
|
0 commit comments