sealevel: add memcmp syscall & tests
This commit is contained in:
parent
1620df2dc9
commit
4dbcc11d68
Binary file not shown.
Binary file not shown.
|
@ -273,6 +273,87 @@ func TestInterpreter_Memcpy_Overlapping(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// The TestInterpreter_Memcmp_Matches function tests that the memcmp
|
||||
// syscall works as expected by comparing two instances of "abcdabcd1234"
|
||||
// The expected result is that the two strings match and the program
|
||||
// writes "Memory chunks matched." to the program log.
|
||||
func TestInterpreter_Memcmp_Matches(t *testing.T) {
|
||||
loader, err := loader.NewLoaderFromBytes(fixtures.Load(t, "sbpf", "memcmp_matched.so"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, loader)
|
||||
|
||||
program, err := loader.Load()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, program)
|
||||
|
||||
require.NoError(t, program.Verify())
|
||||
|
||||
syscalls := sbpf.NewSyscallRegistry()
|
||||
syscalls.Register("sol_log_", SyscallLog)
|
||||
syscalls.Register("log_64", SyscallLog64)
|
||||
syscalls.Register("my_memcmp", SyscallMemcmp)
|
||||
|
||||
var log LogRecorder
|
||||
|
||||
interpreter := sbpf.NewInterpreter(program, &sbpf.VMOpts{
|
||||
HeapSize: 32 * 1024,
|
||||
Input: nil,
|
||||
MaxCU: 10000,
|
||||
Syscalls: syscalls,
|
||||
Context: &Execution{Log: &log},
|
||||
})
|
||||
require.NotNil(t, interpreter)
|
||||
|
||||
err = interpreter.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, log.Logs, []string{
|
||||
"Program log: Memory chunks matched.",
|
||||
})
|
||||
}
|
||||
|
||||
// The TestInterpreter_Memcmp_Does_Not_Match function tests that the memcmp
|
||||
// syscall works as expected by comparing the string literals "abcdabcd1234"
|
||||
// and "BLAHabcd1234"
|
||||
// The expected result is that the two strings do not match and the difference
|
||||
// between the first non-matching characters (0x61 - 0x42 = 0x1f) is returned,
|
||||
// and the program checks these and returns messages accordingly.
|
||||
func TestInterpreter_Memcmp_Does_Not_Match(t *testing.T) {
|
||||
loader, err := loader.NewLoaderFromBytes(fixtures.Load(t, "sbpf", "memcmp_not_matched.so"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, loader)
|
||||
|
||||
program, err := loader.Load()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, program)
|
||||
|
||||
require.NoError(t, program.Verify())
|
||||
|
||||
syscalls := sbpf.NewSyscallRegistry()
|
||||
syscalls.Register("sol_log_", SyscallLog)
|
||||
syscalls.Register("log_64", SyscallLog64)
|
||||
syscalls.Register("my_memcmp", SyscallMemcmp)
|
||||
|
||||
var log LogRecorder
|
||||
|
||||
interpreter := sbpf.NewInterpreter(program, &sbpf.VMOpts{
|
||||
HeapSize: 32 * 1024,
|
||||
Input: nil,
|
||||
MaxCU: 10000,
|
||||
Syscalls: syscalls,
|
||||
Context: &Execution{Log: &log},
|
||||
})
|
||||
require.NotNil(t, interpreter)
|
||||
|
||||
err = interpreter.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, log.Logs, []string{
|
||||
"Program log: Memory chunks did not match.",
|
||||
"Program log: Difference between non-matching character was correctly returned.",
|
||||
})
|
||||
}
|
||||
|
||||
type executeCase struct {
|
||||
Name string
|
||||
Program string
|
||||
|
|
|
@ -16,6 +16,7 @@ func Syscalls() sbpf.SyscallRegistry {
|
|||
reg.Register("sol_log_pubkey", SyscallLogPubkey)
|
||||
reg.Register("sol_memcpy_", SyscallMemcpy)
|
||||
reg.Register("sol_memmove_", SyscallMemmove)
|
||||
reg.Register("sol_memcmp_", SyscallMemcmp)
|
||||
return reg
|
||||
}
|
||||
|
||||
|
|
|
@ -65,3 +65,34 @@ func SyscallMemmoveImpl(vm sbpf.VM, dst, src, n uint64, cuIn int) (r0 uint64, cu
|
|||
}
|
||||
|
||||
var SyscallMemmove = sbpf.SyscallFunc3(SyscallMemmoveImpl)
|
||||
|
||||
// SyscallMemcmpImpl is the implementation for the memcmp (sol_memcmp_) syscall.
|
||||
func SyscallMemcmpImpl(vm sbpf.VM, addr1, addr2, n, resultAddr uint64, cuIn int) (r0 uint64, cuOut int, err error) {
|
||||
cuOut = MemOpConsume(cuIn, n)
|
||||
if cuOut < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
slice1, err := vm.Translate(addr1, uint32(n), false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
slice2, err := vm.Translate(addr2, uint32(n), false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cmpResult := int32(0)
|
||||
for count := uint64(0); count < n; count++ {
|
||||
b1 := slice1[count]
|
||||
b2 := slice2[count]
|
||||
if b1 != b2 {
|
||||
cmpResult = int32(b1 - b2)
|
||||
break
|
||||
}
|
||||
}
|
||||
err = vm.Write32(resultAddr, uint32(cmpResult))
|
||||
return
|
||||
}
|
||||
|
||||
var SyscallMemcmp = sbpf.SyscallFunc4(SyscallMemcmpImpl)
|
||||
|
|
Loading…
Reference in New Issue