diff --git a/pkg/sbf/interpreter.go b/pkg/sbf/interpreter.go new file mode 100644 index 0000000..67c2668 --- /dev/null +++ b/pkg/sbf/interpreter.go @@ -0,0 +1 @@ +package sbf diff --git a/pkg/sbf/loader/arithmetic.go b/pkg/sbf/loader/arithmetic.go index eee9df6..f9d164d 100644 --- a/pkg/sbf/loader/arithmetic.go +++ b/pkg/sbf/loader/arithmetic.go @@ -13,6 +13,14 @@ func clampAddUint64(x uint64, y uint64) uint64 { return z } +func clampSubUint64(x uint64, y uint64) uint64 { + z, borrow := bits.Sub64(x, y, 0) + if borrow != 0 { + return 0 + } + return z +} + type addrRange struct { min, max uint64 } diff --git a/pkg/sbf/loader/copy.go b/pkg/sbf/loader/copy.go index 2c59a9a..2484139 100644 --- a/pkg/sbf/loader/copy.go +++ b/pkg/sbf/loader/copy.go @@ -34,7 +34,7 @@ func (l *Loader) getText() error { if err := l.checkSectionAddrs(l.shText); err != nil { return fmt.Errorf("invalid .text: %w", err) } - l.text = addrRange{min: l.shText.Off, max: l.shText.Off + l.shText.Size} + l.textRange = addrRange{min: l.shText.Off, max: l.shText.Off + l.shText.Size} return nil } @@ -72,7 +72,7 @@ func (l *Loader) mapSections() error { } l.progRange.insert(section) - if section.min != l.text.min { + if section.min != l.textRange.min { l.rodatas = append(l.rodatas, section) } } @@ -116,10 +116,13 @@ func (l *Loader) copySections() error { return err } } - if err := l.copySection(l.text); err != nil { + if err := l.copySection(l.textRange); err != nil { return err } + // Special sub-slice for text + l.text = l.getRange(l.textRange) + return nil } diff --git a/pkg/sbf/loader/loader.go b/pkg/sbf/loader/loader.go index 07adbc8..44e81fc 100644 --- a/pkg/sbf/loader/loader.go +++ b/pkg/sbf/loader/loader.go @@ -39,12 +39,14 @@ type Loader struct { // Program section/segment mappings // Uses physical addressing rodatas []addrRange - text addrRange + textRange addrRange progRange addrRange // Contains most of ELF (.text and rodata-like) // Non-loaded sections are zeroed - program []byte + program []byte + text []byte + entrypoint uint64 // program counter // Symbols funcs map[uint32]uint64 @@ -77,6 +79,9 @@ func NewLoaderFromBytes(buf []byte) (*Loader, error) { } // Load parses, loads, and relocates an SBF program. +// +// This loader differs from rbpf in a few ways: +// We don't support spec bugs, we relocate after loading. func (l *Loader) Load() (*sbf.Program, error) { if err := l.parse(); err != nil { return nil, err @@ -87,8 +92,13 @@ func (l *Loader) Load() (*sbf.Program, error) { if err := l.relocate(); err != nil { return nil, err } - prog := &sbf.Program{ - RO: l.program, - } - return prog, nil + return l.getProgram(), nil +} + +func (l *Loader) getProgram() *sbf.Program { + return &sbf.Program{ + RO: l.program, + Text: l.text, + Entrypoint: l.entrypoint, + } } diff --git a/pkg/sbf/loader/loader_test.go b/pkg/sbf/loader/loader_test.go index a83d942..6a5757d 100644 --- a/pkg/sbf/loader/loader_test.go +++ b/pkg/sbf/loader/loader_test.go @@ -14,7 +14,7 @@ var ( soNoop []byte ) -func TestLoadProgram_Noop(t *testing.T) { +func TestLoader_Noop(t *testing.T) { loader, err := NewLoaderFromBytes(soNoop) require.NoError(t, err) @@ -163,20 +163,28 @@ func TestLoadProgram_Noop(t *testing.T) { assert.Equal(t, addrRange{ min: 0x1000, max: 0x1060, - }, loader.text) + }, loader.textRange) assertZeroBytes(t, loader.program[:loader.rodatas[0].min]) assert.Equal(t, soNoop[loader.rodatas[0].min:loader.rodatas[0].max], loader.getRange(loader.rodatas[0])) - assertZeroBytes(t, loader.program[loader.rodatas[0].max:loader.text.min]) + assertZeroBytes(t, loader.program[loader.rodatas[0].max:loader.textRange.min]) assert.Equal(t, - soNoop[loader.text.min:loader.text.max], - loader.getRange(loader.text)) - assertZeroBytes(t, loader.program[loader.text.max:]) + soNoop[loader.textRange.min:loader.textRange.max], + loader.getRange(loader.textRange)) + assertZeroBytes(t, loader.program[loader.textRange.max:]) + + assert.Equal(t, + soNoop[loader.textRange.min:loader.textRange.max], + loader.text) err = loader.relocate() require.NoError(t, err) + + assert.Equal(t, uint64(0), loader.entrypoint) + + assert.NotNil(t, loader.getProgram()) } func assertZeroBytes(t *testing.T, b []byte) { @@ -193,3 +201,13 @@ func isZeroBytes(b []byte) bool { } return true } + +func TestVerifier(t *testing.T) { + loader, err := NewLoaderFromBytes(soNoop) + require.NoError(t, err) + + program, err := loader.Load() + require.NoError(t, err) + + require.NoError(t, program.Verify()) +} diff --git a/pkg/sbf/loader/relocate.go b/pkg/sbf/loader/relocate.go index 8614f67..91b4bad 100644 --- a/pkg/sbf/loader/relocate.go +++ b/pkg/sbf/loader/relocate.go @@ -11,19 +11,23 @@ import ( // relocate applies ELF relocations (for syscalls and position-independent code). func (l *Loader) relocate() error { + l.funcs = make(map[uint32]uint64) if err := l.fixupRelativeCalls(); err != nil { return err } if err := l.applyDynamicRelocs(); err != nil { return err } + if err := l.getEntrypoint(); err != nil { + return err + } return nil } func (l *Loader) fixupRelativeCalls() error { // TODO does invariant text.size%8 == 0 hold? - insCount := l.text.len() / sbf.SlotSize - buf := l.getRange(l.text) + insCount := l.textRange.len() / sbf.SlotSize + buf := l.getRange(l.textRange) for i := uint64(0); i < insCount; i++ { off := i * sbf.SlotSize slot := sbf.GetSlot(buf[off : off+sbf.SlotSize]) @@ -99,7 +103,7 @@ func (l *Loader) applyReloc(reloc *elf.Rel64) error { binary.LittleEndian.PutUint32(l.program[rOff+4:rOff+8], uint32(addr)) binary.LittleEndian.PutUint32(l.program[rOff+12:rOff+16], uint32(addr>>32)) case R_BPF_64_RELATIVE: - if l.text.contains(rOff) { + if l.textRange.contains(rOff) { immLow := binary.LittleEndian.Uint32(l.program[rOff+4 : rOff+8]) immHi := binary.LittleEndian.Uint32(l.program[rOff+12 : rOff+16]) @@ -141,10 +145,10 @@ func (l *Loader) applyReloc(reloc *elf.Rel64) error { var hash uint32 if elf.ST_TYPE(sym.Info) == elf.STT_FUNC && sym.Value != 0 { // Function call - if !l.text.contains(sym.Value) { + if !l.textRange.contains(sym.Value) { return fmt.Errorf("out-of-bounds R_BPF_64_32 function ref") } - target := (sym.Value - l.text.min) / 8 + target := (sym.Value - l.textRange.min) / 8 hash, err = l.registerFunc(target) if err != nil { return fmt.Errorf("R_BPF_64_32 function ref: %w", err) @@ -162,6 +166,15 @@ func (l *Loader) applyReloc(reloc *elf.Rel64) error { return nil } +func (l *Loader) getEntrypoint() error { + offset := l.eh.Entry - l.shText.Addr + if offset%sbf.SlotSize != 0 { + return fmt.Errorf("invalid entrypoint") + } + l.entrypoint = offset / sbf.SlotSize + return nil +} + const ( // EntrypointHash equals SymbolHash("entrypoint") EntrypointHash = uint32(0x71e3cf81) diff --git a/pkg/sbf/program.go b/pkg/sbf/program.go index b3d4894..f8bbdaa 100644 --- a/pkg/sbf/program.go +++ b/pkg/sbf/program.go @@ -2,5 +2,12 @@ package sbf // Program is a loaded SBF program. type Program struct { - RO []byte // read-only segment containing text and ELFs + RO []byte // read-only segment containing text and ELFs + Text []byte + Entrypoint uint64 // PC +} + +// Verify runs the static bytecode verifier. +func (p *Program) Verify() error { + return NewVerifier(p).Verify() } diff --git a/pkg/sbf/sbf.go b/pkg/sbf/sbf.go index 789a329..547355d 100644 --- a/pkg/sbf/sbf.go +++ b/pkg/sbf/sbf.go @@ -34,25 +34,30 @@ func GetSlot(buf []byte) Slot { // Op returns the opcode field. func (s Slot) Op() uint8 { - return uint8(s >> 56) + return uint8(s) } // Dst returns the destination register field. func (s Slot) Dst() uint8 { - return uint8(s>>52) & 0xF + return uint8(s>>12) & 0xF } // Src returns the source register field. func (s Slot) Src() uint8 { - return uint8(s>>48) & 0xF + return uint8(s>>8) & 0xF } // Off returns the offset field. -func (s Slot) Off() uint16 { - return uint16(s >> 32) +func (s Slot) Off() int16 { + return int16(uint16(s >> 16)) } // Imm returns the immediate field. func (s Slot) Imm() int32 { - return int32(uint32(s)) + return int32(uint32(s >> 32)) +} + +// Uimm returns the immediate field as unsigned. +func (s Slot) Uimm() uint32 { + return uint32(s >> 32) } diff --git a/pkg/sbf/verifier.go b/pkg/sbf/verifier.go new file mode 100644 index 0000000..c4a34d7 --- /dev/null +++ b/pkg/sbf/verifier.go @@ -0,0 +1,113 @@ +package sbf + +import "fmt" + +type Verifier struct { + Program *Program +} + +func NewVerifier(p *Program) *Verifier { + return &Verifier{Program: p} +} + +func (v *Verifier) Verify() error { + text := v.Program.Text + if len(text)%SlotSize != 0 { + return fmt.Errorf("odd .text size") + } + if len(text) == 0 { + return fmt.Errorf("empty text") + } + + for pc := uint64(0); (pc+1)*SlotSize <= uint64(len(text)); pc++ { + insBytes := text[pc*SlotSize:] + ins := GetSlot(insBytes) + + if ins.Src() > 10 { + return fmt.Errorf("invalid src register") + } + switch ins.Op() { + case OpLdxb, OpLdxh, OpLdxw, OpLdxdw: + case OpAdd32Imm, OpAdd32Reg, OpAdd64Imm, OpAdd64Reg: + case OpSub32Imm, OpSub32Reg, OpSub64Imm, OpSub64Reg: + case OpMul32Imm, OpMul32Reg, OpMul64Imm, OpMul64Reg: + case OpOr32Imm, OpOr32Reg, OpOr64Imm, OpOr64Reg: + case OpAnd32Imm, OpAnd32Reg, OpAnd64Imm, OpAnd64Reg: + case OpNeg32, OpNeg64: + case OpXor32Imm, OpXor32Reg, OpXor64Imm, OpXor64Reg: + case OpMov32Imm, OpMov32Reg, OpMov64Imm, OpMov64Reg: + case OpDiv32Reg, OpDiv64Reg: + case OpMod32Reg, OpMod64Reg: + case OpSdiv32Reg, OpSdiv64Reg: + case OpCall, OpExit: + // nothing + case OpStb, OpSth, OpStw, OpStdw, + OpStxb, OpStxh, OpStxw, OpStxdw: + if ins.Dst() > 10 { + return fmt.Errorf("invalid dst register") + } + continue + case OpLsh32Imm, OpRsh32Imm, OpArsh32Imm: + if ins.Uimm() > 31 { + return fmt.Errorf("32-bit shift out of bounds") + } + case OpLsh64Imm, OpRsh64Imm, OpArsh64Imm: + if ins.Uimm() > 63 { + return fmt.Errorf("64-bit shift out of bounds") + } + case OpLe, OpBe: + switch ins.Uimm() { + case 16, 32, 64: + // ok + default: + return fmt.Errorf("invalid bit size for endianness conversion") + } + case OpSdiv32Imm, OpSdiv64Imm: + fallthrough + case OpDiv32Imm, OpDiv64Imm, OpMod32Imm, OpMod64Imm: + if ins.Imm() == 0 { + return fmt.Errorf("division by zero") + } + case OpJa, + OpJeqImm, OpJeqReg, + OpJgtImm, OpJgtReg, + OpJgeImm, OpJgeReg, + OpJltImm, OpJltReg, + OpJleImm, OpJleReg, + OpJsetImm, OpJsetReg, + OpJneImm, OpJneReg, + OpJsgtImm, OpJsgtReg, + OpJsgeImm, OpJsgeReg, + OpJsltImm, OpJsltReg, + OpJsleImm, OpJsleReg: + dst := int64(pc) + int64(ins.Off()) + 1 + if dst < 0 || (dst*SlotSize) >= int64(len(text)) { + return fmt.Errorf("jump out of code") + } + dstIns := GetSlot(text[dst*SlotSize:]) + if dstIns.Op() == 0 { + return fmt.Errorf("jump into middle of instruction") + } + case OpCallx: + if uimm := ins.Uimm(); uimm >= 10 { + return fmt.Errorf("invalid callx register") + } + case OpLddw: + if len(insBytes) < 2*SlotSize { + return fmt.Errorf("incomplete lddw instruction") + } + if insBytes[8] != 0 { + return fmt.Errorf("malformed lddw instruction") + } + pc++ + default: + return fmt.Errorf("unknown opcode %#02x", ins.Op()) + } + + if ins.Dst() > 9 { + return fmt.Errorf("invalid dst register") + } + } + + return nil +}