Use process-concurrency for Dirk accountmanager.

The Dirk accountmanager was using a local scatter/gather concurrency
method to obtain wallets, however this uses the parallelism of the Vouch
server rather than the Dirk server.  This chnages the Dirk
accountmanager to use a configuration value to select the concurrency
level.

This also standardizes the use of process concurency to allow for
hierarchical definition of the value.
This commit is contained in:
Jim McDonald 2021-07-22 22:35:09 +01:00
parent 700d1a19d9
commit c10f060848
No known key found for this signature in database
GPG Key ID: 89CEB61B2AD2A5E7
17 changed files with 223 additions and 311 deletions

View File

@ -1,5 +1,7 @@
1.1.0:
- added metrics to track strategy operation results
- fetch wallet accounts from Dirk in parallel
- fetch process-concurrency configuration value from most specific point in hierarchy
- add metrics to track strategy operation results
- provide release metric in `vouch_release`
- provide ready metric in `vouch_ready`
- handle chain reorganisations, updating duties as appropriate

14
main.go
View File

@ -62,6 +62,7 @@ import (
firstattestationdatastrategy "github.com/attestantio/vouch/strategies/attestationdata/first"
bestbeaconblockproposalstrategy "github.com/attestantio/vouch/strategies/beaconblockproposal/best"
firstbeaconblockproposalstrategy "github.com/attestantio/vouch/strategies/beaconblockproposal/first"
"github.com/attestantio/vouch/util"
"github.com/aws/aws-sdk-go/aws/credentials"
homedir "github.com/mitchellh/go-homedir"
"github.com/opentracing/opentracing-go"
@ -342,7 +343,7 @@ func startServices(ctx context.Context, majordomo majordomo.Service) error {
log.Trace().Msg("Starting attester")
attester, err := standardattester.New(ctx,
standardattester.WithLogLevel(logLevel(viper.GetString("attester.log-level"))),
standardattester.WithProcessConcurrency(viper.GetInt64("process-concurrency")),
standardattester.WithProcessConcurrency(util.ProcessConcurrency("attester")),
standardattester.WithSlotsPerEpochProvider(eth2Client.(eth2client.SlotsPerEpochProvider)),
standardattester.WithAttestationDataProvider(attestationDataProvider),
standardattester.WithAttestationsSubmitter(submitterStrategy.(submitter.AttestationsSubmitter)),
@ -380,7 +381,7 @@ func startServices(ctx context.Context, majordomo majordomo.Service) error {
log.Trace().Msg("Starting beacon committee subscriber service")
beaconCommitteeSubscriber, err := standardbeaconcommitteesubscriber.New(ctx,
standardbeaconcommitteesubscriber.WithLogLevel(logLevel(viper.GetString("beaconcommiteesubscriber.log-level"))),
standardbeaconcommitteesubscriber.WithProcessConcurrency(viper.GetInt64("process-concurrency")),
standardbeaconcommitteesubscriber.WithProcessConcurrency(util.ProcessConcurrency("beaconcommitteesubscriber")),
standardbeaconcommitteesubscriber.WithMonitor(monitor.(metrics.BeaconCommitteeSubscriptionMonitor)),
standardbeaconcommitteesubscriber.WithAttesterDutiesProvider(eth2Client.(eth2client.AttesterDutiesProvider)),
standardbeaconcommitteesubscriber.WithAttestationAggregator(attestationAggregator),
@ -612,6 +613,7 @@ func startAccountManager(ctx context.Context, monitor metrics.Service, eth2Clien
dirkaccountmanager.WithLogLevel(logLevel(viper.GetString("accountmanager.dirk.log-level"))),
dirkaccountmanager.WithMonitor(monitor.(metrics.AccountManagerMonitor)),
dirkaccountmanager.WithClientMonitor(monitor.(metrics.ClientMonitor)),
dirkaccountmanager.WithProcessConcurrency(util.ProcessConcurrency("accountmanager.dirk")),
dirkaccountmanager.WithValidatorsManager(validatorsManager),
dirkaccountmanager.WithEndpoints(viper.GetStringSlice("accountmanager.dirk.endpoints")),
dirkaccountmanager.WithAccountPaths(viper.GetStringSlice("accountmanager.dirk.accounts")),
@ -683,7 +685,7 @@ func selectAttestationDataProvider(ctx context.Context,
}
attestationDataProvider, err = bestattestationdatastrategy.New(ctx,
bestattestationdatastrategy.WithClientMonitor(monitor.(metrics.ClientMonitor)),
bestattestationdatastrategy.WithProcessConcurrency(viper.GetInt64("process-concurrency")),
bestattestationdatastrategy.WithProcessConcurrency(util.ProcessConcurrency("strategies.attestationdata.best")),
bestattestationdatastrategy.WithLogLevel(logLevel(viper.GetString("strategies.attestationdata.log-level"))),
bestattestationdatastrategy.WithAttestationDataProviders(attestationDataProviders),
)
@ -744,7 +746,7 @@ func selectAggregateAttestationProvider(ctx context.Context,
}
aggregateAttestationProvider, err = bestaggregateattestationstrategy.New(ctx,
bestaggregateattestationstrategy.WithClientMonitor(monitor.(metrics.ClientMonitor)),
bestaggregateattestationstrategy.WithProcessConcurrency(viper.GetInt64("process-concurrency")),
bestaggregateattestationstrategy.WithProcessConcurrency(util.ProcessConcurrency("strategies.aggregateattestation.best")),
bestaggregateattestationstrategy.WithLogLevel(logLevel(viper.GetString("strategies.aggregateattestation.log-level"))),
bestaggregateattestationstrategy.WithAggregateAttestationProviders(aggregateAttestationProviders),
)
@ -804,7 +806,7 @@ func selectBeaconBlockProposalProvider(ctx context.Context,
}
beaconBlockProposalProvider, err = bestbeaconblockproposalstrategy.New(ctx,
bestbeaconblockproposalstrategy.WithClientMonitor(monitor.(metrics.ClientMonitor)),
bestbeaconblockproposalstrategy.WithProcessConcurrency(viper.GetInt64("process-concurrency")),
bestbeaconblockproposalstrategy.WithProcessConcurrency(util.ProcessConcurrency("strategies.beaconblockproposal.best")),
bestbeaconblockproposalstrategy.WithLogLevel(logLevel(viper.GetString("strategies.beaconblockproposal.log-level"))),
bestbeaconblockproposalstrategy.WithBeaconBlockProposalProviders(beaconBlockProposalProviders),
bestbeaconblockproposalstrategy.WithSignedBeaconBlockProvider(eth2Client.(eth2client.SignedBeaconBlockProvider)),
@ -860,7 +862,7 @@ func selectSubmitterStrategy(ctx context.Context, monitor metrics.Service, eth2C
}
submitter, err = multinodesubmitter.New(ctx,
multinodesubmitter.WithClientMonitor(monitor.(metrics.ClientMonitor)),
multinodesubmitter.WithProcessConcurrency(viper.GetInt64("process-concurrency")),
multinodesubmitter.WithProcessConcurrency(util.ProcessConcurrency("submitter.multinode")),
multinodesubmitter.WithLogLevel(logLevel(viper.GetString("submitter.log-level"))),
multinodesubmitter.WithBeaconBlockSubmitters(beaconBlockSubmitters),
multinodesubmitter.WithAttestationsSubmitters(attestationsSubmitters),

View File

@ -1,4 +1,4 @@
// Copyright © 2020 Attestant Limited.
// Copyright © 2020, 2021 Attestant Limited.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -29,6 +29,7 @@ type parameters struct {
logLevel zerolog.Level
monitor metrics.AccountManagerMonitor
clientMonitor metrics.ClientMonitor
processConcurrency int64
endpoints []string
accountPaths []string
clientCert []byte
@ -72,6 +73,13 @@ func WithClientMonitor(clientMonitor metrics.ClientMonitor) Parameter {
})
}
// WithProcessConcurrency sets the concurrency for the service.
func WithProcessConcurrency(concurrency int64) Parameter {
return parameterFunc(func(p *parameters) {
p.processConcurrency = concurrency
})
}
// WithEndpoints sets the endpoints to communicate with dirk.
func WithEndpoints(endpoints []string) Parameter {
return parameterFunc(func(p *parameters) {
@ -154,6 +162,9 @@ func parseAndCheckParameters(params ...Parameter) (*parameters, error) {
if parameters.clientMonitor == nil {
return nil, errors.New("no client monitor specified")
}
if parameters.processConcurrency <= 0 {
return nil, errors.New("process concurrency must be > 0")
}
if len(parameters.endpoints) == 0 {
return nil, errors.New("no endpoints specified")
}

View File

@ -22,6 +22,7 @@ import (
"strconv"
"strings"
"sync"
"time"
eth2client "github.com/attestantio/go-eth2-client"
api "github.com/attestantio/go-eth2-client/api/v1"
@ -29,13 +30,13 @@ import (
"github.com/attestantio/vouch/services/chaintime"
"github.com/attestantio/vouch/services/metrics"
"github.com/attestantio/vouch/services/validatorsmanager"
"github.com/attestantio/vouch/util"
"github.com/pkg/errors"
"github.com/rs/zerolog"
zerologger "github.com/rs/zerolog/log"
"github.com/wealdtech/go-bytesutil"
dirk "github.com/wealdtech/go-eth2-wallet-dirk"
e2wtypes "github.com/wealdtech/go-eth2-wallet-types/v2"
"golang.org/x/sync/semaphore"
"google.golang.org/grpc/credentials"
)
@ -44,6 +45,7 @@ type Service struct {
mutex sync.RWMutex
monitor metrics.AccountManagerMonitor
clientMonitor metrics.ClientMonitor
processConcurrency int64
endpoints []*dirk.Endpoint
accountPaths []string
credentials credentials.TransportCredentials
@ -107,6 +109,7 @@ func New(ctx context.Context, params ...Parameter) (*Service, error) {
s := &Service{
monitor: parameters.monitor,
clientMonitor: parameters.clientMonitor,
processConcurrency: parameters.processConcurrency,
endpoints: endpoints,
accountPaths: parameters.accountPaths,
credentials: credentials,
@ -116,6 +119,7 @@ func New(ctx context.Context, params ...Parameter) (*Service, error) {
currentEpochProvider: parameters.currentEpochProvider,
wallets: make(map[string]e2wtypes.Wallet),
}
log.Trace().Int64("process_concurrency", s.processConcurrency).Msg("Set process concurrency")
if err := s.refreshAccounts(ctx); err != nil {
return nil, errors.Wrap(err, "failed to fetch initial accounts")
@ -162,23 +166,33 @@ func (s *Service) refreshAccounts(ctx context.Context) error {
}
verificationRegexes := accountPathsToVerificationRegexes(s.accountPaths)
// Fetch accounts for each wallet.
// Fetch accounts for each wallet in parallel.
started := time.Now()
accounts := make(map[phase0.BLSPubKey]e2wtypes.Account)
_, err := util.Scatter(len(wallets), func(offset int, entries int, mu *sync.RWMutex) (interface{}, error) {
for i := offset; i < offset+entries; i++ {
var accountsMu sync.Mutex
sem := semaphore.NewWeighted(s.processConcurrency)
var wg sync.WaitGroup
for i := range wallets {
wg.Add(1)
go func(ctx context.Context, sem *semaphore.Weighted, wg *sync.WaitGroup, i int, mu *sync.Mutex) {
defer wg.Done()
if err := sem.Acquire(ctx, 1); err != nil {
log.Error().Err(err).Msg("Failed to acquire semaphore")
return
}
defer sem.Release(1)
log := log.With().Str("wallet", wallets[i].Name()).Logger()
log.Trace().Dur("elapsed", time.Since(started)).Msg("Obtained semaphore")
walletAccounts := s.fetchAccountsForWallet(ctx, wallets[i], verificationRegexes)
mu.Lock()
log.Trace().Dur("elapsed", time.Since(started)).Int("accounts", len(walletAccounts)).Msg("Obtained accounts")
accountsMu.Lock()
for k, v := range walletAccounts {
accounts[k] = v
}
mu.Unlock()
}
return nil, nil
})
if err != nil {
log.Error().Err(err).Str("result", "failed").Msg("Failed to obtain accounts")
accountsMu.Unlock()
log.Trace().Dur("elapsed", time.Since(started)).Int("accounts", len(walletAccounts)).Msg("Imported accounts")
}(ctx, sem, &wg, i, &accountsMu)
}
log.Trace().Int("accounts", len(accounts)).Msg("Obtained accounts")
if len(accounts) == 0 && len(s.accounts) != 0 {

View File

@ -1,4 +1,4 @@
// Copyright © 2020 Attestant Limited.
// Copyright © 2020, 2021 Attestant Limited.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -190,6 +190,7 @@ func setupService(ctx context.Context, t *testing.T, endpoints []string, account
WithLogLevel(zerolog.TraceLevel),
WithMonitor(nullmetrics.New(context.Background())),
WithClientMonitor(nullmetrics.New(context.Background())),
WithProcessConcurrency(1),
WithEndpoints(endpoints),
WithAccountPaths(accountPaths),
WithClientCert([]byte(resources.ClientTest01Crt)),

View File

@ -1,4 +1,4 @@
// Copyright © 2020 Attestant Limited.
// Copyright © 2020, 2021 Attestant Limited.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@ -78,6 +78,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nil),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{"localhost:12345", "localhost:12346"}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
@ -90,12 +91,32 @@ func TestService(t *testing.T) {
},
err: "problem with parameters: no client monitor specified",
},
{
name: "ProcessConcurrencyZero",
params: []dirk.Parameter{
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(0),
dirk.WithEndpoints([]string{"localhost:12345", "localhost:12346"}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
dirk.WithClientKey([]byte(resources.ClientTest01Key)),
dirk.WithCACert([]byte(resources.CACrt)),
dirk.WithValidatorsManager(validatorsManager),
dirk.WithDomainProvider(domainProvider),
dirk.WithFarFutureEpochProvider(farFutureEpochProvider),
dirk.WithCurrentEpochProvider(chainTime),
},
err: "problem with parameters: process concurrency must be > 0",
},
{
name: "EndpointsNil",
params: []dirk.Parameter{
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
dirk.WithClientKey([]byte(resources.ClientTest01Key)),
@ -113,6 +134,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
@ -131,6 +153,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{""}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
@ -150,6 +173,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{"host:bad"}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
@ -169,6 +193,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{"host:0"}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
@ -188,6 +213,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{"localhost:12345", "localhost:12346"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
dirk.WithClientKey([]byte(resources.ClientTest01Key)),
@ -205,6 +231,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{"localhost:12345", "localhost:12346"}),
dirk.WithAccountPaths([]string{}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
@ -223,6 +250,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{"localhost:12345", "localhost:12346"}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientKey([]byte(resources.ClientTest01Key)),
@ -240,6 +268,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{"localhost:12345", "localhost:12346"}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
@ -257,6 +286,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.Disabled),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{"localhost:12345", "localhost:12346"}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
@ -275,6 +305,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{"localhost:12345", "localhost:12346"}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
@ -292,6 +323,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{"localhost:12345", "localhost:12346"}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
@ -309,6 +341,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.TraceLevel),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{"localhost:12345", "localhost:12346"}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
@ -326,6 +359,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.Disabled),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{"localhost:12345", "localhost:12346"}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),
@ -343,6 +377,7 @@ func TestService(t *testing.T) {
dirk.WithLogLevel(zerolog.Disabled),
dirk.WithMonitor(nullmetrics.New(ctx)),
dirk.WithClientMonitor(nullmetrics.New(ctx)),
dirk.WithProcessConcurrency(1),
dirk.WithEndpoints([]string{"localhost:12345", "localhost:12346"}),
dirk.WithAccountPaths([]string{"wallet1", "wallet2"}),
dirk.WithClientCert([]byte(resources.ClientTest01Crt)),

View File

@ -74,6 +74,7 @@ func New(ctx context.Context, params ...Parameter) (*Service, error) {
attestationsSubmitter: parameters.attestationsSubmitter,
beaconAttestationsSigner: parameters.beaconAttestationsSigner,
}
log.Trace().Int64("process_concurrency", s.processConcurrency).Msg("Set process concurrency")
return s, nil
}

View File

@ -66,6 +66,7 @@ func New(ctx context.Context, params ...Parameter) (*Service, error) {
attestationAggregator: parameters.attestationAggregator,
submitter: parameters.beaconCommitteeSubmitter,
}
log.Trace().Int64("process_concurrency", s.processConcurrency).Msg("Set process concurrency")
return s, nil
}

View File

@ -57,6 +57,7 @@ func New(ctx context.Context, params ...Parameter) (*Service, error) {
aggregateAttestationsSubmitters: parameters.aggregateAttestationsSubmitters,
beaconCommitteeSubscriptionSubmitters: parameters.beaconCommitteeSubscriptionsSubmitters,
}
log.Trace().Int64("process_concurrency", s.processConcurrency).Msg("Set process concurrency")
return s, nil
}

View File

@ -54,6 +54,7 @@ func New(ctx context.Context, params ...Parameter) (*Service, error) {
processConcurrency: parameters.processConcurrency,
aggregateAttestationProviders: parameters.aggregateAttestationProviders,
}
log.Trace().Int64("process_concurrency", s.processConcurrency).Msg("Set process concurrency")
return s, nil
}

View File

@ -54,6 +54,7 @@ func New(ctx context.Context, params ...Parameter) (*Service, error) {
processConcurrency: parameters.processConcurrency,
attestationDataProviders: parameters.attestationDataProviders,
}
log.Trace().Int64("process_concurrency", s.processConcurrency).Msg("Set process concurrency")
return s, nil
}

View File

@ -56,6 +56,7 @@ func New(ctx context.Context, params ...Parameter) (*Service, error) {
timeout: parameters.timeout,
clientMonitor: parameters.clientMonitor,
}
log.Trace().Int64("process_concurrency", s.processConcurrency).Msg("Set process concurrency")
return s, nil
}

39
util/concurrency.go Normal file
View File

@ -0,0 +1,39 @@
// Copyright © 2021 Attestant Limited.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package util
import (
"fmt"
"strings"
"github.com/spf13/viper"
)
// ProcessConcurrency returns the best process concurrency for the path.
func ProcessConcurrency(path string) int64 {
if path == "" {
return viper.GetInt64("process-concurrency")
}
key := fmt.Sprintf("%s.process-concurrency", path)
if viper.GetString(key) != "" {
return viper.GetInt64(key)
}
// Lop off the child and try again.
lastPeriod := strings.LastIndex(path, ".")
if lastPeriod == -1 {
return ProcessConcurrency("")
}
return ProcessConcurrency(path[0:lastPeriod])
}

91
util/concurrency_test.go Normal file
View File

@ -0,0 +1,91 @@
// Copyright © 2021 Attestant Limited.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package util_test
import (
"fmt"
"os"
"strings"
"testing"
"github.com/attestantio/vouch/util"
"github.com/spf13/viper"
"github.com/stretchr/testify/require"
)
func TestProcessConcurrency(t *testing.T) {
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_", ".", "_"))
viper.AutomaticEnv()
tests := []struct {
name string
path string
env map[string]string
expected int64
}{
{
name: "Empty",
env: map[string]string{
"PROCESS_CONCURRENCY": "12345",
},
expected: 12345,
},
{
name: "MultilevelRoot",
env: map[string]string{
"PROCESS_CONCURRENCY": "12345",
},
path: "a.b.c.process-concurrency",
expected: 12345,
},
{
name: "MultilevelBranch",
env: map[string]string{
"PROCESS_CONCURRENCY": "12345",
"A_B_PROCESS_CONCURRENCY": "54321",
},
path: "a.b.c.process-concurrency",
expected: 54321,
},
{
name: "Unknown",
env: map[string]string{
"FOO": "12345",
},
path: "process-concurrency",
expected: 0,
},
{
name: "Fallback",
env: map[string]string{
"PROCESS_CONCURRENCY": "12345",
},
path: "foo",
expected: 12345,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
prefix := fmt.Sprintf("VOUCH_%s", strings.ToUpper(test.name))
for k, v := range test.env {
os.Setenv(fmt.Sprintf("%s_%s", prefix, k), v)
}
viper.SetEnvPrefix(prefix)
res := util.ProcessConcurrency(test.path)
require.Equal(t, test.expected, res)
})
}
}

View File

@ -1,97 +0,0 @@
// Copyright © 2020 Attestant Limited.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package util
import (
"errors"
"runtime"
"sync"
)
// ScatterResult is the result of a single scatter worker.
type ScatterResult struct {
// Offset is the offset at which the worker started.
Offset int
// Extent is the user-defined result of running the scatter function.
Extent interface{}
}
// Scatter scatters a computation across multiple goroutines, returning a set of per-worker results
func Scatter(inputLen int, work func(int, int, *sync.RWMutex) (interface{}, error)) ([]*ScatterResult, error) {
if inputLen <= 0 {
return nil, errors.New("no data with which to work")
}
extentSize := calculateExtentSize(inputLen)
workers := inputLen / extentSize
if inputLen%extentSize != 0 {
workers++
}
resultCh := make(chan *ScatterResult, workers)
defer close(resultCh)
errorCh := make(chan error, workers)
defer close(errorCh)
mutex := new(sync.RWMutex)
for worker := 0; worker < workers; worker++ {
offset := worker * extentSize
entries := extentSize
if offset+entries > inputLen {
entries = inputLen - offset
}
go func(offset int, entries int) {
extent, err := work(offset, entries, mutex)
if err != nil {
errorCh <- err
} else {
resultCh <- &ScatterResult{
Offset: offset,
Extent: extent,
}
}
}(offset, entries)
}
// Collect results from workers
results := make([]*ScatterResult, workers)
var err error
for i := 0; i < workers; i++ {
select {
case result := <-resultCh:
results[i] = result
case err = <-errorCh:
// Error occurred; don't return because that closes the channels
// and can cause other workers to write to the closed channel.
}
}
return results, err
}
// calculateExtentSize calculates the extent size given the number of items and maximum processors available.
func calculateExtentSize(items int) int {
// Start with an even split.
extentSize := items / runtime.GOMAXPROCS(0)
if extentSize == 0 {
// We must have an extent size of at least 1.
return 1
}
if items%extentSize > 0 {
// We have a remainder; add one to the extent size to ensure we capture it.
extentSize++
}
return extentSize
}

View File

@ -1,76 +0,0 @@
// Copyright © 2020 Attestant Limited.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package util_test
import (
"crypto/rand"
"crypto/sha256"
"sync"
"testing"
"github.com/attestantio/vouch/util"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
var input [][]byte
const (
benchmarkElements = 65536
benchmarkElementSize = 32
benchmarkHashRuns = 128
)
func init() {
input = make([][]byte, benchmarkElements)
for i := 0; i < benchmarkElements; i++ {
input[i] = make([]byte, benchmarkElementSize)
_, err := rand.Read(input[i])
if err != nil {
log.WithError(err).Debug("Cannot read from rand")
}
}
}
// hash is a simple worker function that carries out repeated hashging of its input to provide an output.
func hash(input [][]byte) [][]byte {
output := make([][]byte, len(input))
for i := range input {
copy(output, input)
for j := 0; j < benchmarkHashRuns; j++ {
hash := sha256.Sum256(output[i])
output[i] = hash[:]
}
}
return output
}
func BenchmarkHash(b *testing.B) {
for i := 0; i < b.N; i++ {
hash(input)
}
}
func BenchmarkHashMP(b *testing.B) {
output := make([][]byte, len(input))
for i := 0; i < b.N; i++ {
workerResults, err := util.Scatter(len(input), func(offset int, entries int, _ *sync.RWMutex) (interface{}, error) {
return hash(input[offset : offset+entries]), nil
})
require.NoError(b, err)
for _, result := range workerResults {
copy(output[result.Offset:], result.Extent.([][]byte))
}
}
}

View File

@ -1,116 +0,0 @@
// Copyright © 2020 Attestant Limited.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package util_test
import (
"errors"
"sync"
"testing"
"github.com/attestantio/vouch/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDouble(t *testing.T) {
tests := []struct {
name string
inValues int
err string
}{
{
name: "0",
inValues: 0,
err: "no data with which to work",
},
{
name: "1",
inValues: 1,
},
{
name: "1023",
inValues: 1023,
},
{
name: "1024",
inValues: 1024,
},
{
name: "1025",
inValues: 1025,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
inValues := make([]int, test.inValues)
for i := 0; i < test.inValues; i++ {
inValues[i] = i
}
outValues := make([]int, test.inValues)
workerResults, err := util.Scatter(len(inValues), func(offset int, entries int, _ *sync.RWMutex) (interface{}, error) {
extent := make([]int, entries)
for i := 0; i < entries; i++ {
extent[i] = inValues[offset+i] * 2
}
return extent, nil
})
if test.err != "" {
assert.Equal(t, test.err, err.Error())
} else {
require.NoError(t, err)
for _, result := range workerResults {
copy(outValues[result.Offset:], result.Extent.([]int))
}
for i := 0; i < test.inValues; i++ {
require.Equal(t, inValues[i]*2, outValues[i], "Outvalue at %d incorrect", i)
}
}
})
}
}
func TestMutex(t *testing.T) {
totalRuns := 1048576
val := 0
_, err := util.Scatter(totalRuns, func(offset int, entries int, mu *sync.RWMutex) (interface{}, error) {
for i := 0; i < entries; i++ {
mu.Lock()
val++
mu.Unlock()
}
return nil, nil
})
require.NoError(t, err)
require.Equal(t, totalRuns, val)
}
func TestError(t *testing.T) {
totalRuns := 1024
val := 0
_, err := util.Scatter(totalRuns, func(offset int, entries int, mu *sync.RWMutex) (interface{}, error) {
for i := 0; i < entries; i++ {
mu.Lock()
val++
if val == 1011 {
mu.Unlock()
return nil, errors.New("bad number")
}
mu.Unlock()
}
return nil, nil
})
require.EqualError(t, err, "bad number")
}