Skip to content

Commit

Permalink
pkg/rpcserver: refactor to remove Fatalf calls
Browse files Browse the repository at this point in the history
  • Loading branch information
a-nogikh committed Jan 24, 2025
1 parent c67b65d commit 2301da6
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 63 deletions.
59 changes: 38 additions & 21 deletions pkg/flatrpc/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package flatrpc

import (
"context"
"errors"
"fmt"
"io"
Expand All @@ -13,9 +14,10 @@ import (
"sync"
"unsafe"

"github.com/google/flatbuffers/go"
flatbuffers "github.com/google/flatbuffers/go"
"github.com/google/syzkaller/pkg/log"
"github.com/google/syzkaller/pkg/stat"
"golang.org/x/sync/errgroup"
)

var (
Expand All @@ -30,36 +32,51 @@ type Serv struct {
ln net.Listener
}

func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) {
func Listen(addr string) (*Serv, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
go func() {
for {
conn, err := ln.Accept()
return &Serv{
Addr: ln.Addr().(*net.TCPAddr),
ln: ln,
}, nil
}

// Serve accepts incoming connections and calls handler for each of them.
// An error returned from the handler stops the server and aborts the whole processing.
func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error {
eg, ctx := errgroup.WithContext(baseCtx)
eg.Go(func() error {
select {
case <-ctx.Done():
break
}
s.Close()
return nil
})
for {
conn, err := ln.Accept()
if err != nil && errors.Is(err, net.ErrClosed) {
break
}
eg.Go(func() error {
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
}
var netErr *net.OpError
if errors.As(err, &netErr) && !netErr.Temporary() {
log.Fatalf("flatrpc: failed to accept: %v", err)
return fmt.Errorf("flatrpc: failed to accept: %v", err)
}
log.Logf(0, "flatrpc: failed to accept: %v", err)
continue
return nil
}
go func() {
c := NewConn(conn)
defer c.Close()
handler(c)
}()
}
}()
return &Serv{
Addr: ln.Addr().(*net.TCPAddr),
ln: ln,
}, nil

c := NewConn(conn)
defer c.Close()

return handler(ctx, c)
})
}
return eg.Wait()
}

func (s *Serv) Close() error {
Expand Down
35 changes: 22 additions & 13 deletions pkg/rpcserver/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/google/syzkaller/pkg/signal"
"github.com/google/syzkaller/pkg/vminfo"
"github.com/google/syzkaller/prog"
"golang.org/x/sync/errgroup"
)

type LocalConfig struct {
Expand All @@ -39,24 +40,29 @@ func RunLocal(cfg *LocalConfig) error {
if cfg.VMArch == "" {
cfg.VMArch = cfg.Target.Arch
}
if cfg.Context == nil {
cfg.Context = context.Background()
}
cfg.UseCoverEdges = true
cfg.FilterSignal = true
cfg.RPC = ":0"
cfg.PrintMachineCheck = log.V(1)
cfg.Stats = NewStats()
ctx := &local{
localCtx := &local{
cfg: cfg,
setupDone: make(chan bool),
}
serv := newImpl(cfg.Context, &cfg.Config, ctx)
if err := serv.Listen(); err != nil {
return err
}
defer serv.Close()
ctx.serv = serv
defer serv.Close() // TODO:
localCtx.serv = serv
// setupDone synchronizes assignment to ctx.serv and read of ctx.serv in MachineChecked
// for the race detector b/c it does not understand the synchronization via TCP socket connect/accept.
close(ctx.setupDone)
close(localCtx.setupDone)

eg, ctx := errgroup.WithContext(cfg.Context)

id := 0
connErr := serv.CreateInstance(id, nil, nil)
Expand All @@ -82,11 +88,14 @@ func RunLocal(cfg *LocalConfig) error {
if cfg.GDB {
cmd.Stdin = os.Stdin
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start executor: %w", err)
}
res := make(chan error, 1)
go func() { res <- cmd.Wait() }()
eg.Go(func() error {
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start executor: %w", err)
}
return cmd.Wait()
})

// TODO: cancel context on shutdown.
shutdown := make(chan struct{})
if cfg.HandleInterrupts {
osutil.HandleInterrupts(shutdown)
Expand All @@ -112,10 +121,10 @@ type local struct {
setupDone chan bool
}

func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) {
<-ctx.setupDone
ctx.serv.TriagedCorpus()
return ctx.cfg.MachineChecked(features, syscalls)
return ctx.cfg.MachineChecked(features, syscalls), nil
}

func (ctx *local) BugFrames() ([]string, []string) {
Expand All @@ -126,6 +135,6 @@ func (ctx *local) MaxSignal() signal.Signal {
return signal.FromRaw(ctx.cfg.MaxSignal, 0)
}

func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) []uint64 {
return ctx.cfg.CoverFilter
func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) {
return ctx.cfg.CoverFilter, nil
}
90 changes: 66 additions & 24 deletions pkg/rpcserver/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/google/syzkaller/prog"
"github.com/google/syzkaller/sys/targets"
"github.com/google/syzkaller/vm/dispatcher"
"golang.org/x/sync/errgroup"
)

type Config struct {
Expand Down Expand Up @@ -63,8 +64,8 @@ type RemoteConfig struct {
type Manager interface {
MaxSignal() signal.Signal
BugFrames() (leaks []string, races []string)
MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source
CoverageFilter(modules []*vminfo.KernelModule) []uint64
MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error)
CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error)
}

type Server interface {
Expand All @@ -88,6 +89,7 @@ type server struct {
checker *vminfo.Checker

infoOnce sync.Once
checkOnce sync.Once
checkDone atomic.Bool
checkFailures int
baseSource *queue.DynamicSourceCtl
Expand Down Expand Up @@ -217,23 +219,33 @@ func (serv *server) Close() error {
}

func (serv *server) Listen() error {
s, err := flatrpc.ListenAndServe(serv.cfg.RPC, serv.handleConn)
s, err := flatrpc.Listen(serv.cfg.RPC)
if err != nil {
return err
}
serv.serv = s
return nil
}

func (serv *server) Serve(baseCtx context.Context) error {
eg, ctx := errgroup.WithContext(baseCtx)
eg.Go(func() error {
return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error {
return serv.handleConn(ctx, eg, conn)
})
})
return eg.Wait()
}

func (serv *server) Port() int {
return serv.serv.Addr.Port
}

func (serv *server) handleConn(conn *flatrpc.Conn) {
func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *flatrpc.Conn) error {
connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn)
if err != nil {
log.Logf(1, "%s", err)
return
return nil
}
id := int(connectReq.Id)
log.Logf(1, "runner %v connected", id)
Expand All @@ -246,7 +258,8 @@ func (serv *server) handleConn(conn *flatrpc.Conn) {
serv.ShutdownInstance(id, true)
}()
} else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil {
log.Fatal(err)
// This is a fatal error.
return err
}
serv.StatVMRestarts.Add(1)

Expand All @@ -255,15 +268,23 @@ func (serv *server) handleConn(conn *flatrpc.Conn) {
serv.mu.Unlock()
if runner == nil {
log.Logf(2, "unknown VM %v tries to connect", id)
return
return nil
}

err = serv.handleRunnerConn(runner, conn)
err = serv.handleRunnerConn(ctx, eg, runner, conn)
log.Logf(2, "runner %v: %v", id, err)

if err != nil && errors.Is(err, errFatal) {
log.Logf(0, "%v", err)
return err
}

runner.resultCh <- err
return nil
}

func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error {
func (serv *server) handleRunnerConn(ctx context.Context, eg *errgroup.Group,
runner *Runner, conn *flatrpc.Conn) error {
opts := &handshakeConfig{
VMLess: serv.cfg.VMLess,
Files: serv.checker.RequiredFiles(),
Expand All @@ -278,12 +299,22 @@ func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error {
opts.Features = serv.cfg.Features
}

err := runner.Handshake(conn, opts)
info, err := runner.Handshake(conn, opts)
if err != nil {
log.Logf(1, "%v", err)
return err
}

serv.checkOnce.Do(func() {
// Run the machine check.
eg.Go(func() error {
if err := serv.runCheck(infoReq); err != nil {
return fmt.Errorf("check failed: %w", err)
}
return nil
})
})

if serv.triagedCorpus.Load() {
if err := runner.SendCorpusTriaged(); err != nil {
log.Logf(2, "%v", err)
Expand All @@ -294,6 +325,9 @@ func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error {
return serv.connectionLoop(runner)
}

// Used for errors incompatible with further RPCServer operation.
var errFatal = errors.New("aborting RPC server")

func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) {
modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files)
if err != nil {
Expand All @@ -307,31 +341,36 @@ func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handsha
log.Logf(0, "machine check failed: %v", infoReq.Error)
serv.checkFailures++
if serv.checkFailures == 10 {
log.Fatalf("machine check failing")
return handshakeResult{}, errors.New("%w: machine check failed too many times", errFatal)
}
return handshakeResult{}, errors.New("machine check failed")
}
var retErr error
serv.infoOnce.Do(func() {
serv.StatModules.Add(len(modules))
serv.canonicalModules = cover.NewCanonicalizer(modules, serv.cfg.Cover)
serv.coverFilter = serv.mgr.CoverageFilter(modules)
// Flatbuffers don't do deep copy of byte slices,
// so clone manually since we pass it a goroutine.
for _, file := range infoReq.Files {
file.Data = slices.Clone(file.Data)
var err error
serv.coverFilter, err = serv.mgr.CoverageFilter(modules)
if err != nil {
retErr = fmt.Errorf("%w: %w", errFatal, err)
return
}
// Now execute check programs.
go func() {
if err := serv.runCheck(infoReq); err != nil {
log.Fatalf("check failed: %v", err)
}
}()
})
if retErr != nil {
return handshakeResult{}, retErr
}
// Flatbuffers don't do deep copy of byte slices,
// so clone manually since we may later pass it a goroutine.
for _, file := range infoReq.Files {
file.Data = slices.Clone(file.Data)
}
canonicalizer := serv.canonicalModules.NewInstance(modules)
return handshakeResult{
CovFilter: canonicalizer.Decanonicalize(serv.coverFilter),
MachineInfo: machineInfo,
Canonicalizer: canonicalizer,
Files: infoReq.Files,
Features: infoReq.Features,
}, nil
}

Expand Down Expand Up @@ -371,7 +410,7 @@ func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error {
return nil
}

func (serv *server) runCheck(info *flatrpc.InfoRequest) error {
func (serv *server) runCheck(info *handshakeResult) error {
enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(info.Files, info.Features)
enabledCalls, transitivelyDisabled := serv.target.TransitivelyEnabledCalls(enabledCalls)
// Note: need to print disbled syscalls before failing due to an error.
Expand All @@ -384,7 +423,10 @@ func (serv *server) runCheck(info *flatrpc.InfoRequest) error {
}
enabledFeatures := features.Enabled()
serv.setupFeatures = features.NeedSetup()
newSource := serv.mgr.MachineChecked(enabledFeatures, enabledCalls)
newSource, err := serv.mgr.MachineChecked(enabledFeatures, enabledCalls)
if err != nil {
return err
}
serv.baseSource.Store(newSource)
serv.checkDone.Store(true)
return nil
Expand Down
Loading

0 comments on commit 2301da6

Please sign in to comment.