sbf: add stack

This commit is contained in:
Richard Patel 2022-09-04 21:49:44 +02:00
parent 196e826d6b
commit 69819dc852
6 changed files with 231 additions and 61 deletions

3
pkg/sbf/cu.go Normal file
View File

@ -0,0 +1,3 @@
package sbf
// This file contains helper routines for the calculation of compute units.

View File

@ -9,17 +9,18 @@ import (
// Interpreter implements the SBF core in pure Go. // Interpreter implements the SBF core in pure Go.
type Interpreter struct { type Interpreter struct {
text []byte textVA uint64
ro []byte text []byte
stack []byte ro []byte
heap []byte stack Stack
input []byte heap []byte
input []byte
entry uint64 entry uint64
cuMax uint64 cuMax uint64
syscalls map[uint32]Syscall syscalls map[uint32]Syscall
funcs map[uint32]int64
vmContext any vmContext any
} }
@ -31,7 +32,7 @@ func NewInterpreter(p *Program, opts *VMOpts) *Interpreter {
return &Interpreter{ return &Interpreter{
text: p.Text, text: p.Text,
ro: p.RO, ro: p.RO,
stack: make([]byte, opts.StackSize), stack: NewStack(),
heap: make([]byte, opts.HeapSize), heap: make([]byte, opts.HeapSize),
input: opts.Input, input: opts.Input,
entry: p.Entrypoint, entry: p.Entrypoint,
@ -45,16 +46,18 @@ func NewInterpreter(p *Program, opts *VMOpts) *Interpreter {
// //
// This function may panic given code that doesn't pass the static verifier. // This function may panic given code that doesn't pass the static verifier.
func (i *Interpreter) Run() (err error) { func (i *Interpreter) Run() (err error) {
// Deliberately implementing the entire core in a single function here
// to give the compiler more creative liberties.
var r [11]uint64 var r [11]uint64
r[1] = VaddrInput r[1] = VaddrInput
// TODO frame pointer // TODO frame pointer
pc := int64(i.entry) pc := int64(i.entry)
cuLeft := int64(i.cuMax) cuLeft := int64(i.cuMax)
// TODO step to next instruction // Design notes
// - The interpreter is deliberately implemented in a single big loop,
// to give the compiler more creative liberties, and avoid escaping hot data to the heap.
// - uint64(int32(x)) performs sign extension. Most ALU64 instructions make use of this.
// - The static verifier imposes invariants on the bytecode.
// The interpreter may panic when it notices these invariants are violated (e.g. invalid opcode)
mainLoop: mainLoop:
for { for {
@ -367,21 +370,49 @@ mainLoop:
// TODO use src reg hint // TODO use src reg hint
if sc, ok := i.syscalls[ins.Uimm()]; ok { if sc, ok := i.syscalls[ins.Uimm()]; ok {
r[0], cuLeft, err = sc.Invoke(i, r[1], r[2], r[3], r[4], r[5], cuLeft) r[0], cuLeft, err = sc.Invoke(i, r[1], r[2], r[3], r[4], r[5], cuLeft)
} else if target, ok := i.funcs[ins.Uimm()]; ok {
r[10], ok = i.stack.Push((*[4]uint64)(r[6:10]), pc+1)
if !ok {
err = ExcCallDepth
}
pc = target
} else { } else {
panic("bpf function calls not implemented") err = ExcCallDest
} }
case OpCallx: case OpCallx:
panic("callx not implemented") target := r[ins.Uimm()]
target &= ^(uint64(0x7))
var ok bool
r[10], ok = i.stack.Push((*[4]uint64)(r[6:10]), pc+1)
if !ok {
err = ExcCallDepth
}
if target < i.textVA || target >= VaddrStack || target >= i.textVA+uint64(len(i.text)) {
err = NewExcBadAccess(target, 8, false, "jump out-of-bounds")
}
pc = int64((target - i.textVA) / 8)
case OpExit: case OpExit:
// TODO implement function returns var ok bool
break mainLoop r[10], pc, ok = i.stack.Pop((*[4]uint64)(r[6:10]))
if !ok {
break mainLoop
}
default: default:
panic(fmt.Sprintf("unimplemented opcode %#02x", ins.Op())) panic(fmt.Sprintf("unimplemented opcode %#02x", ins.Op()))
} }
// Post execute // Post execute
if cuLeft < 0 {
err = ExcOutOfCU
}
if err != nil { if err != nil {
// TODO return CPU exception error type here exc := &Exception{
return err PC: pc,
Detail: err,
}
if IsLongIns(ins.Op()) {
exc.PC-- // fix reported PC
}
return exc
} }
pc++ pc++
} }
@ -412,7 +443,11 @@ func (i *Interpreter) Translate(addr uint64, size uint32, write bool) (unsafe.Po
} }
return unsafe.Pointer(&i.ro[lo]), nil return unsafe.Pointer(&i.ro[lo]), nil
case VaddrStack >> 32: case VaddrStack >> 32:
panic("todo implement stack access check") mem := i.stack.GetFrame(uint32(addr))
if uint32(len(mem)) < size {
return nil, NewExcBadAccess(addr, size, write, "out-of-bounds stack access")
}
return unsafe.Pointer(&mem[0]), nil
case VaddrHeap >> 32: case VaddrHeap >> 32:
panic("todo implement heap access check") panic("todo implement heap access check")
case VaddrInput >> 32: case VaddrInput >> 32:

View File

@ -20,8 +20,6 @@ const (
MaxInsSize = 2 * SlotSize MaxInsSize = 2 * SlotSize
) )
const StackFrameSize = 0x1000
func IsLongIns(op uint8) bool { func IsLongIns(op uint8) bool {
return op == OpLddw return op == OpLddw
} }

118
pkg/sbf/stack.go Normal file
View File

@ -0,0 +1,118 @@
package sbf
// Stack is the VM's call frame stack.
//
// # Memory stack
//
// The memory stack resides in addressable memory at VaddrStack.
//
// It is split into statically sized stack frames (StackFrameSize).
// Each frame stores spilled function arguments and local variables.
// The frame pointer (r10) points to the highest address in the current frame.
//
// New frames get allocated upwards.
// Each frame is followed by a gap of size StackFrameSize.
//
// [0x1_0000_0000]: Frame
// [0x1_0000_1000]: Gap
// [0x1_0000_2000]: Frame
// [0x1_0000_3000]: Gap
// ...
//
// # Shadow stack
//
// The shadow stack is not directly accessible from SBF.
// It stores return addresses and caller-preserved registers.
type Stack struct {
mem []byte
sp uint64
shadow []Frame
}
// Frame is an entry on the shadow stack.
type Frame struct {
FramePtr uint64
NVRegs [4]uint64
RetAddr int64
}
// StackFrameSize is the addressable memory within a stack frame.
//
// Note that this constant cannot be changed trivially.
const StackFrameSize = 0x1000
// StackDepth is the max frame count of the stack.
const StackDepth = 64
func NewStack() Stack {
s := Stack{
mem: make([]byte, StackDepth*StackFrameSize),
sp: VaddrStack,
shadow: make([]Frame, 1, StackDepth),
}
s.shadow[0] = Frame{
FramePtr: VaddrStack + StackFrameSize,
}
return s
}
// GetFramePtr returns the current frame pointer.
func (s *Stack) GetFramePtr() uint64 {
return s.shadow[len(s.shadow)-1].FramePtr
}
// GetFrame returns the stack frame memory slice containing the frame pointer.
//
// The returned slice starts at the location within the frame as indicated by the address.
// To get the full frame, align the provided address by StackFrameSize.
//
// Returns nil if the program tries to address a gap or out-of-bounds memory.
func (s *Stack) GetFrame(addr uint32) []byte {
hi, lo := addr/StackFrameSize, addr%StackFrameSize
if hi > StackDepth || hi%2 == 1 {
return nil
}
pos := hi / 2
off := pos * StackFrameSize
return s.mem[off+lo : off+StackFrameSize]
}
// Push allocates a new call frame.
//
// Saves the given nonvolatile regs and return address.
// Returns the new frame pointer.
func (s *Stack) Push(nvRegs *[4]uint64, ret int64) (fp uint64, ok bool) {
if ok = len(s.shadow) < cap(s.shadow); !ok {
return
}
fp = s.GetFramePtr() + 2*StackFrameSize
s.shadow = s.shadow[len(s.shadow)+1:]
s.shadow[len(s.shadow)-1] = Frame{
FramePtr: fp,
NVRegs: *nvRegs,
RetAddr: ret,
}
s.sp = fp - StackFrameSize
return
}
// Pop exits the last call frame.
//
// Writes saved nonvolatile regs into provided slice.
// Returns saved return address, new frame pointer.
// Sets `ok` to false if no call frames are left.
func (s *Stack) Pop(nvRegs *[4]uint64) (fp uint64, ret int64, ok bool) {
if len(s.shadow) <= 1 {
ok = false
return
}
var frame Frame
frame, s.shadow = s.shadow[0], s.shadow[1:]
fp = s.GetFramePtr()
*nvRegs = frame.NVRegs
ret = frame.RetAddr
return
}

View File

@ -26,6 +26,11 @@ func PCHash(addr uint64) uint32 {
return murmur3.Sum32(key[:]) return murmur3.Sum32(key[:])
} }
// Syscall are callback handles from VM to Go. (work in progress)
type Syscall interface {
Invoke(vm VM, r1, r2, r3, r4, r5 uint64, cuIn int64) (r0 uint64, cuOut int64, err error)
}
type SyscallRegistry map[uint32]Syscall type SyscallRegistry map[uint32]Syscall
func NewSyscallRegistry() SyscallRegistry { func NewSyscallRegistry() SyscallRegistry {
@ -41,3 +46,41 @@ func (s SyscallRegistry) Register(name string, syscall Syscall) (hash uint32, ok
ok = true ok = true
return return
} }
// Convenience Methods
type SyscallFunc0 func(vm VM, cuIn int64) (r0 uint64, cuOut int64, err error)
func (f SyscallFunc0) Invoke(vm VM, _, _, _, _, _ uint64, cuIn int64) (r0 uint64, cuOut int64, err error) {
return f(vm, cuIn)
}
type SyscallFunc1 func(vm VM, r1 uint64, cuIn int64) (r0 uint64, cuOut int64, err error)
func (f SyscallFunc1) Invoke(vm VM, r1, _, _, _, _ uint64, cuIn int64) (r0 uint64, cuOut int64, err error) {
return f(vm, r1, cuIn)
}
type SyscallFunc2 func(vm VM, r1, r2 uint64, cuIn int64) (r0 uint64, cuOut int64, err error)
func (f SyscallFunc2) Invoke(vm VM, r1, r2, _, _, _ uint64, cuIn int64) (r0 uint64, cuOut int64, err error) {
return f(vm, r1, r2, cuIn)
}
type SyscallFunc3 func(vm VM, r1, r2, r3 uint64, cuIn int64) (r0 uint64, cuOut int64, err error)
func (f SyscallFunc3) Invoke(vm VM, r1, r2, r3, _, _ uint64, cuIn int64) (r0 uint64, cuOut int64, err error) {
return f(vm, r1, r2, r3, cuIn)
}
type SyscallFunc4 func(vm VM, r1, r2, r3, r4 uint64, cuIn int64) (r0 uint64, cuOut int64, err error)
func (f SyscallFunc4) Invoke(vm VM, r1, r2, r3, r4, _ uint64, cuIn int64) (r0 uint64, cuOut int64, err error) {
return f(vm, r1, r2, r3, r4, cuIn)
}
type SyscallFunc5 func(vm VM, r1, r2, r3, r4, r5 uint64, cuIn int64) (r0 uint64, cuOut int64, err error)
func (f SyscallFunc5) Invoke(vm VM, r1, r2, r3, r4, r5 uint64, cuIn int64) (r0 uint64, cuOut int64, err error) {
return f(vm, r1, r2, r3, r4, r5, cuIn)
}

View File

@ -35,15 +35,26 @@ type VMOpts struct {
Input []byte // mapped at VaddrInput Input []byte // mapped at VaddrInput
} }
// Syscall are callback handles from VM to Go. (work in progress) type Exception struct {
type Syscall interface { PC int64
Invoke(vm VM, r1, r2, r3, r4, r5 uint64, cuIn int64) (r0 uint64, cuOut int64, err error) Detail error
}
func (e *Exception) Error() string {
return fmt.Sprintf("exception at %d: %s", e.PC, e.Detail)
}
func (e *Exception) Unwrap() error {
return e.Detail
} }
// Exception codes. // Exception codes.
var ( var (
ExcDivideByZero = errors.New("division by zero") ExcDivideByZero = errors.New("division by zero")
ExcDivideOverflow = errors.New("divide overflow") ExcDivideOverflow = errors.New("divide overflow")
ExcOutOfCU = errors.New("compute unit overrun")
ExcCallDepth = errors.New("call depth exceeded")
ExcCallDest = errors.New("unknown symbol or syscall")
) )
type ExcBadAccess struct { type ExcBadAccess struct {
@ -65,41 +76,3 @@ func NewExcBadAccess(addr uint64, size uint32, write bool, reason string) ExcBad
func (e ExcBadAccess) Error() string { func (e ExcBadAccess) Error() string {
return fmt.Sprintf("bad memory access at %#x (size=%d write=%v), reason: %s", e.Addr, e.Size, e.Write, e.Reason) return fmt.Sprintf("bad memory access at %#x (size=%d write=%v), reason: %s", e.Addr, e.Size, e.Write, e.Reason)
} }
// Convenience Methods
type SyscallFunc0 func(vm VM, cuIn int64) (r0 uint64, cuOut int64, err error)
func (f SyscallFunc0) Invoke(vm VM, _, _, _, _, _ uint64, cuIn int64) (r0 uint64, cuOut int64, err error) {
return f(vm, cuIn)
}
type SyscallFunc1 func(vm VM, r1 uint64, cuIn int64) (r0 uint64, cuOut int64, err error)
func (f SyscallFunc1) Invoke(vm VM, r1, _, _, _, _ uint64, cuIn int64) (r0 uint64, cuOut int64, err error) {
return f(vm, r1, cuIn)
}
type SyscallFunc2 func(vm VM, r1, r2 uint64, cuIn int64) (r0 uint64, cuOut int64, err error)
func (f SyscallFunc2) Invoke(vm VM, r1, r2, _, _, _ uint64, cuIn int64) (r0 uint64, cuOut int64, err error) {
return f(vm, r1, r2, cuIn)
}
type SyscallFunc3 func(vm VM, r1, r2, r3 uint64, cuIn int64) (r0 uint64, cuOut int64, err error)
func (f SyscallFunc3) Invoke(vm VM, r1, r2, r3, _, _ uint64, cuIn int64) (r0 uint64, cuOut int64, err error) {
return f(vm, r1, r2, r3, cuIn)
}
type SyscallFunc4 func(vm VM, r1, r2, r3, r4 uint64, cuIn int64) (r0 uint64, cuOut int64, err error)
func (f SyscallFunc4) Invoke(vm VM, r1, r2, r3, r4, _ uint64, cuIn int64) (r0 uint64, cuOut int64, err error) {
return f(vm, r1, r2, r3, r4, cuIn)
}
type SyscallFunc5 func(vm VM, r1, r2, r3, r4, r5 uint64, cuIn int64) (r0 uint64, cuOut int64, err error)
func (f SyscallFunc5) Invoke(vm VM, r1, r2, r3, r4, r5 uint64, cuIn int64) (r0 uint64, cuOut int64, err error) {
return f(vm, r1, r2, r3, r4, r5, cuIn)
}