Fix segfault from empty GetSignedVAARequest (#1069)
This commit is contained in:
parent
bad4f7061b
commit
b18a6c8c2f
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
"github.com/certusone/wormhole/node/pkg/common"
|
||||
"github.com/certusone/wormhole/node/pkg/db"
|
||||
publicrpcv1 "github.com/certusone/wormhole/node/pkg/proto/publicrpc/v1"
|
||||
|
@ -59,6 +60,10 @@ func (s *PublicrpcServer) GetLastHeartbeats(ctx context.Context, req *publicrpcv
|
|||
}
|
||||
|
||||
func (s *PublicrpcServer) GetSignedVAA(ctx context.Context, req *publicrpcv1.GetSignedVAARequest) (*publicrpcv1.GetSignedVAAResponse, error) {
|
||||
if req.MessageId == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "no message ID specified")
|
||||
}
|
||||
|
||||
address, err := hex.DecodeString(req.MessageId.EmitterAddress)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("failed to decode address: %v", err))
|
||||
|
@ -66,9 +71,6 @@ func (s *PublicrpcServer) GetSignedVAA(ctx context.Context, req *publicrpcv1.Get
|
|||
if len(address) != 32 {
|
||||
return nil, status.Error(codes.InvalidArgument, "address must be 32 bytes")
|
||||
}
|
||||
if req.MessageId == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "no message ID specified")
|
||||
}
|
||||
|
||||
addr := vaa.Address{}
|
||||
copy(addr[:], address)
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
package publicrpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
publicrpcv1 "github.com/certusone/wormhole/node/pkg/proto/publicrpc/v1"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetSignedVAANoMessage(t *testing.T) {
|
||||
msg := publicrpcv1.GetSignedVAARequest{}
|
||||
ctx := context.Background()
|
||||
|
||||
logger, _ := zap.NewProduction()
|
||||
server := &PublicrpcServer{logger: logger}
|
||||
|
||||
resp, err := server.GetSignedVAA(ctx, &msg)
|
||||
assert.Nil(t, resp)
|
||||
|
||||
expected_err := status.Error(codes.InvalidArgument, "no message ID specified")
|
||||
assert.Equal(t, expected_err, err)
|
||||
}
|
||||
|
||||
func TestGetSignedVAANoAddress(t *testing.T) {
|
||||
msg := publicrpcv1.GetSignedVAARequest{MessageId: &publicrpcv1.MessageID{}}
|
||||
ctx := context.Background()
|
||||
|
||||
logger, _ := zap.NewProduction()
|
||||
server := &PublicrpcServer{logger: logger}
|
||||
|
||||
resp, err := server.GetSignedVAA(ctx, &msg)
|
||||
assert.Nil(t, resp)
|
||||
|
||||
expected_err := status.Error(codes.InvalidArgument, "address must be 32 bytes")
|
||||
assert.Equal(t, expected_err, err)
|
||||
}
|
||||
|
||||
func TestGetSignedVAABadAddress(t *testing.T) {
|
||||
chainID := uint32(1)
|
||||
emitterAddr := "AAAA"
|
||||
seq := uint64(1)
|
||||
|
||||
msg := publicrpcv1.GetSignedVAARequest{
|
||||
MessageId: &publicrpcv1.MessageID{
|
||||
EmitterChain: publicrpcv1.ChainID(chainID),
|
||||
EmitterAddress: emitterAddr,
|
||||
Sequence: seq,
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
logger, _ := zap.NewProduction()
|
||||
server := &PublicrpcServer{logger: logger}
|
||||
|
||||
resp, err := server.GetSignedVAA(ctx, &msg)
|
||||
assert.Nil(t, resp)
|
||||
|
||||
expected_err := status.Error(codes.InvalidArgument, "address must be 32 bytes")
|
||||
assert.Equal(t, expected_err, err)
|
||||
}
|
Loading…
Reference in New Issue