Fix: fly memory leaks in sqs consumer (#126)
This commit is contained in:
parent
821071086d
commit
7fba3537a1
|
@ -83,10 +83,10 @@ func (c *Consumer) GetMessages() ([]*aws_sqs.Message, error) {
|
|||
}
|
||||
|
||||
// DeleteMessage deletes messages from SQS.
|
||||
func (c *Consumer) DeleteMessage(msg *aws_sqs.Message) error {
|
||||
func (c *Consumer) DeleteMessage(id *string) error {
|
||||
params := &aws_sqs.DeleteMessageInput{
|
||||
QueueUrl: aws.String(c.url),
|
||||
ReceiptHandle: msg.ReceiptHandle,
|
||||
ReceiptHandle: id,
|
||||
}
|
||||
_, err := c.api.DeleteMessage(params)
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
)
|
||||
|
||||
// VAAQueueConsumeFunc is a function to obtain messages from a queue
|
||||
type VAAQueueConsumeFunc func(context.Context) <-chan *queue.Message
|
||||
type VAAQueueConsumeFunc func(context.Context) <-chan queue.Message
|
||||
|
||||
// VAAQueueConsumer represents a VAA queue consumer.
|
||||
type VAAQueueConsumer struct {
|
||||
|
@ -38,41 +38,40 @@ func NewVAAQueueConsumer(
|
|||
// Start consumes messages from VAA queue and store those messages in a repository.
|
||||
func (c *VAAQueueConsumer) Start(ctx context.Context) {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg := <-c.consume(ctx):
|
||||
v, err := vaa.Unmarshal(msg.Data)
|
||||
if err != nil {
|
||||
c.logger.Error("Error unmarshalling vaa", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.IsExpired() {
|
||||
c.logger.Warn("Message with vaa expired", zap.String("id", v.MessageID()))
|
||||
continue
|
||||
}
|
||||
|
||||
err = c.repository.UpsertVaa(ctx, v, msg.Data)
|
||||
if err != nil {
|
||||
c.logger.Error("Error inserting vaa in repository",
|
||||
zap.String("id", v.MessageID()),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
err = c.notifyFunc(ctx, v, msg.Data)
|
||||
if err != nil {
|
||||
c.logger.Error("Error notifying vaa",
|
||||
zap.String("id", v.MessageID()),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
msg.Ack()
|
||||
c.logger.Info("Vaa save in repository", zap.String("id", v.MessageID()))
|
||||
for msg := range c.consume(ctx) {
|
||||
v, err := vaa.Unmarshal(msg.Data())
|
||||
if err != nil {
|
||||
c.logger.Error("Error unmarshalling vaa", zap.Error(err))
|
||||
msg.Failed()
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.IsExpired() {
|
||||
c.logger.Warn("Message with vaa expired", zap.String("id", v.MessageID()))
|
||||
msg.Failed()
|
||||
continue
|
||||
}
|
||||
|
||||
err = c.repository.UpsertVaa(ctx, v, msg.Data())
|
||||
if err != nil {
|
||||
c.logger.Error("Error inserting vaa in repository",
|
||||
zap.String("id", v.MessageID()),
|
||||
zap.Error(err))
|
||||
msg.Failed()
|
||||
continue
|
||||
}
|
||||
|
||||
err = c.notifyFunc(ctx, v, msg.Data())
|
||||
if err != nil {
|
||||
c.logger.Error("Error notifying vaa",
|
||||
zap.String("id", v.MessageID()),
|
||||
zap.Error(err))
|
||||
msg.Failed()
|
||||
continue
|
||||
}
|
||||
|
||||
msg.Done()
|
||||
c.logger.Info("Vaa save in repository", zap.String("id", v.MessageID()))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
package queue
|
||||
|
||||
// Message represents a message from a queue.
|
||||
type Message struct {
|
||||
Data []byte
|
||||
Ack func()
|
||||
IsExpired func() bool
|
||||
type Message interface {
|
||||
Data() []byte
|
||||
Done()
|
||||
Failed()
|
||||
IsExpired() bool
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ type VAAInMemoryOption func(*VAAInMemory)
|
|||
|
||||
// VAAInMemory represents VAA queue in memory.
|
||||
type VAAInMemory struct {
|
||||
ch chan *Message
|
||||
ch chan Message
|
||||
size int
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,7 @@ func NewVAAInMemory(opts ...VAAInMemoryOption) *VAAInMemory {
|
|||
for _, opt := range opts {
|
||||
opt(m)
|
||||
}
|
||||
m.ch = make(chan *Message, m.size)
|
||||
m.ch = make(chan Message, m.size)
|
||||
return m
|
||||
}
|
||||
|
||||
|
@ -34,15 +34,29 @@ func WithSize(v int) VAAInMemoryOption {
|
|||
|
||||
// Publish sends the message to a channel.
|
||||
func (i *VAAInMemory) Publish(_ context.Context, v *vaa.VAA, data []byte) error {
|
||||
i.ch <- &Message{
|
||||
Data: data,
|
||||
Ack: func() {},
|
||||
IsExpired: func() bool { return false },
|
||||
i.ch <- &memoryConsumerMessage{
|
||||
data: data,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Consume returns the channel with the received messages.
|
||||
func (i *VAAInMemory) Consume(_ context.Context) <-chan *Message {
|
||||
func (i *VAAInMemory) Consume(_ context.Context) <-chan Message {
|
||||
return i.ch
|
||||
}
|
||||
|
||||
type memoryConsumerMessage struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (m *memoryConsumerMessage) Data() []byte {
|
||||
return m.data
|
||||
}
|
||||
|
||||
func (m *memoryConsumerMessage) Done() {}
|
||||
|
||||
func (m *memoryConsumerMessage) Failed() {}
|
||||
|
||||
func (m *memoryConsumerMessage) IsExpired() bool {
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/wormhole-foundation/wormhole-explorer/fly/internal/sqs"
|
||||
|
@ -19,8 +20,9 @@ type SQSOption func(*SQS)
|
|||
type SQS struct {
|
||||
producer *sqs.Producer
|
||||
consumer *sqs.Consumer
|
||||
ch chan *Message
|
||||
ch chan Message
|
||||
chSize int
|
||||
wg sync.WaitGroup
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
|
@ -34,7 +36,7 @@ func NewVAASQS(producer *sqs.Producer, consumer *sqs.Consumer, logger *zap.Logge
|
|||
for _, opt := range opts {
|
||||
opt(s)
|
||||
}
|
||||
s.ch = make(chan *Message, s.chSize)
|
||||
s.ch = make(chan Message, s.chSize)
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -53,39 +55,35 @@ func (q *SQS) Publish(_ context.Context, v *vaa.VAA, data []byte) error {
|
|||
}
|
||||
|
||||
// Consume returns the channel with the received messages from SQS queue.
|
||||
func (q *SQS) Consume(ctx context.Context) <-chan *Message {
|
||||
func (q *SQS) Consume(ctx context.Context) <-chan Message {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
messages, err := q.consumer.GetMessages()
|
||||
messages, err := q.consumer.GetMessages()
|
||||
if err != nil {
|
||||
q.logger.Error("Error getting messages from SQS", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
expiredAt := time.Now().Add(q.consumer.GetVisibilityTimeout())
|
||||
for _, msg := range messages {
|
||||
body, err := base64.StdEncoding.DecodeString(*msg.Body)
|
||||
if err != nil {
|
||||
q.logger.Error("Error getting messages from SQS", zap.Error(err))
|
||||
q.logger.Error("Error decoding message from SQS", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
expiredAt := time.Now().Add(q.consumer.GetVisibilityTimeout())
|
||||
for _, msg := range messages {
|
||||
body, err := base64.StdEncoding.DecodeString(*msg.Body)
|
||||
if err != nil {
|
||||
q.logger.Error("Error decoding message from SQS", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
//TODO check if callback is better than channel
|
||||
q.ch <- &Message{
|
||||
Data: body,
|
||||
Ack: func() {
|
||||
if err := q.consumer.DeleteMessage(msg); err != nil {
|
||||
q.logger.Error("Error deleting message from SQS", zap.Error(err))
|
||||
}
|
||||
},
|
||||
IsExpired: func() bool {
|
||||
return expiredAt.Before(time.Now())
|
||||
},
|
||||
}
|
||||
|
||||
//TODO check if callback is better than channel
|
||||
q.wg.Add(1)
|
||||
q.ch <- &sqsConsumerMessage{
|
||||
id: msg.ReceiptHandle,
|
||||
data: body,
|
||||
wg: &q.wg,
|
||||
logger: q.logger,
|
||||
consumer: q.consumer,
|
||||
expiredAt: expiredAt,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
q.wg.Wait()
|
||||
}
|
||||
}()
|
||||
return q.ch
|
||||
|
@ -95,3 +93,32 @@ func (q *SQS) Consume(ctx context.Context) <-chan *Message {
|
|||
func (q *SQS) Close() {
|
||||
close(q.ch)
|
||||
}
|
||||
|
||||
type sqsConsumerMessage struct {
|
||||
data []byte
|
||||
consumer *sqs.Consumer
|
||||
id *string
|
||||
logger *zap.Logger
|
||||
expiredAt time.Time
|
||||
wg *sync.WaitGroup
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m *sqsConsumerMessage) Data() []byte {
|
||||
return m.data
|
||||
}
|
||||
|
||||
func (m *sqsConsumerMessage) Done() {
|
||||
if err := m.consumer.DeleteMessage(m.id); err != nil {
|
||||
m.logger.Error("Error deleting message from SQS", zap.Error(err))
|
||||
}
|
||||
m.wg.Done()
|
||||
}
|
||||
|
||||
func (m *sqsConsumerMessage) Failed() {
|
||||
m.wg.Done()
|
||||
}
|
||||
|
||||
func (m *sqsConsumerMessage) IsExpired() bool {
|
||||
return m.expiredAt.Before(time.Now())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue