diff --git a/hw06_pipeline_execution/pipeline.go b/hw06_pipeline_execution/pipeline.go index b6fd2c9..d9b70bb 100644 --- a/hw06_pipeline_execution/pipeline.go +++ b/hw06_pipeline_execution/pipeline.go @@ -25,11 +25,19 @@ func chanWrap(in In, done In) Out { go func() { defer close(out) - for val := range in { + for { select { - case out <- val: case <-done: return + case val, ok := <-in: + if !ok { + return + } + select { + case <-done: + return + case out <- val: + } } } }() diff --git a/hw06_pipeline_execution/pipeline_test.go b/hw06_pipeline_execution/pipeline_test.go index 5cb1bcd..63c3a59 100644 --- a/hw06_pipeline_execution/pipeline_test.go +++ b/hw06_pipeline_execution/pipeline_test.go @@ -135,3 +135,79 @@ func TestPipeline(t *testing.T) { require.Equal(t, []string{}, result) }) } + +func TestChanWrap(t *testing.T) { + t.Run("To end", func(t *testing.T) { + in := make(Bi) + done := make(Bi) + data := []int{1, 2, 3, 4, 5} + + go func() { + for _, v := range data { + in <- v + time.Sleep(time.Millisecond * 10) + } + close(in) + }() + + res := make([]int, 0, 5) + for i := range chanWrap(in, done) { + res = append(res, i.(int)) + } + + require.Len(t, res, 5) + require.Equal(t, []int{1, 2, 3, 4, 5}, res) + }) + + t.Run("Close Done", func(t *testing.T) { + in := make(Bi) + done := make(Bi) + data := []int{1, 2, 3, 4, 5} + + go func() { + for i, v := range data { + if i == 2 { + close(done) + return + } + in <- v + time.Sleep(time.Millisecond * 10) + } + close(in) + }() + + res := make([]int, 0, 5) + for i := range chanWrap(in, done) { + res = append(res, i.(int)) + } + + require.Len(t, res, 2) + require.Equal(t, []int{1, 2}, res) + }) + + t.Run("Close In", func(t *testing.T) { + in := make(Bi) + done := make(Bi) + data := []int{1, 2, 3, 4, 5} + + go func() { + for i, v := range data { + if i == 2 { + close(in) + return + } + in <- v + time.Sleep(time.Millisecond * 10) + } + close(in) + }() + + res := make([]int, 0, 5) + for i := range chanWrap(in, done) { + res = append(res, i.(int)) + } + + require.Len(t, res, 2) + require.Equal(t, []int{1, 2}, res) + }) +}