diff --git a/parser/cmd/main.go b/parser/cmd/main.go index 551658e0..29f660aa 100644 --- a/parser/cmd/main.go +++ b/parser/cmd/main.go @@ -34,6 +34,7 @@ func handleExit() { } func main() { + defer handleExit() rootCtx, rootCtxCancel := context.WithCancel(context.Background()) @@ -70,7 +71,7 @@ func main() { // // create a new publisher. publisher := pipeline.NewPublisher(logger, repository, vaaPushFunc) - watcher := watcher.NewWatcher(db.Database, config.MongoDatabase, publisher.Publish, logger) + watcher := watcher.NewWatcher(rootCtx, db.Database, config.MongoDatabase, publisher.Publish, logger) err = watcher.Start(rootCtx) if err != nil { logger.Fatal("failed to watch MongoDB", zap.Error(err)) diff --git a/parser/http/infrastructure/service.go b/parser/http/infrastructure/service.go index e1470a54..2b0a112f 100644 --- a/parser/http/infrastructure/service.go +++ b/parser/http/infrastructure/service.go @@ -69,7 +69,7 @@ func (s *Service) CheckAwsSQS(ctx context.Context) (bool, error) { return true, nil } // get queue attributes - queueAttributes, err := s.consumer.GetQueueAttributes() + queueAttributes, err := s.consumer.GetQueueAttributes(ctx) if err != nil || queueAttributes == nil { return false, err } diff --git a/parser/internal/sqs/sqs_consumer.go b/parser/internal/sqs/sqs_consumer.go index cca2ab06..e0c892bb 100644 --- a/parser/internal/sqs/sqs_consumer.go +++ b/parser/internal/sqs/sqs_consumer.go @@ -60,7 +60,7 @@ func WithWaitTimeSeconds(v int32) ConsumerOption { } // GetMessages retrieves messages from SQS. -func (c *Consumer) GetMessages() ([]aws_sqs_types.Message, error) { +func (c *Consumer) GetMessages(ctx context.Context) ([]aws_sqs_types.Message, error) { params := &aws_sqs.ReceiveMessageInput{ QueueUrl: aws.String(c.url), MaxNumberOfMessages: c.maxMessages, @@ -74,7 +74,7 @@ func (c *Consumer) GetMessages() ([]aws_sqs_types.Message, error) { VisibilityTimeout: c.visibilityTimeout, } - res, err := c.api.ReceiveMessage(context.TODO(), params) + res, err := c.api.ReceiveMessage(ctx, params) if err != nil { return nil, err } @@ -83,12 +83,12 @@ func (c *Consumer) GetMessages() ([]aws_sqs_types.Message, error) { } // DeleteMessage deletes messages from SQS. -func (c *Consumer) DeleteMessage(msg *aws_sqs_types.Message) error { +func (c *Consumer) DeleteMessage(ctx context.Context, msg *aws_sqs_types.Message) error { params := &aws_sqs.DeleteMessageInput{ QueueUrl: aws.String(c.url), ReceiptHandle: msg.ReceiptHandle, } - _, err := c.api.DeleteMessage(context.TODO(), params) + _, err := c.api.DeleteMessage(ctx, params) return err } @@ -99,12 +99,12 @@ func (c *Consumer) GetVisibilityTimeout() time.Duration { } // GetQueueAttributes get queue attributes. -func (c *Consumer) GetQueueAttributes() (*aws_sqs.GetQueueAttributesOutput, error) { +func (c *Consumer) GetQueueAttributes(ctx context.Context) (*aws_sqs.GetQueueAttributesOutput, error) { params := &aws_sqs.GetQueueAttributesInput{ QueueUrl: aws.String(c.url), AttributeNames: []aws_sqs_types.QueueAttributeName{ aws_sqs_types.QueueAttributeNameCreatedTimestamp, }, } - return c.api.GetQueueAttributes(context.TODO(), params) + return c.api.GetQueueAttributes(ctx, params) } diff --git a/parser/internal/sqs/sqs_producer.go b/parser/internal/sqs/sqs_producer.go index 0d9ed122..d40bd699 100644 --- a/parser/internal/sqs/sqs_producer.go +++ b/parser/internal/sqs/sqs_producer.go @@ -21,8 +21,8 @@ func NewProducer(awsConfig aws.Config, url string) (*Producer, error) { } // SendMessage sends messages to SQS. -func (p *Producer) SendMessage(groupID, deduplicationID, body string) error { - _, err := p.api.SendMessage(context.TODO(), +func (p *Producer) SendMessage(ctx context.Context, groupID, deduplicationID, body string) error { + _, err := p.api.SendMessage(ctx, &aws_sqs.SendMessageInput{ MessageGroupId: aws.String(groupID), MessageDeduplicationId: aws.String(deduplicationID), diff --git a/parser/pipeline/publisher.go b/parser/pipeline/publisher.go index bd2bed27..0253d770 100644 --- a/parser/pipeline/publisher.go +++ b/parser/pipeline/publisher.go @@ -23,7 +23,7 @@ func NewPublisher(logger *zap.Logger, repository *parser.Repository, pushFunc qu } // Publish sends a VaaEvent for the vaa that has parse configuration defined. -func (p *Publisher) Publish(e *watcher.Event) { +func (p *Publisher) Publish(ctx context.Context, e *watcher.Event) { // deserializes the binary representation of a VAA vaa, err := vaa.Unmarshal(e.Vaas) if err != nil { @@ -42,7 +42,7 @@ func (p *Publisher) Publish(e *watcher.Event) { } // push messages to queue. - err = p.pushFunc(context.TODO(), &event) + err = p.pushFunc(ctx, &event) if err != nil { p.logger.Error("can not push event to queue", zap.Error(err), zap.String("event", event.ID())) } diff --git a/parser/queue/vaa_sqs.go b/parser/queue/vaa_sqs.go index 31b9e8d4..911fecd7 100644 --- a/parser/queue/vaa_sqs.go +++ b/parser/queue/vaa_sqs.go @@ -44,14 +44,14 @@ func WithChannelSize(size int) SQSOption { } // Publish sends the message to a SQS queue. -func (q *SQS) Publish(_ context.Context, message *VaaEvent) error { +func (q *SQS) Publish(ctx context.Context, message *VaaEvent) error { body, err := json.Marshal(message) if err != nil { return err } groupID := fmt.Sprintf("%d/%s", message.ChainID, message.EmitterAddress) deduplicationID := fmt.Sprintf("%d/%s/%d", message.ChainID, message.EmitterAddress, message.Sequence) - return q.producer.SendMessage(groupID, deduplicationID, string(body)) + return q.producer.SendMessage(ctx, groupID, deduplicationID, string(body)) } // Consume returns the channel with the received messages from SQS queue. @@ -62,7 +62,7 @@ func (q *SQS) Consume(ctx context.Context) <-chan *ConsumerMessage { case <-ctx.Done(): return default: - messages, err := q.consumer.GetMessages() + messages, err := q.consumer.GetMessages(ctx) if err != nil { q.logger.Error("Error getting messages from SQS", zap.Error(err)) continue @@ -78,7 +78,7 @@ func (q *SQS) Consume(ctx context.Context) <-chan *ConsumerMessage { q.ch <- &ConsumerMessage{ Data: &body, Ack: func() { - if err := q.consumer.DeleteMessage(&msg); err != nil { + if err := q.consumer.DeleteMessage(ctx, &msg); err != nil { q.logger.Error("Error deleting message from SQS", zap.Error(err)) } }, diff --git a/parser/watcher/watcher.go b/parser/watcher/watcher.go index cdd98436..c9e45dba 100644 --- a/parser/watcher/watcher.go +++ b/parser/watcher/watcher.go @@ -18,7 +18,7 @@ type Watcher struct { } // WatcherFunc is a function to send database changes. -type WatcherFunc func(*Event) +type WatcherFunc func(context.Context, *Event) type watchEvent struct { DocumentKey documentKey `bson:"documentKey"` @@ -47,7 +47,7 @@ const queryTemplate = ` ` // NewWatcher creates a new database event watcher. -func NewWatcher(db *mongo.Database, dbName string, handler WatcherFunc, logger *zap.Logger) *Watcher { +func NewWatcher(ctx context.Context, db *mongo.Database, dbName string, handler WatcherFunc, logger *zap.Logger) *Watcher { return &Watcher{ db: db, dbName: dbName, @@ -76,7 +76,7 @@ func (w *Watcher) Start(ctx context.Context) error { w.logger.Error("Error unmarshalling event", zap.Error(err)) continue } - w.handler(&Event{ + w.handler(ctx, &Event{ ID: e.DbFullDocument.ID, Vaas: e.DbFullDocument.Vaas, })