Skip to content

Commit

Permalink
refactor: use sync.OnceValue[s]
Browse files Browse the repository at this point in the history
  • Loading branch information
Zxilly committed Oct 11, 2024
1 parent e4fad86 commit 72f89cf
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 122 deletions.
52 changes: 27 additions & 25 deletions elf.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"errors"
"fmt"
"os"
"sync"
)

func openELF(fp string) (*elfFile, error) {
Expand All @@ -35,52 +36,53 @@ func openELF(fp string) (*elfFile, error) {
if err != nil {
return nil, fmt.Errorf("error when parsing the ELF file: %w", err)
}
return &elfFile{file: f, osFile: osFile, symtab: newSymbolTableOnce()}, nil
ret := &elfFile{file: f, osFile: osFile}
ret.getsymtab = sync.OnceValues(ret.initSymTab)
return ret, nil
}

var _ fileHandler = (*elfFile)(nil)

type elfFile struct {
file *elf.File
osFile *os.File
symtab *symbolTableOnce
file *elf.File
osFile *os.File
getsymtab func() (map[string]Symbol, error)
}

func (e *elfFile) initSymTab() error {
e.symtab.Do(func() {
syms, err := e.file.Symbols()
if err != nil {
// If the error is ErrNoSymbols, we just ignore it.
if !errors.Is(err, elf.ErrNoSymbols) {
e.symtab.err = fmt.Errorf("error when getting the symbols: %w", err)
}
return
func (e *elfFile) initSymTab() (map[string]Symbol, error) {
syms, err := e.file.Symbols()
if err != nil {
// If the error is ErrNoSymbols, we just ignore it.
if !errors.Is(err, elf.ErrNoSymbols) {
return nil, fmt.Errorf("error when getting the symbols: %w", err)
}
for _, sym := range syms {
e.symtab.table[sym.Name] = Symbol{
Name: sym.Name,
Value: sym.Value,
Size: sym.Size,
}
return nil, nil
}
symm := make(map[string]Symbol)
for _, sym := range syms {
symm[sym.Name] = Symbol{
Name: sym.Name,
Value: sym.Value,
Size: sym.Size,
}
})
return e.symtab.err
}
return symm, nil
}

func (e *elfFile) hasSymbolTable() (bool, error) {
err := e.initSymTab()
symm, err := e.getsymtab()
if err != nil {
return false, err
}
return len(e.symtab.table) > 0, nil
return len(symm) > 0, nil
}

func (e *elfFile) getSymbol(name string) (uint64, uint64, error) {
err := e.initSymTab()
symm, err := e.getsymtab()
if err != nil {
return 0, 0, err
}
sym, ok := e.symtab.table[name]
sym, ok := symm[name]
if !ok {
return 0, 0, ErrSymbolNotFound
}
Expand Down
82 changes: 40 additions & 42 deletions macho.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
"fmt"
"os"
"slices"
"sort"
"sync"
)

func openMachO(fp string) (*machoFile, error) {
Expand All @@ -36,66 +36,64 @@ func openMachO(fp string) (*machoFile, error) {
if err != nil {
return nil, fmt.Errorf("error when parsing the Mach-O file: %w", err)
}
return &machoFile{file: f, osFile: osFile, symtab: newSymbolTableOnce()}, nil
ret := &machoFile{file: f, osFile: osFile}
ret.getsymtab = sync.OnceValue(ret.initSymtab)
return ret, nil
}

var _ fileHandler = (*machoFile)(nil)

type machoFile struct {
file *macho.File
osFile *os.File
symtab *symbolTableOnce
file *macho.File
osFile *os.File
getsymtab func() map[string]Symbol
}

func (m *machoFile) initSymtab() error {
m.symtab.Do(func() {
if m.file.Symtab == nil {
// just do nothing, keep err nil and table empty
return
func (m *machoFile) initSymtab() (symm map[string]Symbol) {
if m.file.Symtab == nil {
// just do nothing, keep err nil and table empty
return
}
symm = make(map[string]Symbol)
const stabTypeMask = 0xe0
// Build a sorted list of addresses of all symbols.
// We infer the size of a symbol by looking at where the next symbol begins.
var addrs []uint64
for _, s := range m.file.Symtab.Syms {
// Skip stab debug info.
if s.Type&stabTypeMask == 0 {
addrs = append(addrs, s.Value)
}
const stabTypeMask = 0xe0
// Build a sorted list of addresses of all symbols.
// We infer the size of a symbol by looking at where the next symbol begins.
var addrs []uint64
for _, s := range m.file.Symtab.Syms {
}
slices.Sort(addrs)

var syms []Symbol
for _, s := range m.file.Symtab.Syms {
if s.Type&stabTypeMask != 0 {
// Skip stab debug info.
if s.Type&stabTypeMask == 0 {
addrs = append(addrs, s.Value)
}
continue
}
slices.Sort(addrs)

var syms []Symbol
for _, s := range m.file.Symtab.Syms {
if s.Type&stabTypeMask != 0 {
// Skip stab debug info.
continue
}
sym := Symbol{Name: s.Name, Value: s.Value}
i := sort.Search(len(addrs), func(x int) bool { return addrs[x] > s.Value })
if i < len(addrs) {
sym.Size = addrs[i] - s.Value
}
syms = append(syms, sym)
sym := Symbol{Name: s.Name, Value: s.Value}
i, found := slices.BinarySearch(addrs, s.Value)
if found {
sym.Size = addrs[i] - s.Value
}
syms = append(syms, sym)
}

for _, sym := range syms {
m.symtab.table[sym.Name] = sym
}
})
return nil
for _, sym := range syms {
symm[sym.Name] = sym
}

return
}

func (m *machoFile) hasSymbolTable() (bool, error) {
return m.file.Symtab != nil, nil
}

func (m *machoFile) getSymbol(name string) (uint64, uint64, error) {
err := m.initSymtab()
if err != nil {
return 0, 0, err
}
sym, ok := m.symtab.table[name]
sym, ok := m.getsymtab()[name]
if !ok {
return 0, 0, ErrSymbolNotFound
}
Expand Down
88 changes: 44 additions & 44 deletions pe.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ import (
"fmt"
"os"
"slices"
"sort"
"sync"
)

func openPE(fp string) (peF *peFile, err error) {
// Parsing by the file by debug/pe can panic if the PE file is malformed.
// To prevent a crash, we recover the panic and return it as an error
// instead.
go func() {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("error when processing PE file, probably corrupt: %s", r)
}
Expand Down Expand Up @@ -62,7 +62,8 @@ func openPE(fp string) (peF *peFile, err error) {
return
}

peF = &peFile{file: f, osFile: osFile, imageBase: imageBase, symtab: newSymbolTableOnce()}
peF = &peFile{file: f, osFile: osFile, imageBase: imageBase}
peF.getsymtab = sync.OnceValues(peF.initSymTab)
return
}

Expand All @@ -72,64 +73,63 @@ type peFile struct {
file *pe.File
osFile *os.File
imageBase uint64
symtab *symbolTableOnce
}

func (p *peFile) initSymTab() error {
p.symtab.Do(func() {
var addrs []uint64

var syms []Symbol
for _, s := range p.file.Symbols {
const (
NUndef = 0 // An undefined (extern) symbol
NAbs = -1 // An absolute symbol (e_value is a constant, not an address)
NDebug = -2 // A debugging symbol
)
sym := Symbol{Name: s.Name, Value: uint64(s.Value), Size: 0}
switch s.SectionNumber {
case NUndef, NAbs, NDebug: // do nothing
default:
if s.SectionNumber < 0 || len(p.file.Sections) < int(s.SectionNumber) {
p.symtab.err = fmt.Errorf("invalid section number in symbol table")
return
}
sect := p.file.Sections[s.SectionNumber-1]
sym.Value += p.imageBase + uint64(sect.VirtualAddress)
getsymtab func() (map[string]Symbol, error)
}

func (p *peFile) initSymTab() (map[string]Symbol, error) {
var addrs []uint64

var syms []Symbol
for _, s := range p.file.Symbols {
const (
NUndef = 0 // An undefined (extern) symbol
NAbs = -1 // An absolute symbol (e_value is a constant, not an address)
NDebug = -2 // A debugging symbol
)
sym := Symbol{Name: s.Name, Value: uint64(s.Value), Size: 0}
switch s.SectionNumber {
case NUndef, NAbs, NDebug: // do nothing
default:
if s.SectionNumber < 0 || len(p.file.Sections) < int(s.SectionNumber) {
return nil, fmt.Errorf("invalid section number in symbol table")
}
syms = append(syms, sym)
addrs = append(addrs, sym.Value)
sect := p.file.Sections[s.SectionNumber-1]
sym.Value += p.imageBase + uint64(sect.VirtualAddress)
}
syms = append(syms, sym)
addrs = append(addrs, sym.Value)
}

slices.Sort(addrs)
for i := range syms {
j := sort.Search(len(addrs), func(x int) bool { return addrs[x] > syms[i].Value })
if j < len(addrs) {
syms[i].Size = addrs[j] - syms[i].Value
}
slices.Sort(addrs)
for i := range syms {
j, found := slices.BinarySearch(addrs, syms[i].Value)
if found {
syms[i].Size = addrs[j] - syms[i].Value
}
}

for _, sym := range syms {
p.symtab.table[sym.Name] = sym
}
})
return p.symtab.err
symm := make(map[string]Symbol)
for _, sym := range syms {
symm[sym.Name] = sym
}

return symm, nil
}

func (p *peFile) hasSymbolTable() (bool, error) {
err := p.initSymTab()
symm, err := p.getsymtab()
if err != nil {
return false, err
}
return len(p.symtab.table) > 0, nil
return len(symm) > 0, nil
}

func (p *peFile) getSymbol(name string) (uint64, uint64, error) {
err := p.initSymTab()
symm, err := p.getsymtab()
if err != nil {
return 0, 0, err
}
sym, ok := p.symtab.table[name]
sym, ok := symm[name]
if !ok {
return 0, 0, ErrSymbolNotFound
}
Expand Down
11 changes: 0 additions & 11 deletions symbol.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gore

import (
"errors"
"sync"
)

var ErrSymbolNotFound = errors.New("symbol not found")
Expand All @@ -14,13 +13,3 @@ type Symbol struct {
// Size of the symbol. Only accurate on ELF files. For Mach-O and PE files, it was inferred by looking at the next symbol.
Size uint64
}

type symbolTableOnce struct {
*sync.Once
table map[string]Symbol
err error
}

func newSymbolTableOnce() *symbolTableOnce {
return &symbolTableOnce{Once: &sync.Once{}, table: make(map[string]Symbol)}
}

0 comments on commit 72f89cf

Please sign in to comment.