From 4e5c655944c9a636eaed549e6ad8fd8011fb4d42 Mon Sep 17 00:00:00 2001 From: Jae Kwon Date: Tue, 20 Mar 2018 23:08:51 +0100 Subject: [PATCH] Parallel reaps automatically before returning --- common/async.go | 7 +++--- common/async_test.go | 59 +++++++++++++++++++++++--------------------- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/common/async.go b/common/async.go index 64e32076..31dc2e96 100644 --- a/common/async.go +++ b/common/async.go @@ -41,7 +41,7 @@ func (trs *TaskResultSet) Channels() []TaskResultCh { return trs.chz } -func (trs *TaskResultSet) LastResult(index int) (TaskResult, bool) { +func (trs *TaskResultSet) LatestResult(index int) (TaskResult, bool) { if len(trs.results) <= index { return TaskResult{}, false } @@ -50,7 +50,7 @@ func (trs *TaskResultSet) LastResult(index int) (TaskResult, bool) { } // NOTE: Not concurrency safe. -func (trs *TaskResultSet) Reap() { +func (trs *TaskResultSet) Reap() *TaskResultSet { if trs.results == nil { trs.results = make([]taskResultOK, len(trs.chz)) } @@ -67,6 +67,7 @@ func (trs *TaskResultSet) Reap() { // Do nothing. } } + return trs } // Returns the firstmost (by task index) error as @@ -155,5 +156,5 @@ func Parallel(tasks ...Task) (trs *TaskResultSet, ok bool) { // We must do this check here (after DONE_LOOP). ok = ok && (atomic.LoadInt32(numPanics) == 0) - return newTaskResultSet(taskResultChz), ok + return newTaskResultSet(taskResultChz).Reap(), ok } diff --git a/common/async_test.go b/common/async_test.go index 3b47c3fa..2e8db26e 100644 --- a/common/async_test.go +++ b/common/async_test.go @@ -2,6 +2,7 @@ package common import ( "errors" + "fmt" "sync/atomic" "testing" "time" @@ -29,22 +30,27 @@ func TestParallel(t *testing.T) { assert.Equal(t, int(*counter), len(tasks), "Each task should have incremented the counter already") var failedTasks int for i := 0; i < len(tasks); i++ { - select { - case taskResult := <-trs.chz[i]: - if taskResult.Error != nil { - assert.Fail(t, "Task should not have errored but got %v", taskResult.Error) - failedTasks += 1 - } else if !assert.Equal(t, -1*i, taskResult.Value.(int)) { - failedTasks += 1 - } else { - // Good! - } - default: + taskResult, ok := trs.LatestResult(i) + if !ok { + assert.Fail(t, "Task #%v did not complete.", i) failedTasks += 1 + } else if taskResult.Error != nil { + assert.Fail(t, "Task should not have errored but got %v", taskResult.Error) + failedTasks += 1 + } else if taskResult.Panic != nil { + assert.Fail(t, "Task should not have panic'd but got %v", taskResult.Panic) + failedTasks += 1 + } else if !assert.Equal(t, -1*i, taskResult.Value.(int)) { + assert.Fail(t, "Task should have returned %v but got %v", -1*i, taskResult.Value.(int)) + failedTasks += 1 + } else { + // Good! } } assert.Equal(t, failedTasks, 0, "No task should have failed") - + assert.Nil(t, trs.FirstError(), "There should be no errors") + assert.Nil(t, trs.FirstPanic(), "There should be no panics") + assert.Equal(t, 0, trs.FirstValue(), "First value should be 0") } func TestParallelAbort(t *testing.T) { @@ -90,9 +96,9 @@ func TestParallelAbort(t *testing.T) { flow4 <- <-flow3 // Verify task #0, #1, #2. - waitFor(t, taskResultSet.chz[0], "Task #0", 0, nil, nil) - waitFor(t, taskResultSet.chz[1], "Task #1", 1, errors.New("some error"), nil) - waitFor(t, taskResultSet.chz[2], "Task #2", 2, nil, nil) + checkResult(t, taskResultSet, 0, 0, nil, nil) + checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil) + checkResult(t, taskResultSet, 2, 2, nil, nil) } func TestParallelRecover(t *testing.T) { @@ -115,22 +121,19 @@ func TestParallelRecover(t *testing.T) { assert.False(t, ok, "ok should be false since we panic'd in task #2.") // Verify task #0, #1, #2. - waitFor(t, taskResultSet.chz[0], "Task #0", 0, nil, nil) - waitFor(t, taskResultSet.chz[1], "Task #1", 1, errors.New("some error"), nil) - waitFor(t, taskResultSet.chz[2], "Task #2", nil, nil, 2) + checkResult(t, taskResultSet, 0, 0, nil, nil) + checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil) + checkResult(t, taskResultSet, 2, nil, nil, 2) } // Wait for result -func waitFor(t *testing.T, taskResultCh TaskResultCh, taskName string, val interface{}, err error, pnk interface{}) { - select { - case taskResult, ok := <-taskResultCh: - assert.True(t, ok, "TaskResultCh unexpectedly closed for %v", taskName) - assert.Equal(t, val, taskResult.Value, taskName) - assert.Equal(t, err, taskResult.Error, taskName) - assert.Equal(t, pnk, taskResult.Panic, taskName) - default: - assert.Fail(t, "Failed to receive result for %v", taskName) - } +func checkResult(t *testing.T, taskResultSet *TaskResultSet, index int, val interface{}, err error, pnk interface{}) { + taskResult, ok := taskResultSet.LatestResult(index) + taskName := fmt.Sprintf("Task #%v", index) + assert.True(t, ok, "TaskResultCh unexpectedly closed for %v", taskName) + assert.Equal(t, val, taskResult.Value, taskName) + assert.Equal(t, err, taskResult.Error, taskName) + assert.Equal(t, pnk, taskResult.Panic, taskName) } // Wait for timeout (no result)