From 2301da67a49d26a45fc2d35d46a7fd926e9a1f45 Mon Sep 17 00:00:00 2001 From: Aleksandr Nogikh Date: Fri, 24 Jan 2025 17:17:53 +0100 Subject: [PATCH] pkg/rpcserver: refactor to remove Fatalf calls --- pkg/flatrpc/conn.go | 59 ++++++++++++++++--------- pkg/rpcserver/local.go | 35 +++++++++------ pkg/rpcserver/rpcserver.go | 90 ++++++++++++++++++++++++++++---------- pkg/rpcserver/runner.go | 10 ++--- syz-manager/manager.go | 6 +++ 5 files changed, 137 insertions(+), 63 deletions(-) diff --git a/pkg/flatrpc/conn.go b/pkg/flatrpc/conn.go index 33eca07e172f..8aae61d93d63 100644 --- a/pkg/flatrpc/conn.go +++ b/pkg/flatrpc/conn.go @@ -4,6 +4,7 @@ package flatrpc import ( + "context" "errors" "fmt" "io" @@ -13,9 +14,10 @@ import ( "sync" "unsafe" - "github.com/google/flatbuffers/go" + flatbuffers "github.com/google/flatbuffers/go" "github.com/google/syzkaller/pkg/log" "github.com/google/syzkaller/pkg/stat" + "golang.org/x/sync/errgroup" ) var ( @@ -30,36 +32,51 @@ type Serv struct { ln net.Listener } -func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) { +func Listen(addr string) (*Serv, error) { ln, err := net.Listen("tcp", addr) if err != nil { return nil, err } - go func() { - for { - conn, err := ln.Accept() + return &Serv{ + Addr: ln.Addr().(*net.TCPAddr), + ln: ln, + }, nil +} + +// Serve accepts incoming connections and calls handler for each of them. +// An error returned from the handler stops the server and aborts the whole processing. +func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error { + eg, ctx := errgroup.WithContext(baseCtx) + eg.Go(func() error { + select { + case <-ctx.Done(): + break + } + s.Close() + return nil + }) + for { + conn, err := ln.Accept() + if err != nil && errors.Is(err, net.ErrClosed) { + break + } + eg.Go(func() error { if err != nil { - if errors.Is(err, net.ErrClosed) { - break - } var netErr *net.OpError if errors.As(err, &netErr) && !netErr.Temporary() { - log.Fatalf("flatrpc: failed to accept: %v", err) + return fmt.Errorf("flatrpc: failed to accept: %v", err) } log.Logf(0, "flatrpc: failed to accept: %v", err) - continue + return nil } - go func() { - c := NewConn(conn) - defer c.Close() - handler(c) - }() - } - }() - return &Serv{ - Addr: ln.Addr().(*net.TCPAddr), - ln: ln, - }, nil + + c := NewConn(conn) + defer c.Close() + + return handler(ctx, c) + }) + } + return eg.Wait() } func (s *Serv) Close() error { diff --git a/pkg/rpcserver/local.go b/pkg/rpcserver/local.go index 5faa8334be87..83d4f54d65f0 100644 --- a/pkg/rpcserver/local.go +++ b/pkg/rpcserver/local.go @@ -16,6 +16,7 @@ import ( "github.com/google/syzkaller/pkg/signal" "github.com/google/syzkaller/pkg/vminfo" "github.com/google/syzkaller/prog" + "golang.org/x/sync/errgroup" ) type LocalConfig struct { @@ -39,12 +40,15 @@ func RunLocal(cfg *LocalConfig) error { if cfg.VMArch == "" { cfg.VMArch = cfg.Target.Arch } + if cfg.Context == nil { + cfg.Context = context.Background() + } cfg.UseCoverEdges = true cfg.FilterSignal = true cfg.RPC = ":0" cfg.PrintMachineCheck = log.V(1) cfg.Stats = NewStats() - ctx := &local{ + localCtx := &local{ cfg: cfg, setupDone: make(chan bool), } @@ -52,11 +56,13 @@ func RunLocal(cfg *LocalConfig) error { if err := serv.Listen(); err != nil { return err } - defer serv.Close() - ctx.serv = serv + defer serv.Close() // TODO: + localCtx.serv = serv // setupDone synchronizes assignment to ctx.serv and read of ctx.serv in MachineChecked // for the race detector b/c it does not understand the synchronization via TCP socket connect/accept. - close(ctx.setupDone) + close(localCtx.setupDone) + + eg, ctx := errgroup.WithContext(cfg.Context) id := 0 connErr := serv.CreateInstance(id, nil, nil) @@ -82,11 +88,14 @@ func RunLocal(cfg *LocalConfig) error { if cfg.GDB { cmd.Stdin = os.Stdin } - if err := cmd.Start(); err != nil { - return fmt.Errorf("failed to start executor: %w", err) - } - res := make(chan error, 1) - go func() { res <- cmd.Wait() }() + eg.Go(func() error { + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start executor: %w", err) + } + return cmd.Wait() + }) + + // TODO: cancel context on shutdown. shutdown := make(chan struct{}) if cfg.HandleInterrupts { osutil.HandleInterrupts(shutdown) @@ -112,10 +121,10 @@ type local struct { setupDone chan bool } -func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source { +func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) { <-ctx.setupDone ctx.serv.TriagedCorpus() - return ctx.cfg.MachineChecked(features, syscalls) + return ctx.cfg.MachineChecked(features, syscalls), nil } func (ctx *local) BugFrames() ([]string, []string) { @@ -126,6 +135,6 @@ func (ctx *local) MaxSignal() signal.Signal { return signal.FromRaw(ctx.cfg.MaxSignal, 0) } -func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { - return ctx.cfg.CoverFilter +func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) { + return ctx.cfg.CoverFilter, nil } diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go index 9d259b733186..09395c6b5678 100644 --- a/pkg/rpcserver/rpcserver.go +++ b/pkg/rpcserver/rpcserver.go @@ -28,6 +28,7 @@ import ( "github.com/google/syzkaller/prog" "github.com/google/syzkaller/sys/targets" "github.com/google/syzkaller/vm/dispatcher" + "golang.org/x/sync/errgroup" ) type Config struct { @@ -63,8 +64,8 @@ type RemoteConfig struct { type Manager interface { MaxSignal() signal.Signal BugFrames() (leaks []string, races []string) - MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source - CoverageFilter(modules []*vminfo.KernelModule) []uint64 + MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) + CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) } type Server interface { @@ -88,6 +89,7 @@ type server struct { checker *vminfo.Checker infoOnce sync.Once + checkOnce sync.Once checkDone atomic.Bool checkFailures int baseSource *queue.DynamicSourceCtl @@ -217,7 +219,7 @@ func (serv *server) Close() error { } func (serv *server) Listen() error { - s, err := flatrpc.ListenAndServe(serv.cfg.RPC, serv.handleConn) + s, err := flatrpc.Listen(serv.cfg.RPC) if err != nil { return err } @@ -225,15 +227,25 @@ func (serv *server) Listen() error { return nil } +func (serv *server) Serve(baseCtx context.Context) error { + eg, ctx := errgroup.WithContext(baseCtx) + eg.Go(func() error { + return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error { + return serv.handleConn(ctx, eg, conn) + }) + }) + return eg.Wait() +} + func (serv *server) Port() int { return serv.serv.Addr.Port } -func (serv *server) handleConn(conn *flatrpc.Conn) { +func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *flatrpc.Conn) error { connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn) if err != nil { log.Logf(1, "%s", err) - return + return nil } id := int(connectReq.Id) log.Logf(1, "runner %v connected", id) @@ -246,7 +258,8 @@ func (serv *server) handleConn(conn *flatrpc.Conn) { serv.ShutdownInstance(id, true) }() } else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil { - log.Fatal(err) + // This is a fatal error. + return err } serv.StatVMRestarts.Add(1) @@ -255,15 +268,23 @@ func (serv *server) handleConn(conn *flatrpc.Conn) { serv.mu.Unlock() if runner == nil { log.Logf(2, "unknown VM %v tries to connect", id) - return + return nil } - err = serv.handleRunnerConn(runner, conn) + err = serv.handleRunnerConn(ctx, eg, runner, conn) log.Logf(2, "runner %v: %v", id, err) + + if err != nil && errors.Is(err, errFatal) { + log.Logf(0, "%v", err) + return err + } + runner.resultCh <- err + return nil } -func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { +func (serv *server) handleRunnerConn(ctx context.Context, eg *errgroup.Group, + runner *Runner, conn *flatrpc.Conn) error { opts := &handshakeConfig{ VMLess: serv.cfg.VMLess, Files: serv.checker.RequiredFiles(), @@ -278,12 +299,22 @@ func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { opts.Features = serv.cfg.Features } - err := runner.Handshake(conn, opts) + info, err := runner.Handshake(conn, opts) if err != nil { log.Logf(1, "%v", err) return err } + serv.checkOnce.Do(func() { + // Run the machine check. + eg.Go(func() error { + if err := serv.runCheck(infoReq); err != nil { + return fmt.Errorf("check failed: %w", err) + } + return nil + }) + }) + if serv.triagedCorpus.Load() { if err := runner.SendCorpusTriaged(); err != nil { log.Logf(2, "%v", err) @@ -294,6 +325,9 @@ func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { return serv.connectionLoop(runner) } +// Used for errors incompatible with further RPCServer operation. +var errFatal = errors.New("aborting RPC server") + func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) { modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files) if err != nil { @@ -307,31 +341,36 @@ func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handsha log.Logf(0, "machine check failed: %v", infoReq.Error) serv.checkFailures++ if serv.checkFailures == 10 { - log.Fatalf("machine check failing") + return handshakeResult{}, errors.New("%w: machine check failed too many times", errFatal) } return handshakeResult{}, errors.New("machine check failed") } + var retErr error serv.infoOnce.Do(func() { serv.StatModules.Add(len(modules)) serv.canonicalModules = cover.NewCanonicalizer(modules, serv.cfg.Cover) - serv.coverFilter = serv.mgr.CoverageFilter(modules) - // Flatbuffers don't do deep copy of byte slices, - // so clone manually since we pass it a goroutine. - for _, file := range infoReq.Files { - file.Data = slices.Clone(file.Data) + var err error + serv.coverFilter, err = serv.mgr.CoverageFilter(modules) + if err != nil { + retErr = fmt.Errorf("%w: %w", errFatal, err) + return } - // Now execute check programs. - go func() { - if err := serv.runCheck(infoReq); err != nil { - log.Fatalf("check failed: %v", err) - } - }() }) + if retErr != nil { + return handshakeResult{}, retErr + } + // Flatbuffers don't do deep copy of byte slices, + // so clone manually since we may later pass it a goroutine. + for _, file := range infoReq.Files { + file.Data = slices.Clone(file.Data) + } canonicalizer := serv.canonicalModules.NewInstance(modules) return handshakeResult{ CovFilter: canonicalizer.Decanonicalize(serv.coverFilter), MachineInfo: machineInfo, Canonicalizer: canonicalizer, + Files: infoReq.Files, + Features: infoReq.Features, }, nil } @@ -371,7 +410,7 @@ func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error { return nil } -func (serv *server) runCheck(info *flatrpc.InfoRequest) error { +func (serv *server) runCheck(info *handshakeResult) error { enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(info.Files, info.Features) enabledCalls, transitivelyDisabled := serv.target.TransitivelyEnabledCalls(enabledCalls) // Note: need to print disbled syscalls before failing due to an error. @@ -384,7 +423,10 @@ func (serv *server) runCheck(info *flatrpc.InfoRequest) error { } enabledFeatures := features.Enabled() serv.setupFeatures = features.NeedSetup() - newSource := serv.mgr.MachineChecked(enabledFeatures, enabledCalls) + newSource, err := serv.mgr.MachineChecked(enabledFeatures, enabledCalls) + if err != nil { + return err + } serv.baseSource.Store(newSource) serv.checkDone.Store(true) return nil diff --git a/pkg/rpcserver/runner.go b/pkg/rpcserver/runner.go index a6b763b9a79c..a4ac8efec985 100644 --- a/pkg/rpcserver/runner.go +++ b/pkg/rpcserver/runner.go @@ -82,7 +82,7 @@ type handshakeResult struct { Canonicalizer *cover.CanonicalizerInstance } -func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error { +func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) (handshakeResult, error) { if runner.updInfo != nil { runner.updInfo(func(info *dispatcher.Info) { info.Status = "handshake" @@ -104,21 +104,21 @@ func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error Features: cfg.Features, } if err := flatrpc.Send(conn, connectReply); err != nil { - return err + return handshakeResult{}, err } infoReq, err := flatrpc.Recv[*flatrpc.InfoRequestRaw](conn) if err != nil { - return err + return handshakeResult{}, err } ret, err := cfg.Callback(infoReq) if err != nil { - return err + return handshakeResult{}, err } infoReply := &flatrpc.InfoReply{ CoverFilter: ret.CovFilter, } if err := flatrpc.Send(conn, infoReply); err != nil { - return err + return handshakeResult{}, err } runner.mu.Lock() runner.conn = conn diff --git a/syz-manager/manager.go b/syz-manager/manager.go index 7a85a6c9af01..e72c59ec485e 100644 --- a/syz-manager/manager.go +++ b/syz-manager/manager.go @@ -310,6 +310,12 @@ func RunManager(mode *Mode, cfg *mgrconfig.Config) { if err := mgr.serv.Listen(); err != nil { log.Fatalf("failed to start rpc server: %v", err) } + go func() { + err := mgr.serv.Serve(ctx) + if err != nil { + log.Fatalf("%s", err) + } + }() log.Logf(0, "serving rpc on tcp://%v", mgr.serv.Port()) if cfg.DashboardAddr != "" {