Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pkg: avoid log.Fatal calls #5700

Merged
merged 3 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 42 additions & 23 deletions pkg/flatrpc/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package flatrpc

import (
"context"
"errors"
"fmt"
"io"
Expand All @@ -13,9 +14,10 @@ import (
"sync"
"unsafe"

"github.com/google/flatbuffers/go"
flatbuffers "github.com/google/flatbuffers/go"
"github.com/google/syzkaller/pkg/log"
"github.com/google/syzkaller/pkg/stat"
"golang.org/x/sync/errgroup"
)

var (
Expand All @@ -30,38 +32,55 @@ type Serv struct {
ln net.Listener
}

func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) {
func Listen(addr string) (*Serv, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
go func() {
for {
conn, err := ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
}
var netErr *net.OpError
if errors.As(err, &netErr) && !netErr.Temporary() {
log.Fatalf("flatrpc: failed to accept: %v", err)
}
log.Logf(0, "flatrpc: failed to accept: %v", err)
continue
}
go func() {
c := NewConn(conn)
defer c.Close()
handler(c)
}()
}
}()
return &Serv{
Addr: ln.Addr().(*net.TCPAddr),
ln: ln,
}, nil
}

// Serve accepts incoming connections and calls handler for each of them.
// An error returned from the handler stops the server and aborts the whole processing.
func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error {
eg, ctx := errgroup.WithContext(baseCtx)
go func() {
// If the context is cancelled, stop the server.
<-ctx.Done()
s.Close()
}()
for {
conn, err := s.ln.Accept()
if err != nil && errors.Is(err, net.ErrClosed) {
break
}
if err != nil {
var netErr *net.OpError
if errors.As(err, &netErr) && !netErr.Temporary() {
return fmt.Errorf("flatrpc: failed to accept: %w", err)
}
log.Logf(0, "flatrpc: failed to accept: %v", err)
continue
}
eg.Go(func() error {
connCtx, cancel := context.WithCancel(ctx)
defer cancel()

c := NewConn(conn)
// Closing the server does not automatically close all the connections.
go func() {
<-connCtx.Done()
dvyukov marked this conversation as resolved.
Show resolved Hide resolved
c.Close()
}()
return handler(connCtx, c)
})
}
return eg.Wait()
}

func (s *Serv) Close() error {
return s.ln.Close()
}
Expand Down
104 changes: 61 additions & 43 deletions pkg/flatrpc/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
package flatrpc

import (
"context"
"fmt"
"net"
"os"
"reflect"
"runtime/debug"
"sync"
"syscall"
"testing"
"time"

"github.com/google/flatbuffers/go"
flatbuffers "github.com/google/flatbuffers/go"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -40,35 +43,39 @@ func TestConn(t *testing.T) {
},
}

done := make(chan bool)
defer func() {
<-done
}()
serv, err := ListenAndServe(":0", func(c *Conn) {
defer close(done)
connectReqGot, err := Recv[*ConnectRequestRaw](c)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, connectReq, connectReqGot)

if err := Send(c, connectReply); err != nil {
t.Fatal(err)
}

for i := 0; i < 10; i++ {
got, err := Recv[*ExecutorMessageRaw](c)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, executorMsg, got)
}
})
serv, err := Listen(":0")
if err != nil {
t.Fatal(err)
}
defer serv.Close()

done := make(chan error)
go func() {
done <- serv.Serve(context.Background(),
func(_ context.Context, c *Conn) error {
connectReqGot, err := Recv[*ConnectRequestRaw](c)
if err != nil {
return err
}
if !reflect.DeepEqual(connectReq, connectReqGot) {
return fmt.Errorf("connectReq != connectReqGot")
}

if err := Send(c, connectReply); err != nil {
return err
}

for i := 0; i < 10; i++ {
got, err := Recv[*ExecutorMessageRaw](c)
if err != nil {
return nil
}
if !reflect.DeepEqual(executorMsg, got) {
return fmt.Errorf("executorMsg !=got")
}
}
return nil
})
}()
c := dial(t, serv.Addr.String())
defer c.Close()

Expand All @@ -87,6 +94,11 @@ func TestConn(t *testing.T) {
t.Fatal(err)
}
}

serv.Close()
if err := <-done; err != nil {
t.Fatal(err)
}
}

func BenchmarkConn(b *testing.B) {
Expand All @@ -103,26 +115,27 @@ func BenchmarkConn(b *testing.B) {
Files: []string{"file1"},
}

done := make(chan bool)
defer func() {
<-done
}()
serv, err := ListenAndServe(":0", func(c *Conn) {
defer close(done)
for i := 0; i < b.N; i++ {
_, err := Recv[*ConnectRequestRaw](c)
if err != nil {
b.Fatal(err)
}
if err := Send(c, connectReply); err != nil {
b.Fatal(err)
}
}
})
serv, err := Listen(":0")
if err != nil {
b.Fatal(err)
}
defer serv.Close()
done := make(chan error)

go func() {
done <- serv.Serve(context.Background(),
func(_ context.Context, c *Conn) error {
for i := 0; i < b.N; i++ {
_, err := Recv[*ConnectRequestRaw](c)
if err != nil {
return err
}
if err := Send(c, connectReply); err != nil {
return err
}
}
return nil
})
}()

c := dial(b, serv.Addr.String())
defer c.Close()
Expand All @@ -138,6 +151,11 @@ func BenchmarkConn(b *testing.B) {
b.Fatal(err)
}
}

serv.Close()
if err := <-done; err != nil {
b.Fatal(err)
}
}

func dial(t testing.TB, addr string) *Conn {
Expand Down
47 changes: 34 additions & 13 deletions pkg/manager/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,15 @@ func RunDiffFuzzer(ctx context.Context, baseCfg, newCfg *mgrconfig.Config, debug
if err != nil {
return err
}
go func() {
new.candidates <- LoadSeeds(newCfg, true).Candidates
}()
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
info, err := LoadSeeds(newCfg, true)
if err != nil {
return err
}
new.candidates <- info.Candidates
a-nogikh marked this conversation as resolved.
Show resolved Hide resolved
return nil
})

stream := queue.NewRandomQueue(4096, rand.New(rand.NewSource(time.Now().UnixNano())))
base.source = stream
Expand All @@ -73,8 +79,10 @@ func RunDiffFuzzer(ctx context.Context, baseCfg, newCfg *mgrconfig.Config, debug
}
new.http = diffCtx.http
}
diffCtx.Loop(ctx)
return nil
eg.Go(func() error {
return diffCtx.Loop(ctx)
})
return eg.Wait()
a-nogikh marked this conversation as resolved.
Show resolved Hide resolved
}

type diffContext struct {
Expand Down Expand Up @@ -102,7 +110,11 @@ func (dc *diffContext) Loop(baseCtx context.Context) error {
// Let both base and patched instances somewhat progress in fuzzing before we take
// VMs away for bug reproduction.
// TODO: determine the exact moment of corpus triage.
time.Sleep(15 * time.Minute)
select {
case <-time.After(15 * time.Minute):
case <-ctx.Done():
return nil
}
log.Logf(0, "starting bug reproductions")
reproLoop.Loop(ctx)
return nil
Expand Down Expand Up @@ -297,9 +309,10 @@ func (kc *kernelContext) BugFrames() (leaks, races []string) {
return nil, nil
}

func (kc *kernelContext) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
func (kc *kernelContext) MachineChecked(features flatrpc.Feature,
syscalls map[*prog.Syscall]bool) (queue.Source, error) {
if len(syscalls) == 0 {
log.Fatalf("all system calls are disabled")
return nil, fmt.Errorf("all system calls are disabled")
}
log.Logf(0, "%s: machine check complete", kc.name)
kc.features = features
Expand All @@ -311,7 +324,7 @@ func (kc *kernelContext) MachineChecked(features flatrpc.Feature, syscalls map[*
source = kc.source
}
opts := fuzzer.DefaultExecOpts(kc.cfg, features, kc.debug)
return queue.DefaultOpts(source, opts)
return queue.DefaultOpts(source, opts), nil
}

func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
Expand Down Expand Up @@ -343,7 +356,15 @@ func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*pro
kc.http.Corpus.Store(corpusObj)
}

filtered := FilterCandidates(<-kc.candidates, syscalls, false).Candidates
var candidates []fuzzer.Candidate
select {
case candidates = <-kc.candidates:
case <-kc.ctx.Done():
// The loop will be aborted later.
break
}

filtered := FilterCandidates(candidates, syscalls, false).Candidates
log.Logf(0, "%s: adding %d seeds", kc.name, len(filtered))
fuzzerObj.AddCandidates(filtered)

Expand All @@ -367,11 +388,11 @@ func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*pro
return fuzzerObj
}

func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) []uint64 {
func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) {
kc.reportGenerator.Init(modules)
filters, err := PrepareCoverageFilters(kc.reportGenerator, kc.cfg, false)
if err != nil {
log.Fatalf("failed to init coverage filter: %v", err)
return nil, fmt.Errorf("failed to init coverage filter: %w", err)
}
kc.coverFilters = filters
log.Logf(0, "cover filter size: %d", len(filters.ExecutorFilter))
Expand All @@ -386,7 +407,7 @@ func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) []uint64
for pc := range filters.ExecutorFilter {
pcs = append(pcs, pc)
}
return pcs
return pcs, nil
}

func (kc *kernelContext) fuzzerInstance(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) {
Expand Down
Loading
Loading