diff --git a/pkg/flatrpc/conn.go b/pkg/flatrpc/conn.go index 33eca07e172f..47b2654937fc 100644 --- a/pkg/flatrpc/conn.go +++ b/pkg/flatrpc/conn.go @@ -4,6 +4,7 @@ package flatrpc import ( + "context" "errors" "fmt" "io" @@ -13,9 +14,10 @@ import ( "sync" "unsafe" - "github.com/google/flatbuffers/go" + flatbuffers "github.com/google/flatbuffers/go" "github.com/google/syzkaller/pkg/log" "github.com/google/syzkaller/pkg/stat" + "golang.org/x/sync/errgroup" ) var ( @@ -30,38 +32,55 @@ type Serv struct { ln net.Listener } -func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) { +func Listen(addr string) (*Serv, error) { ln, err := net.Listen("tcp", addr) if err != nil { return nil, err } - go func() { - for { - conn, err := ln.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - break - } - var netErr *net.OpError - if errors.As(err, &netErr) && !netErr.Temporary() { - log.Fatalf("flatrpc: failed to accept: %v", err) - } - log.Logf(0, "flatrpc: failed to accept: %v", err) - continue - } - go func() { - c := NewConn(conn) - defer c.Close() - handler(c) - }() - } - }() return &Serv{ Addr: ln.Addr().(*net.TCPAddr), ln: ln, }, nil } +// Serve accepts incoming connections and calls handler for each of them. +// An error returned from the handler stops the server and aborts the whole processing. +func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error { + eg, ctx := errgroup.WithContext(baseCtx) + go func() { + // If the context is cancelled, stop the server. + <-ctx.Done() + s.Close() + }() + for { + conn, err := s.ln.Accept() + if err != nil && errors.Is(err, net.ErrClosed) { + break + } + if err != nil { + var netErr *net.OpError + if errors.As(err, &netErr) && !netErr.Temporary() { + return fmt.Errorf("flatrpc: failed to accept: %w", err) + } + log.Logf(0, "flatrpc: failed to accept: %v", err) + continue + } + eg.Go(func() error { + connCtx, cancel := context.WithCancel(ctx) + defer cancel() + + c := NewConn(conn) + // Closing the server does not automatically close all the connections. + go func() { + <-connCtx.Done() + c.Close() + }() + return handler(connCtx, c) + }) + } + return eg.Wait() +} + func (s *Serv) Close() error { return s.ln.Close() } diff --git a/pkg/flatrpc/conn_test.go b/pkg/flatrpc/conn_test.go index 132fd1cddaee..4b108a5a4360 100644 --- a/pkg/flatrpc/conn_test.go +++ b/pkg/flatrpc/conn_test.go @@ -4,15 +4,18 @@ package flatrpc import ( + "context" + "fmt" "net" "os" + "reflect" "runtime/debug" "sync" "syscall" "testing" "time" - "github.com/google/flatbuffers/go" + flatbuffers "github.com/google/flatbuffers/go" "github.com/stretchr/testify/assert" ) @@ -40,35 +43,39 @@ func TestConn(t *testing.T) { }, } - done := make(chan bool) - defer func() { - <-done - }() - serv, err := ListenAndServe(":0", func(c *Conn) { - defer close(done) - connectReqGot, err := Recv[*ConnectRequestRaw](c) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, connectReq, connectReqGot) - - if err := Send(c, connectReply); err != nil { - t.Fatal(err) - } - - for i := 0; i < 10; i++ { - got, err := Recv[*ExecutorMessageRaw](c) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, executorMsg, got) - } - }) + serv, err := Listen(":0") if err != nil { t.Fatal(err) } - defer serv.Close() + done := make(chan error) + go func() { + done <- serv.Serve(context.Background(), + func(_ context.Context, c *Conn) error { + connectReqGot, err := Recv[*ConnectRequestRaw](c) + if err != nil { + return err + } + if !reflect.DeepEqual(connectReq, connectReqGot) { + return fmt.Errorf("connectReq != connectReqGot") + } + + if err := Send(c, connectReply); err != nil { + return err + } + + for i := 0; i < 10; i++ { + got, err := Recv[*ExecutorMessageRaw](c) + if err != nil { + return nil + } + if !reflect.DeepEqual(executorMsg, got) { + return fmt.Errorf("executorMsg !=got") + } + } + return nil + }) + }() c := dial(t, serv.Addr.String()) defer c.Close() @@ -87,6 +94,11 @@ func TestConn(t *testing.T) { t.Fatal(err) } } + + serv.Close() + if err := <-done; err != nil { + t.Fatal(err) + } } func BenchmarkConn(b *testing.B) { @@ -103,26 +115,27 @@ func BenchmarkConn(b *testing.B) { Files: []string{"file1"}, } - done := make(chan bool) - defer func() { - <-done - }() - serv, err := ListenAndServe(":0", func(c *Conn) { - defer close(done) - for i := 0; i < b.N; i++ { - _, err := Recv[*ConnectRequestRaw](c) - if err != nil { - b.Fatal(err) - } - if err := Send(c, connectReply); err != nil { - b.Fatal(err) - } - } - }) + serv, err := Listen(":0") if err != nil { b.Fatal(err) } - defer serv.Close() + done := make(chan error) + + go func() { + done <- serv.Serve(context.Background(), + func(_ context.Context, c *Conn) error { + for i := 0; i < b.N; i++ { + _, err := Recv[*ConnectRequestRaw](c) + if err != nil { + return err + } + if err := Send(c, connectReply); err != nil { + return err + } + } + return nil + }) + }() c := dial(b, serv.Addr.String()) defer c.Close() @@ -138,6 +151,11 @@ func BenchmarkConn(b *testing.B) { b.Fatal(err) } } + + serv.Close() + if err := <-done; err != nil { + b.Fatal(err) + } } func dial(t testing.TB, addr string) *Conn { diff --git a/pkg/manager/diff.go b/pkg/manager/diff.go index a96d4f9a82c7..bdf638a5154a 100644 --- a/pkg/manager/diff.go +++ b/pkg/manager/diff.go @@ -45,9 +45,15 @@ func RunDiffFuzzer(ctx context.Context, baseCfg, newCfg *mgrconfig.Config, debug if err != nil { return err } - go func() { - new.candidates <- LoadSeeds(newCfg, true).Candidates - }() + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + info, err := LoadSeeds(newCfg, true) + if err != nil { + return err + } + new.candidates <- info.Candidates + return nil + }) stream := queue.NewRandomQueue(4096, rand.New(rand.NewSource(time.Now().UnixNano()))) base.source = stream @@ -73,8 +79,10 @@ func RunDiffFuzzer(ctx context.Context, baseCfg, newCfg *mgrconfig.Config, debug } new.http = diffCtx.http } - diffCtx.Loop(ctx) - return nil + eg.Go(func() error { + return diffCtx.Loop(ctx) + }) + return eg.Wait() } type diffContext struct { @@ -102,7 +110,11 @@ func (dc *diffContext) Loop(baseCtx context.Context) error { // Let both base and patched instances somewhat progress in fuzzing before we take // VMs away for bug reproduction. // TODO: determine the exact moment of corpus triage. - time.Sleep(15 * time.Minute) + select { + case <-time.After(15 * time.Minute): + case <-ctx.Done(): + return nil + } log.Logf(0, "starting bug reproductions") reproLoop.Loop(ctx) return nil @@ -297,9 +309,10 @@ func (kc *kernelContext) BugFrames() (leaks, races []string) { return nil, nil } -func (kc *kernelContext) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source { +func (kc *kernelContext) MachineChecked(features flatrpc.Feature, + syscalls map[*prog.Syscall]bool) (queue.Source, error) { if len(syscalls) == 0 { - log.Fatalf("all system calls are disabled") + return nil, fmt.Errorf("all system calls are disabled") } log.Logf(0, "%s: machine check complete", kc.name) kc.features = features @@ -311,7 +324,7 @@ func (kc *kernelContext) MachineChecked(features flatrpc.Feature, syscalls map[* source = kc.source } opts := fuzzer.DefaultExecOpts(kc.cfg, features, kc.debug) - return queue.DefaultOpts(source, opts) + return queue.DefaultOpts(source, opts), nil } func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source { @@ -343,7 +356,15 @@ func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*pro kc.http.Corpus.Store(corpusObj) } - filtered := FilterCandidates(<-kc.candidates, syscalls, false).Candidates + var candidates []fuzzer.Candidate + select { + case candidates = <-kc.candidates: + case <-kc.ctx.Done(): + // The loop will be aborted later. + break + } + + filtered := FilterCandidates(candidates, syscalls, false).Candidates log.Logf(0, "%s: adding %d seeds", kc.name, len(filtered)) fuzzerObj.AddCandidates(filtered) @@ -367,11 +388,11 @@ func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*pro return fuzzerObj } -func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { +func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) { kc.reportGenerator.Init(modules) filters, err := PrepareCoverageFilters(kc.reportGenerator, kc.cfg, false) if err != nil { - log.Fatalf("failed to init coverage filter: %v", err) + return nil, fmt.Errorf("failed to init coverage filter: %w", err) } kc.coverFilters = filters log.Logf(0, "cover filter size: %d", len(filters.ExecutorFilter)) @@ -386,7 +407,7 @@ func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) []uint64 for pc := range filters.ExecutorFilter { pcs = append(pcs, pc) } - return pcs + return pcs, nil } func (kc *kernelContext) fuzzerInstance(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) { diff --git a/pkg/manager/seeds.go b/pkg/manager/seeds.go index 2d769d43be03..f17c36ecb267 100644 --- a/pkg/manager/seeds.go +++ b/pkg/manager/seeds.go @@ -31,72 +31,25 @@ type Seeds struct { Candidates []fuzzer.Candidate } -func LoadSeeds(cfg *mgrconfig.Config, immutable bool) Seeds { +func LoadSeeds(cfg *mgrconfig.Config, immutable bool) (Seeds, error) { var info Seeds var err error info.CorpusDB, err = db.Open(filepath.Join(cfg.Workdir, "corpus.db"), !immutable) if err != nil { if info.CorpusDB == nil { - log.Fatalf("failed to open corpus database: %v", err) + return Seeds{}, fmt.Errorf("failed to open corpus database: %w", err) } log.Errorf("read %v inputs from corpus and got error: %v", len(info.CorpusDB.Records), err) } info.Fresh = len(info.CorpusDB.Records) == 0 corpusFlags := versionToFlags(info.CorpusDB.Version) - type Input struct { - IsSeed bool - Key string - Path string - Data []byte - Prog *prog.Prog - Err error - } - procs := runtime.GOMAXPROCS(0) - inputs := make(chan *Input, procs) - outputs := make(chan *Input, procs) - var wg sync.WaitGroup - wg.Add(procs) - for p := 0; p < procs; p++ { - go func() { - defer wg.Done() - for inp := range inputs { - inp.Prog, inp.Err = ParseSeed(cfg.Target, inp.Data) - outputs <- inp - } - }() - } + outputs := make(chan *input, 32) + chErr := make(chan error, 1) go func() { - wg.Wait() + chErr <- readInputs(cfg, info.CorpusDB, outputs) close(outputs) }() - go func() { - for key, rec := range info.CorpusDB.Records { - inputs <- &Input{ - Key: key, - Data: rec.Val, - } - } - seedPath := filepath.Join("sys", cfg.TargetOS, "test") - seedDir := filepath.Join(cfg.Syzkaller, seedPath) - if osutil.IsExist(seedDir) { - seeds, err := os.ReadDir(seedDir) - if err != nil { - log.Fatalf("failed to read seeds dir: %v", err) - } - for _, seed := range seeds { - data, err := os.ReadFile(filepath.Join(seedDir, seed.Name())) - if err != nil { - log.Fatalf("failed to read seed %v: %v", seed.Name(), err) - } - inputs <- &Input{ - IsSeed: true, - Path: filepath.Join(seedPath, seed.Name()), - Data: data, - } - } - } - close(inputs) - }() + brokenSeeds := 0 skippedSeeds := 0 var brokenCorpus []string @@ -130,6 +83,9 @@ func LoadSeeds(cfg *mgrconfig.Config, immutable bool) Seeds { Flags: flags, }) } + if err := <-chErr; err != nil { + return Seeds{}, err + } if len(brokenCorpus)+brokenSeeds != 0 { log.Logf(0, "broken programs in the corpus: %v, broken seeds: %v", len(brokenCorpus), brokenSeeds) } @@ -142,14 +98,69 @@ func LoadSeeds(cfg *mgrconfig.Config, immutable bool) Seeds { info.CorpusDB.Delete(sig) } if err := info.CorpusDB.Flush(); err != nil { - log.Fatalf("failed to save corpus database: %v", err) + return Seeds{}, fmt.Errorf("failed to save corpus database: %w", err) } } // Switch database to the mode when it does not keep records in memory. // We don't need them anymore and they consume lots of memory. info.CorpusDB.DiscardData() info.Candidates = candidates - return info + return info, nil +} + +type input struct { + IsSeed bool + Key string + Path string + Data []byte + Prog *prog.Prog + Err error +} + +func readInputs(cfg *mgrconfig.Config, db *db.DB, output chan *input) error { + procs := runtime.GOMAXPROCS(0) + inputs := make(chan *input, procs) + var wg sync.WaitGroup + wg.Add(procs) + + defer wg.Wait() + defer close(inputs) + for p := 0; p < procs; p++ { + go func() { + defer wg.Done() + for inp := range inputs { + inp.Prog, inp.Err = ParseSeed(cfg.Target, inp.Data) + output <- inp + } + }() + } + + for key, rec := range db.Records { + inputs <- &input{ + Key: key, + Data: rec.Val, + } + } + seedPath := filepath.Join("sys", cfg.TargetOS, "test") + seedDir := filepath.Join(cfg.Syzkaller, seedPath) + if osutil.IsExist(seedDir) { + seeds, err := os.ReadDir(seedDir) + if err != nil { + return fmt.Errorf("failed to read seeds dir: %w", err) + } + for _, seed := range seeds { + data, err := os.ReadFile(filepath.Join(seedDir, seed.Name())) + if err != nil { + return fmt.Errorf("failed to read seed %v: %w", seed.Name(), err) + } + inputs <- &input{ + IsSeed: true, + Path: filepath.Join(seedPath, seed.Name()), + Data: data, + } + } + } + return nil } const CurrentDBVersion = 5 diff --git a/pkg/rpcserver/local.go b/pkg/rpcserver/local.go index 4ab8827ae1e4..e4e128dcfcc6 100644 --- a/pkg/rpcserver/local.go +++ b/pkg/rpcserver/local.go @@ -16,6 +16,7 @@ import ( "github.com/google/syzkaller/pkg/signal" "github.com/google/syzkaller/pkg/vminfo" "github.com/google/syzkaller/prog" + "golang.org/x/sync/errgroup" ) type LocalConfig struct { @@ -39,26 +40,32 @@ func RunLocal(cfg *LocalConfig) error { if cfg.VMArch == "" { cfg.VMArch = cfg.Target.Arch } + if cfg.Context == nil { + cfg.Context = context.Background() + } cfg.UseCoverEdges = true cfg.FilterSignal = true cfg.RPC = ":0" cfg.PrintMachineCheck = log.V(1) cfg.Stats = NewStats() - ctx := &local{ + localCtx := &local{ cfg: cfg, setupDone: make(chan bool), } - serv := newImpl(&cfg.Config, ctx) + serv := newImpl(&cfg.Config, localCtx) if err := serv.Listen(); err != nil { return err } defer serv.Close() - ctx.serv = serv + localCtx.serv = serv // setupDone synchronizes assignment to ctx.serv and read of ctx.serv in MachineChecked // for the race detector b/c it does not understand the synchronization via TCP socket connect/accept. - close(ctx.setupDone) + close(localCtx.setupDone) + + cancelCtx, cancel := context.WithCancel(cfg.Context) + eg, ctx := errgroup.WithContext(cancelCtx) - id := 0 + const id = 0 connErr := serv.CreateInstance(id, nil, nil) defer serv.ShutdownInstance(id, true) @@ -73,7 +80,7 @@ func RunLocal(cfg *LocalConfig) error { cfg.Executor, }, args...) } - cmd := exec.Command(bin, args...) + cmd := exec.CommandContext(ctx, bin, args...) cmd.Dir = cfg.Dir if cfg.Debug || cfg.GDB { cmd.Stdout = os.Stdout @@ -82,28 +89,32 @@ func RunLocal(cfg *LocalConfig) error { if cfg.GDB { cmd.Stdin = os.Stdin } - if err := cmd.Start(); err != nil { - return fmt.Errorf("failed to start executor: %w", err) - } - res := make(chan error, 1) - go func() { res <- cmd.Wait() }() + eg.Go(func() error { + return serv.Serve(ctx) + }) + eg.Go(func() error { + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start executor: %w", err) + } + err := cmd.Wait() + // Note that we ignore the error if we killed the process by closing the context. + if err == nil || ctx.Err() != nil { + return nil + } + return fmt.Errorf("executor process exited: %w", err) + }) + shutdown := make(chan struct{}) if cfg.HandleInterrupts { osutil.HandleInterrupts(shutdown) } - var cmdErr error select { + case <-ctx.Done(): case <-shutdown: - case <-cfg.Context.Done(): case <-connErr: - case err := <-res: - cmdErr = fmt.Errorf("executor process exited: %w", err) - } - if cmdErr == nil { - cmd.Process.Kill() - <-res } - return cmdErr + cancel() + return eg.Wait() } type local struct { @@ -112,10 +123,10 @@ type local struct { setupDone chan bool } -func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source { +func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) { <-ctx.setupDone ctx.serv.TriagedCorpus() - return ctx.cfg.MachineChecked(features, syscalls) + return ctx.cfg.MachineChecked(features, syscalls), nil } func (ctx *local) BugFrames() ([]string, []string) { @@ -126,6 +137,6 @@ func (ctx *local) MaxSignal() signal.Signal { return signal.FromRaw(ctx.cfg.MaxSignal, 0) } -func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { - return ctx.cfg.CoverFilter +func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) { + return ctx.cfg.CoverFilter, nil } diff --git a/pkg/rpcserver/mocks/Manager.go b/pkg/rpcserver/mocks/Manager.go index 810b5028fa13..c0c9621de1a4 100644 --- a/pkg/rpcserver/mocks/Manager.go +++ b/pkg/rpcserver/mocks/Manager.go @@ -53,7 +53,7 @@ func (_m *Manager) BugFrames() ([]string, []string) { } // CoverageFilter provides a mock function with given fields: modules -func (_m *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { +func (_m *Manager) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) { ret := _m.Called(modules) if len(ret) == 0 { @@ -61,6 +61,10 @@ func (_m *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { } var r0 []uint64 + var r1 error + if rf, ok := ret.Get(0).(func([]*vminfo.KernelModule) ([]uint64, error)); ok { + return rf(modules) + } if rf, ok := ret.Get(0).(func([]*vminfo.KernelModule) []uint64); ok { r0 = rf(modules) } else { @@ -69,11 +73,17 @@ func (_m *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { } } - return r0 + if rf, ok := ret.Get(1).(func([]*vminfo.KernelModule) error); ok { + r1 = rf(modules) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // MachineChecked provides a mock function with given fields: features, syscalls -func (_m *Manager) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source { +func (_m *Manager) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) { ret := _m.Called(features, syscalls) if len(ret) == 0 { @@ -81,6 +91,10 @@ func (_m *Manager) MachineChecked(features flatrpc.Feature, syscalls map[*prog.S } var r0 queue.Source + var r1 error + if rf, ok := ret.Get(0).(func(flatrpc.Feature, map[*prog.Syscall]bool) (queue.Source, error)); ok { + return rf(features, syscalls) + } if rf, ok := ret.Get(0).(func(flatrpc.Feature, map[*prog.Syscall]bool) queue.Source); ok { r0 = rf(features, syscalls) } else { @@ -89,7 +103,13 @@ func (_m *Manager) MachineChecked(features flatrpc.Feature, syscalls map[*prog.S } } - return r0 + if rf, ok := ret.Get(1).(func(flatrpc.Feature, map[*prog.Syscall]bool) error); ok { + r1 = rf(features, syscalls) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // MaxSignal provides a mock function with given fields: diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go index 003c5f4b978d..b3b518b04206 100644 --- a/pkg/rpcserver/rpcserver.go +++ b/pkg/rpcserver/rpcserver.go @@ -28,6 +28,7 @@ import ( "github.com/google/syzkaller/prog" "github.com/google/syzkaller/sys/targets" "github.com/google/syzkaller/vm/dispatcher" + "golang.org/x/sync/errgroup" ) type Config struct { @@ -63,8 +64,8 @@ type RemoteConfig struct { type Manager interface { MaxSignal() signal.Signal BugFrames() (leaks []string, races []string) - MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source - CoverageFilter(modules []*vminfo.KernelModule) []uint64 + MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) + CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) } type Server interface { @@ -72,6 +73,7 @@ type Server interface { Close() error Port() int TriagedCorpus() + Serve(context.Context) error CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error ShutdownInstance(id int, crashed bool, extraExecs ...report.ExecutorInfo) ([]ExecRecord, []byte) StopFuzzing(id int) @@ -88,6 +90,7 @@ type server struct { checker *vminfo.Checker infoOnce sync.Once + checkOnce sync.Once checkDone atomic.Bool checkFailures int baseSource *queue.DynamicSourceCtl @@ -217,7 +220,7 @@ func (serv *server) Close() error { } func (serv *server) Listen() error { - s, err := flatrpc.ListenAndServe(serv.cfg.RPC, serv.handleConn) + s, err := flatrpc.Listen(serv.cfg.RPC) if err != nil { return err } @@ -225,15 +228,25 @@ func (serv *server) Listen() error { return nil } +func (serv *server) Serve(ctx context.Context) error { + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error { + return serv.handleConn(ctx, g, conn) + }) + }) + return g.Wait() +} + func (serv *server) Port() int { return serv.serv.Addr.Port } -func (serv *server) handleConn(conn *flatrpc.Conn) { +func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *flatrpc.Conn) error { connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn) if err != nil { log.Logf(1, "%s", err) - return + return nil } id := int(connectReq.Id) log.Logf(1, "runner %v connected", id) @@ -246,7 +259,8 @@ func (serv *server) handleConn(conn *flatrpc.Conn) { serv.ShutdownInstance(id, true) }() } else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil { - log.Fatal(err) + // This is a fatal error. + return err } serv.StatVMRestarts.Add(1) @@ -255,15 +269,23 @@ func (serv *server) handleConn(conn *flatrpc.Conn) { serv.mu.Unlock() if runner == nil { log.Logf(2, "unknown VM %v tries to connect", id) - return + return nil } - err = serv.handleRunnerConn(runner, conn) + err = serv.handleRunnerConn(ctx, eg, runner, conn) log.Logf(2, "runner %v: %v", id, err) + + if err != nil && errors.Is(err, errFatal) { + log.Logf(0, "%v", err) + return err + } + runner.resultCh <- err + return nil } -func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { +func (serv *server) handleRunnerConn(ctx context.Context, eg *errgroup.Group, + runner *Runner, conn *flatrpc.Conn) error { opts := &handshakeConfig{ VMLess: serv.cfg.VMLess, Files: serv.checker.RequiredFiles(), @@ -278,22 +300,36 @@ func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { opts.Features = serv.cfg.Features } - err := runner.Handshake(conn, opts) + info, err := runner.Handshake(conn, opts) if err != nil { log.Logf(1, "%v", err) return err } + serv.checkOnce.Do(func() { + // Run the machine check. + eg.Go(func() error { + if err := serv.runCheck(ctx, &info); err != nil { + return fmt.Errorf("%w: %w", errFatal, err) + } + return nil + }) + }) + if serv.triagedCorpus.Load() { - if err := runner.SendCorpusTriaged(); err != nil { - log.Logf(2, "%v", err) - return err - } + eg.Go(runner.SendCorpusTriaged) } + go func() { + <-ctx.Done() + runner.Stop() + }() return serv.connectionLoop(runner) } +// Used for errors incompatible with further RPCServer operation. +var errFatal = errors.New("aborting RPC server") + func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) { modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files) if err != nil { @@ -307,31 +343,36 @@ func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handsha log.Logf(0, "machine check failed: %v", infoReq.Error) serv.checkFailures++ if serv.checkFailures == 10 { - log.Fatalf("machine check failing") + return handshakeResult{}, fmt.Errorf("%w: machine check failed too many times", errFatal) } return handshakeResult{}, errors.New("machine check failed") } + var retErr error serv.infoOnce.Do(func() { serv.StatModules.Add(len(modules)) serv.canonicalModules = cover.NewCanonicalizer(modules, serv.cfg.Cover) - serv.coverFilter = serv.mgr.CoverageFilter(modules) - // Flatbuffers don't do deep copy of byte slices, - // so clone manually since we pass it a goroutine. - for _, file := range infoReq.Files { - file.Data = slices.Clone(file.Data) + var err error + serv.coverFilter, err = serv.mgr.CoverageFilter(modules) + if err != nil { + retErr = fmt.Errorf("%w: %w", errFatal, err) + return } - // Now execute check programs. - go func() { - if err := serv.runCheck(infoReq); err != nil { - log.Fatalf("check failed: %v", err) - } - }() }) + if retErr != nil { + return handshakeResult{}, retErr + } + // Flatbuffers don't do deep copy of byte slices, + // so clone manually since we may later pass it a goroutine. + for _, file := range infoReq.Files { + file.Data = slices.Clone(file.Data) + } canonicalizer := serv.canonicalModules.NewInstance(modules) return handshakeResult{ CovFilter: canonicalizer.Decanonicalize(serv.coverFilter), MachineInfo: machineInfo, Canonicalizer: canonicalizer, + Files: infoReq.Files, + Features: infoReq.Features, }, nil } @@ -371,10 +412,8 @@ func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error { return nil } -func (serv *server) runCheck(info *flatrpc.InfoRequest) error { - // TODO: take context as a parameter. - enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(context.Background(), - info.Files, info.Features) +func (serv *server) runCheck(ctx context.Context, info *handshakeResult) error { + enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(ctx, info.Files, info.Features) enabledCalls, transitivelyDisabled := serv.target.TransitivelyEnabledCalls(enabledCalls) // Note: need to print disbled syscalls before failing due to an error. // This helps to debug "all system calls are disabled". @@ -386,7 +425,10 @@ func (serv *server) runCheck(info *flatrpc.InfoRequest) error { } enabledFeatures := features.Enabled() serv.setupFeatures = features.NeedSetup() - newSource := serv.mgr.MachineChecked(enabledFeatures, enabledCalls) + newSource, err := serv.mgr.MachineChecked(enabledFeatures, enabledCalls) + if err != nil { + return err + } serv.baseSource.Store(newSource) serv.checkDone.Store(true) return nil diff --git a/pkg/rpcserver/rpcserver_test.go b/pkg/rpcserver/rpcserver_test.go index a885ad720f34..0c4984e93b93 100644 --- a/pkg/rpcserver/rpcserver_test.go +++ b/pkg/rpcserver/rpcserver_test.go @@ -4,10 +4,12 @@ package rpcserver import ( + "context" "net" "testing" "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" "github.com/google/syzkaller/pkg/flatrpc" "github.com/google/syzkaller/pkg/mgrconfig" @@ -212,7 +214,11 @@ func TestHandleConn(t *testing.T) { serv.CreateInstance(1, injectExec, nil) go flatrpc.Send(clientConn, tt.req) - serv.handleConn(serverConn) + var eg errgroup.Group + serv.handleConn(context.Background(), &eg, serverConn) + if err := eg.Wait(); err != nil { + t.Fatal(err) + } }) } } diff --git a/pkg/rpcserver/runner.go b/pkg/rpcserver/runner.go index a6b763b9a79c..de38d29f7a5f 100644 --- a/pkg/rpcserver/runner.go +++ b/pkg/rpcserver/runner.go @@ -77,12 +77,14 @@ type handshakeConfig struct { } type handshakeResult struct { + Files []*flatrpc.FileInfo + Features []*flatrpc.FeatureInfo CovFilter []uint64 MachineInfo []byte Canonicalizer *cover.CanonicalizerInstance } -func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error { +func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) (handshakeResult, error) { if runner.updInfo != nil { runner.updInfo(func(info *dispatcher.Info) { info.Status = "handshake" @@ -104,21 +106,21 @@ func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error Features: cfg.Features, } if err := flatrpc.Send(conn, connectReply); err != nil { - return err + return handshakeResult{}, err } infoReq, err := flatrpc.Recv[*flatrpc.InfoRequestRaw](conn) if err != nil { - return err + return handshakeResult{}, err } ret, err := cfg.Callback(infoReq) if err != nil { - return err + return handshakeResult{}, err } infoReply := &flatrpc.InfoReply{ CoverFilter: ret.CovFilter, } if err := flatrpc.Send(conn, infoReply); err != nil { - return err + return handshakeResult{}, err } runner.mu.Lock() runner.conn = conn @@ -132,7 +134,7 @@ func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error info.DetailedStatus = runner.QueryStatus }) } - return nil + return ret, nil } func (runner *Runner) ConnectionLoop() error { diff --git a/syz-manager/manager.go b/syz-manager/manager.go index 1ea489dd9198..fdb4929d79d6 100644 --- a/syz-manager/manager.go +++ b/syz-manager/manager.go @@ -310,6 +310,13 @@ func RunManager(mode *Mode, cfg *mgrconfig.Config) { if err := mgr.serv.Listen(); err != nil { log.Fatalf("failed to start rpc server: %v", err) } + ctx := vm.ShutdownCtx() + go func() { + err := mgr.serv.Serve(ctx) + if err != nil { + log.Fatalf("%s", err) + } + }() log.Logf(0, "serving rpc on tcp://%v", mgr.serv.Port()) if cfg.DashboardAddr != "" { @@ -355,7 +362,6 @@ func RunManager(mode *Mode, cfg *mgrconfig.Config) { mgr.http.ReproLoop = mgr.reproLoop mgr.http.TogglePause = mgr.pool.TogglePause - ctx := vm.ShutdownCtx() if mgr.cfg.HTTP != "" { go func() { err := mgr.http.Serve(ctx) @@ -546,7 +552,10 @@ func (mgr *Manager) processRepro(res *manager.ReproResult) { } func (mgr *Manager) preloadCorpus() { - info := manager.LoadSeeds(mgr.cfg, false) + info, err := manager.LoadSeeds(mgr.cfg, false) + if err != nil { + log.Fatalf("failed to load corpus: %v", err) + } mgr.fresh = info.Fresh mgr.corpusDB = info.CorpusDB mgr.corpusPreload <- info.Candidates @@ -1085,9 +1094,10 @@ func (mgr *Manager) BugFrames() (leaks, races []string) { return } -func (mgr *Manager) MachineChecked(features flatrpc.Feature, enabledSyscalls map[*prog.Syscall]bool) queue.Source { +func (mgr *Manager) MachineChecked(features flatrpc.Feature, + enabledSyscalls map[*prog.Syscall]bool) (queue.Source, error) { if len(enabledSyscalls) == 0 { - log.Fatalf("all system calls are disabled") + return nil, fmt.Errorf("all system calls are disabled") } if mgr.mode.ExitAfterMachineCheck { mgr.exit(mgr.mode.Name) @@ -1162,15 +1172,15 @@ func (mgr *Manager) MachineChecked(features flatrpc.Feature, enabledSyscalls map mgr.serv = nil return queue.Callback(func() *queue.Request { return nil - }) + }), nil } - return source + return source, nil } else if mgr.mode == ModeCorpusRun { ctx := &corpusRunner{ candidates: candidates, rnd: rand.New(rand.NewSource(time.Now().UnixNano())), } - return queue.DefaultOpts(ctx, opts) + return queue.DefaultOpts(ctx, opts), nil } else if mgr.mode == ModeRunTests { ctx := &runtest.Context{ Dir: filepath.Join(mgr.cfg.Syzkaller, "sys", mgr.cfg.Target.OS, "test"), @@ -1192,7 +1202,7 @@ func (mgr *Manager) MachineChecked(features flatrpc.Feature, enabledSyscalls map } mgr.exit("tests") }() - return ctx + return ctx, nil } else if mgr.mode == ModeIfaceProbe { exec := queue.Plain() go func() { @@ -1206,7 +1216,7 @@ func (mgr *Manager) MachineChecked(features flatrpc.Feature, enabledSyscalls map } mgr.exit("interface probe") }() - return exec + return exec, nil } panic(fmt.Sprintf("unexpected mode %q", mgr.mode.Name)) } @@ -1427,11 +1437,11 @@ func (mgr *Manager) dashboardReproTasks() { } } -func (mgr *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { +func (mgr *Manager) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) { mgr.reportGenerator.Init(modules) filters, err := manager.PrepareCoverageFilters(mgr.reportGenerator, mgr.cfg, true) if err != nil { - log.Fatalf("failed to init coverage filter: %v", err) + return nil, fmt.Errorf("failed to init coverage filter: %w", err) } mgr.coverFilters = filters mgr.http.Cover.Store(&manager.CoverageInfo{ @@ -1443,7 +1453,7 @@ func (mgr *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { for pc := range filters.ExecutorFilter { pcs = append(pcs, pc) } - return pcs + return pcs, nil } func publicWebAddr(addr string) string {