]> Cypherpunks repositories - gostls13.git/commitdiff
Add new functions to the iterable package:
authorDavid Symonds <dsymonds@golang.org>
Thu, 9 Apr 2009 04:50:40 +0000 (21:50 -0700)
committerDavid Symonds <dsymonds@golang.org>
Thu, 9 Apr 2009 04:50:40 +0000 (21:50 -0700)
- Filter
- Find
- Partition

R=rsc
APPROVED=rsc
DELTA=117  (92 added, 17 deleted, 8 changed)
OCL=27135
CL=27240

src/lib/container/iterable.go
src/lib/container/iterable_test.go

index 7963d14b57a41ba7188a408199bfb1387e02c1a3..a4a0d4c0f7446667054eb8d175a706327c3fb49d 100644 (file)
@@ -10,15 +10,17 @@ package iterable
 
 import "vector"
 
-
 type Iterable interface {
        // Iter should return a fresh channel each time it is called.
        Iter() <-chan interface {}
 }
 
+func not(f func(interface {}) bool) (func(interface {}) bool) {
+  return func(e interface {}) bool { return !f(e) }
+}
 
 // All tests whether f is true for every element of iter.
-func All(iter Iterable, f func(interface {}) bool) bool {
+func All(iter Iterable, f func(interface {}) bool) bool {
        for e := range iter.Iter() {
                if !f(e) {
                        return false
@@ -27,13 +29,11 @@ func All(iter Iterable, f func(e interface {}) bool) bool {
        return true
 }
 
-
 // Any tests whether f is true for at least one element of iter.
-func Any(iter Iterable, f func(interface {}) bool) bool {
-       return !All(iter, func(e interface {}) bool { return !f(e) });
+func Any(iter Iterable, f func(interface {}) bool) bool {
+       return !All(iter, not(f))
 }
 
-
 // Data returns a slice containing the elements of iter.
 func Data(iter Iterable) []interface {} {
        vec := vector.New(0);
@@ -43,6 +43,41 @@ func Data(iter Iterable) []interface {} {
        return vec.Data()
 }
 
+// filteredIterable is a struct that implements Iterable with each element
+// passed through a filter.
+type filteredIterable struct {
+       it Iterable;
+       f func(interface {}) bool;
+}
+
+func (f *filteredIterable) iterate(out chan<- interface {}) {
+       for e := range f.it.Iter() {
+               if f.f(e) {
+                       out <- e
+               }
+       }
+       close(out)
+}
+
+func (f *filteredIterable) Iter() <-chan interface {} {
+       ch := make(chan interface {});
+       go f.iterate(ch);
+       return ch;
+}
+
+// Filter returns an Iterable that returns the elements of iter that satisfy f.
+func Filter(iter Iterable, f func(interface {}) bool) Iterable {
+       return &filteredIterable{ iter, f }
+}
+
+// Find returns the first element of iter that satisfies f.
+// Returns nil if no such element is found.
+func Find(iter Iterable, f func(interface {}) bool) interface {} {
+       for e := range Filter(iter, f).Iter() {
+               return e
+       }
+       return nil
+}
 
 // mappedIterable is a helper struct that implements Iterable, returned by Map.
 type mappedIterable struct {
@@ -50,31 +85,30 @@ type mappedIterable struct {
        f func(interface {}) interface {};
 }
 
-
-func (m mappedIterable) iterate(out chan<- interface {}) {
+func (m *mappedIterable) iterate(out chan<- interface {}) {
        for e := range m.it.Iter() {
                out <- m.f(e)
        }
        close(out)
 }
 
-
-func (m mappedIterable) Iter() <-chan interface {} {
+func (m *mappedIterable) Iter() <-chan interface {} {
        ch := make(chan interface {});
        go m.iterate(ch);
        return ch;
 }
 
-
 // Map returns an Iterable that returns the result of applying f to each
 // element of iter.
-func Map(iter Iterable, f func(interface {}) interface {}) Iterable {
-       return mappedIterable{ iter, f }
+func Map(iter Iterable, f func(interface {}) interface {}) Iterable {
+       return &mappedIterable{ iter, f }
 }
 
+// Partition(iter, f) returns Filter(iter, f) and Filter(iter, !f).
+func Partition(iter Iterable, f func(interface {}) bool) (Iterable, Iterable)  {
+  return Filter(iter, f), Filter(iter, not(f))
+}
 
 // TODO:
-// - Find, Filter
 // - Inject
-// - Partition
 // - Zip
index 9c7d291105a64b22272dff2f69c77c23cde1c196..702ebe861b99e5cb7bf5e5d6a8b137bc21a54030 100644 (file)
@@ -43,6 +43,17 @@ func addOne(n interface {}) interface {} {
        return n.(int) + 1
 }
 
+// A stream of the natural numbers: 0, 1, 2, 3, ...
+type integerStream struct {}
+func (i integerStream) Iter() <-chan interface {} {
+  ch := make(chan interface {});
+  go func() {
+    for i := 0; ; i++ {
+      ch <- i
+    }
+  }();
+  return ch
+}
 
 func TestAll(t *testing.T) {
        if !All(oneToFive, isPositive) {
@@ -53,7 +64,6 @@ func TestAll(t *testing.T) {
        }
 }
 
-
 func TestAny(t *testing.T) {
        if Any(oneToFive, isNegative) {
                t.Error("Any(oneToFive, isNegative) == true")
@@ -63,16 +73,47 @@ func TestAny(t *testing.T) {
        }
 }
 
-
-func TestMap(t *testing.T) {
-       res := Data(Map(Map(oneToFive, doubler), addOne));
-       if len(res) != len(oneToFive) {
-               t.Fatal("len(res) = %v, want %v", len(res), len(oneToFive))
+func assertArraysAreEqual(t *testing.T, res []interface {}, expected []int) {
+       if len(res) != len(expected) {
+               t.Errorf("len(res) = %v, want %v", len(res), len(expected));
+               goto missing
        }
-       expected := []int{ 3, 5, 7, 9, 11 };
        for i := range res {
-               if res[i].(int) != expected[i] {
-                       t.Errorf("res[%v] = %v, want %v", i, res[i], expected[i])
+               if v := res[i].(int); v != expected[i] {
+                       t.Errorf("res[%v] = %v, want %v", i, v, expected[i]);
+                       goto missing
                }
        }
+       return;
+missing:
+       t.Errorf("res = %v\nwant  %v", res, expected);
+}
+
+func TestFilter(t *testing.T) {
+       ints := integerStream{};
+       moreInts := Filter(ints, isAbove3).Iter();
+       res := make([]interface {}, 3);
+       for i := 0; i < 3; i++ {
+               res[i] = <-moreInts;
+       }
+       assertArraysAreEqual(t, res, []int{ 4, 5, 6 })
+}
+
+func TestFind(t *testing.T) {
+       ints := integerStream{};
+       first := Find(ints, isAbove3);
+       if first.(int) != 4 {
+               t.Errorf("Find(ints, isAbove3) = %v, want 4", first)
+       }
+}
+
+func TestMap(t *testing.T) {
+       res := Data(Map(Map(oneToFive, doubler), addOne));
+       assertArraysAreEqual(t, res, []int{ 3, 5, 7, 9, 11 })
+}
+
+func TestPartition(t *testing.T) {
+       ti, fi := Partition(oneToFive, isEven);
+       assertArraysAreEqual(t, Data(ti), []int{ 2, 4 });
+       assertArraysAreEqual(t, Data(fi), []int{ 1, 3, 5 })
 }