Fix: fly memory leaks in sqs consumer (#126)

This commit is contained in:
ftocal 2023-02-02 14:51:33 -03:00 committed by GitHub
parent 821071086d
commit 7fba3537a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 117 additions and 76 deletions

View File

@ -83,10 +83,10 @@ func (c *Consumer) GetMessages() ([]*aws_sqs.Message, error) {
}
// DeleteMessage deletes messages from SQS.
func (c *Consumer) DeleteMessage(msg *aws_sqs.Message) error {
func (c *Consumer) DeleteMessage(id *string) error {
params := &aws_sqs.DeleteMessageInput{
QueueUrl: aws.String(c.url),
ReceiptHandle: msg.ReceiptHandle,
ReceiptHandle: id,
}
_, err := c.api.DeleteMessage(params)

View File

@ -11,7 +11,7 @@ import (
)
// VAAQueueConsumeFunc is a function to obtain messages from a queue
type VAAQueueConsumeFunc func(context.Context) <-chan *queue.Message
type VAAQueueConsumeFunc func(context.Context) <-chan queue.Message
// VAAQueueConsumer represents a VAA queue consumer.
type VAAQueueConsumer struct {
@ -38,41 +38,40 @@ func NewVAAQueueConsumer(
// Start consumes messages from VAA queue and store those messages in a repository.
func (c *VAAQueueConsumer) Start(ctx context.Context) {
go func() {
for {
select {
case <-ctx.Done():
return
case msg := <-c.consume(ctx):
v, err := vaa.Unmarshal(msg.Data)
if err != nil {
c.logger.Error("Error unmarshalling vaa", zap.Error(err))
continue
}
if msg.IsExpired() {
c.logger.Warn("Message with vaa expired", zap.String("id", v.MessageID()))
continue
}
err = c.repository.UpsertVaa(ctx, v, msg.Data)
if err != nil {
c.logger.Error("Error inserting vaa in repository",
zap.String("id", v.MessageID()),
zap.Error(err))
continue
}
err = c.notifyFunc(ctx, v, msg.Data)
if err != nil {
c.logger.Error("Error notifying vaa",
zap.String("id", v.MessageID()),
zap.Error(err))
continue
}
msg.Ack()
c.logger.Info("Vaa save in repository", zap.String("id", v.MessageID()))
for msg := range c.consume(ctx) {
v, err := vaa.Unmarshal(msg.Data())
if err != nil {
c.logger.Error("Error unmarshalling vaa", zap.Error(err))
msg.Failed()
continue
}
if msg.IsExpired() {
c.logger.Warn("Message with vaa expired", zap.String("id", v.MessageID()))
msg.Failed()
continue
}
err = c.repository.UpsertVaa(ctx, v, msg.Data())
if err != nil {
c.logger.Error("Error inserting vaa in repository",
zap.String("id", v.MessageID()),
zap.Error(err))
msg.Failed()
continue
}
err = c.notifyFunc(ctx, v, msg.Data())
if err != nil {
c.logger.Error("Error notifying vaa",
zap.String("id", v.MessageID()),
zap.Error(err))
msg.Failed()
continue
}
msg.Done()
c.logger.Info("Vaa save in repository", zap.String("id", v.MessageID()))
}
}()
}

View File

@ -1,8 +1,9 @@
package queue
// Message represents a message from a queue.
type Message struct {
Data []byte
Ack func()
IsExpired func() bool
type Message interface {
Data() []byte
Done()
Failed()
IsExpired() bool
}

View File

@ -11,7 +11,7 @@ type VAAInMemoryOption func(*VAAInMemory)
// VAAInMemory represents VAA queue in memory.
type VAAInMemory struct {
ch chan *Message
ch chan Message
size int
}
@ -21,7 +21,7 @@ func NewVAAInMemory(opts ...VAAInMemoryOption) *VAAInMemory {
for _, opt := range opts {
opt(m)
}
m.ch = make(chan *Message, m.size)
m.ch = make(chan Message, m.size)
return m
}
@ -34,15 +34,29 @@ func WithSize(v int) VAAInMemoryOption {
// Publish sends the message to a channel.
func (i *VAAInMemory) Publish(_ context.Context, v *vaa.VAA, data []byte) error {
i.ch <- &Message{
Data: data,
Ack: func() {},
IsExpired: func() bool { return false },
i.ch <- &memoryConsumerMessage{
data: data,
}
return nil
}
// Consume returns the channel with the received messages.
func (i *VAAInMemory) Consume(_ context.Context) <-chan *Message {
func (i *VAAInMemory) Consume(_ context.Context) <-chan Message {
return i.ch
}
type memoryConsumerMessage struct {
data []byte
}
func (m *memoryConsumerMessage) Data() []byte {
return m.data
}
func (m *memoryConsumerMessage) Done() {}
func (m *memoryConsumerMessage) Failed() {}
func (m *memoryConsumerMessage) IsExpired() bool {
return false
}

View File

@ -4,6 +4,7 @@ import (
"context"
"encoding/base64"
"fmt"
"sync"
"time"
"github.com/wormhole-foundation/wormhole-explorer/fly/internal/sqs"
@ -19,8 +20,9 @@ type SQSOption func(*SQS)
type SQS struct {
producer *sqs.Producer
consumer *sqs.Consumer
ch chan *Message
ch chan Message
chSize int
wg sync.WaitGroup
logger *zap.Logger
}
@ -34,7 +36,7 @@ func NewVAASQS(producer *sqs.Producer, consumer *sqs.Consumer, logger *zap.Logge
for _, opt := range opts {
opt(s)
}
s.ch = make(chan *Message, s.chSize)
s.ch = make(chan Message, s.chSize)
return s
}
@ -53,39 +55,35 @@ func (q *SQS) Publish(_ context.Context, v *vaa.VAA, data []byte) error {
}
// Consume returns the channel with the received messages from SQS queue.
func (q *SQS) Consume(ctx context.Context) <-chan *Message {
func (q *SQS) Consume(ctx context.Context) <-chan Message {
go func() {
for {
select {
case <-ctx.Done():
return
default:
messages, err := q.consumer.GetMessages()
messages, err := q.consumer.GetMessages()
if err != nil {
q.logger.Error("Error getting messages from SQS", zap.Error(err))
continue
}
expiredAt := time.Now().Add(q.consumer.GetVisibilityTimeout())
for _, msg := range messages {
body, err := base64.StdEncoding.DecodeString(*msg.Body)
if err != nil {
q.logger.Error("Error getting messages from SQS", zap.Error(err))
q.logger.Error("Error decoding message from SQS", zap.Error(err))
continue
}
expiredAt := time.Now().Add(q.consumer.GetVisibilityTimeout())
for _, msg := range messages {
body, err := base64.StdEncoding.DecodeString(*msg.Body)
if err != nil {
q.logger.Error("Error decoding message from SQS", zap.Error(err))
continue
}
//TODO check if callback is better than channel
q.ch <- &Message{
Data: body,
Ack: func() {
if err := q.consumer.DeleteMessage(msg); err != nil {
q.logger.Error("Error deleting message from SQS", zap.Error(err))
}
},
IsExpired: func() bool {
return expiredAt.Before(time.Now())
},
}
//TODO check if callback is better than channel
q.wg.Add(1)
q.ch <- &sqsConsumerMessage{
id: msg.ReceiptHandle,
data: body,
wg: &q.wg,
logger: q.logger,
consumer: q.consumer,
expiredAt: expiredAt,
ctx: ctx,
}
}
q.wg.Wait()
}
}()
return q.ch
@ -95,3 +93,32 @@ func (q *SQS) Consume(ctx context.Context) <-chan *Message {
func (q *SQS) Close() {
close(q.ch)
}
type sqsConsumerMessage struct {
data []byte
consumer *sqs.Consumer
id *string
logger *zap.Logger
expiredAt time.Time
wg *sync.WaitGroup
ctx context.Context
}
func (m *sqsConsumerMessage) Data() []byte {
return m.data
}
func (m *sqsConsumerMessage) Done() {
if err := m.consumer.DeleteMessage(m.id); err != nil {
m.logger.Error("Error deleting message from SQS", zap.Error(err))
}
m.wg.Done()
}
func (m *sqsConsumerMessage) Failed() {
m.wg.Done()
}
func (m *sqsConsumerMessage) IsExpired() bool {
return m.expiredAt.Before(time.Now())
}