]> Cypherpunks repositories - gostls13.git/commitdiff
runtime: add parallel for algorithm
authorDmitriy Vyukov <dvyukov@google.com>
Fri, 11 May 2012 06:50:03 +0000 (10:50 +0400)
committerDmitriy Vyukov <dvyukov@google.com>
Fri, 11 May 2012 06:50:03 +0000 (10:50 +0400)
This is factored out part of:
https://golang.org/cl/5279048/
(parallel GC)

R=bsiegert, mpimenov, rsc, minux.ma, r
CC=golang-dev
https://golang.org/cl/5986054

src/cmd/dist/buildruntime.c
src/pkg/runtime/export_test.go
src/pkg/runtime/parfor.c [new file with mode: 0644]
src/pkg/runtime/parfor_test.go [new file with mode: 0644]
src/pkg/runtime/runtime.h

index 5bf6047cbf90aa7be7c0fb10e4d536b2632d0a9a..454d594e5d51cf80145dc8bb41cb37606561c780 100644 (file)
@@ -256,6 +256,7 @@ static char *runtimedefs[] = {
        "iface.c",
        "hashmap.c",
        "chan.c",
+       "parfor.c",
 };
 
 // mkzruntimedefs writes zruntime_defs_$GOOS_$GOARCH.h,
index d50040adcfb02580c5c9327453b59d95bca73007..c1971cd2d18af077d5a542ffac3bb71b0951b77e 100644 (file)
@@ -36,3 +36,28 @@ func lfstackpop2(head *uint64) *LFNode
 
 var LFStackPush = lfstackpush
 var LFStackPop = lfstackpop2
+
+type ParFor struct {
+       body    *byte
+       done    uint32
+       Nthr    uint32
+       nthrmax uint32
+       thrseq  uint32
+       Cnt     uint32
+       Ctx     *byte
+       wait    bool
+}
+
+func parforalloc2(nthrmax uint32) *ParFor
+func parforsetup2(desc *ParFor, nthr, n uint32, ctx *byte, wait bool, body func(*ParFor, uint32))
+func parfordo(desc *ParFor)
+func parforiters(desc *ParFor, tid uintptr) (uintptr, uintptr)
+
+var NewParFor = parforalloc2
+var ParForSetup = parforsetup2
+var ParForDo = parfordo
+
+func ParForIters(desc *ParFor, tid uint32) (uint32, uint32) {
+       begin, end := parforiters(desc, uintptr(tid))
+       return uint32(begin), uint32(end)
+}
diff --git a/src/pkg/runtime/parfor.c b/src/pkg/runtime/parfor.c
new file mode 100644 (file)
index 0000000..7ebbaac
--- /dev/null
@@ -0,0 +1,210 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Parallel for algorithm.
+
+#include "runtime.h"
+#include "arch_GOARCH.h"
+
+struct ParForThread
+{
+       // the thread's iteration space [32lsb, 32msb)
+       uint64 pos;
+       // stats
+       uint64 nsteal;
+       uint64 nstealcnt;
+       uint64 nprocyield;
+       uint64 nosyield;
+       uint64 nsleep;
+       byte pad[CacheLineSize];
+};
+
+ParFor*
+runtime·parforalloc(uint32 nthrmax)
+{
+       ParFor *desc;
+
+       // The ParFor object is followed by CacheLineSize padding
+       // and then nthrmax ParForThread.
+       desc = (ParFor*)runtime·malloc(sizeof(ParFor) + CacheLineSize + nthrmax * sizeof(ParForThread));
+       desc->thr = (ParForThread*)((byte*)(desc+1) + CacheLineSize);
+       desc->nthrmax = nthrmax;
+       return desc;
+}
+
+// For testing from Go
+// func parforalloc2(nthrmax uint32) *ParFor
+void
+runtime·parforalloc2(uint32 nthrmax, ParFor *desc)
+{
+       desc = runtime·parforalloc(nthrmax);
+       FLUSH(&desc);
+}
+
+void
+runtime·parforsetup(ParFor *desc, uint32 nthr, uint32 n, void *ctx, bool wait, void (*body)(ParFor*, uint32))
+{
+       uint32 i, begin, end;
+
+       if(desc == nil || nthr == 0 || nthr > desc->nthrmax || body == nil) {
+               runtime·printf("desc=%p nthr=%d count=%d body=%p\n", desc, nthr, n, body);
+               runtime·throw("parfor: invalid args");
+       }
+
+       desc->body = body;
+       desc->done = 0;
+       desc->nthr = nthr;
+       desc->thrseq = 0;
+       desc->cnt = n;
+       desc->ctx = ctx;
+       desc->wait = wait;
+       desc->nsteal = 0;
+       desc->nstealcnt = 0;
+       desc->nprocyield = 0;
+       desc->nosyield = 0;
+       desc->nsleep = 0;
+       for(i=0; i<nthr; i++) {
+               begin = (uint64)n*i / nthr;
+               end = (uint64)n*(i+1) / nthr;
+               desc->thr[i].pos = (uint64)begin | (((uint64)end)<<32);
+       }
+}
+
+// For testing from Go
+// func parforsetup2(desc *ParFor, nthr, n uint32, ctx *byte, wait bool, body func(*ParFor, uint32))
+void
+runtime·parforsetup2(ParFor *desc, uint32 nthr, uint32 n, void *ctx, bool wait, void *body)
+{
+       runtime·parforsetup(desc, nthr, n, ctx, wait, (void(*)(ParFor*, uint32))body);
+}
+
+void
+runtime·parfordo(ParFor *desc)
+{
+       ParForThread *me;
+       uint32 tid, begin, end, begin2, try, victim, i;
+       uint64 *mypos, *victimpos, pos, newpos;
+       void (*body)(ParFor*, uint32);
+       bool idle;
+
+       // Obtain 0-based thread index.
+       tid = runtime·xadd(&desc->thrseq, 1) - 1;
+       if(tid >= desc->nthr) {
+               runtime·printf("tid=%d nthr=%d\n", tid, desc->nthr);
+               runtime·throw("parfor: invalid tid");
+       }
+
+       // If single-threaded, just execute the for serially.
+       if(desc->nthr==1) {
+               for(i=0; i<desc->cnt; i++)
+                       desc->body(desc, i);
+               return;
+       }
+
+       body = desc->body;
+       me = &desc->thr[tid];
+       mypos = &me->pos;
+       for(;;) {
+               for(;;) {
+                       // While there is local work,
+                       // bump low index and execute the iteration.
+                       pos = runtime·xadd64(mypos, 1);
+                       begin = (uint32)pos-1;
+                       end = (uint32)(pos>>32);
+                       if(begin < end) {
+                               body(desc, begin);
+                               continue;
+                       }
+                       break;
+               }
+
+               // Out of work, need to steal something.
+               idle = false;
+               for(try=0;; try++) {
+                       // If we don't see any work for long enough,
+                       // increment the done counter...
+                       if(try > desc->nthr*4 && !idle) {
+                               idle = true;
+                               runtime·xadd(&desc->done, 1);
+                       }
+                       // ...if all threads have incremented the counter,
+                       // we are done.
+                       if(desc->done + !idle == desc->nthr) {
+                               if(!idle)
+                                       runtime·xadd(&desc->done, 1);
+                               goto exit;
+                       }
+                       // Choose a random victim for stealing.
+                       victim = runtime·fastrand1() % (desc->nthr-1);
+                       if(victim >= tid)
+                               victim++;
+                       victimpos = &desc->thr[victim].pos;
+                       pos = runtime·atomicload64(victimpos);
+                       for(;;) {
+                               // See if it has any work.
+                               begin = (uint32)pos;
+                               end = (uint32)(pos>>32);
+                               if(begin >= end-1) {
+                                       begin = end = 0;
+                                       break;
+                               }
+                               if(idle) {
+                                       runtime·xadd(&desc->done, -1);
+                                       idle = false;
+                               }
+                               begin2 = begin + (end-begin)/2;
+                               newpos = (uint64)begin | (uint64)begin2<<32;
+                               if(runtime·cas64(victimpos, &pos, newpos)) {
+                                       begin = begin2;
+                                       break;
+                               }
+                       }
+                       if(begin < end) {
+                               // Has successfully stolen some work.
+                               if(idle)
+                                       runtime·throw("parfor: should not be idle");
+                               runtime·atomicstore64(mypos, (uint64)begin | (uint64)end<<32);
+                               me->nsteal++;
+                               me->nstealcnt += end-begin;
+                               break;
+                       }
+                       // Backoff.
+                       if(try < desc->nthr) {
+                               // nothing
+                       } else if (try < 4*desc->nthr) {
+                               me->nprocyield++;
+                               runtime·procyield(20);
+                       // If a caller asked not to wait for the others, exit now
+                       // (assume that most work is already done at this point).
+                       } else if (!desc->wait) {
+                               if(!idle)
+                                       runtime·xadd(&desc->done, 1);
+                               goto exit;
+                       } else if (try < 6*desc->nthr) {
+                               me->nosyield++;
+                               runtime·osyield();
+                       } else {
+                               me->nsleep++;
+                               runtime·usleep(1);
+                       }
+               }
+       }
+exit:
+       runtime·xadd64(&desc->nsteal, me->nsteal);
+       runtime·xadd64(&desc->nstealcnt, me->nstealcnt);
+       runtime·xadd64(&desc->nprocyield, me->nprocyield);
+       runtime·xadd64(&desc->nosyield, me->nosyield);
+       runtime·xadd64(&desc->nsleep, me->nsleep);
+}
+
+// For testing from Go
+// func parforiters(desc *ParFor, tid uintptr) (uintptr, uintptr)
+void
+runtime·parforiters(ParFor *desc, uintptr tid, uintptr start, uintptr end)
+{
+       start = (uint32)desc->thr[tid].pos;
+       end = (uint32)(desc->thr[tid].pos>>32);
+       FLUSH(&start);
+       FLUSH(&end);
+}
diff --git a/src/pkg/runtime/parfor_test.go b/src/pkg/runtime/parfor_test.go
new file mode 100644 (file)
index 0000000..055c134
--- /dev/null
@@ -0,0 +1,117 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package runtime_test
+
+import (
+       . "runtime"
+       "testing"
+       "unsafe"
+)
+
+// Simple serial sanity test for parallelfor.
+func TestParFor(t *testing.T) {
+       const P = 1
+       const N = 20
+       data := make([]uint64, N)
+       for i := uint64(0); i < N; i++ {
+               data[i] = i
+       }
+       desc := NewParFor(P)
+       ParForSetup(desc, P, N, nil, true, func(desc *ParFor, i uint32) {
+               data[i] = data[i]*data[i] + 1
+       })
+       ParForDo(desc)
+       for i := uint64(0); i < N; i++ {
+               if data[i] != i*i+1 {
+                       t.Fatalf("Wrong element %d: %d", i, data[i])
+               }
+       }
+}
+
+// Test that nonblocking parallelfor does not block.
+func TestParFor2(t *testing.T) {
+       const P = 7
+       const N = 1003
+       data := make([]uint64, N)
+       for i := uint64(0); i < N; i++ {
+               data[i] = i
+       }
+       desc := NewParFor(P)
+       ParForSetup(desc, P, N, (*byte)(unsafe.Pointer(&data)), false, func(desc *ParFor, i uint32) {
+               d := *(*[]uint64)(unsafe.Pointer(desc.Ctx))
+               d[i] = d[i]*d[i] + 1
+       })
+       for p := 0; p < P; p++ {
+               ParForDo(desc)
+       }
+       for i := uint64(0); i < N; i++ {
+               if data[i] != i*i+1 {
+                       t.Fatalf("Wrong element %d: %d", i, data[i])
+               }
+       }
+}
+
+// Test that iterations are properly distributed.
+func TestParForSetup(t *testing.T) {
+       const P = 11
+       const N = 101
+       desc := NewParFor(P)
+       for n := uint32(0); n < N; n++ {
+               for p := uint32(1); p <= P; p++ {
+                       ParForSetup(desc, p, n, nil, true, func(desc *ParFor, i uint32) {})
+                       sum := uint32(0)
+                       size0 := uint32(0)
+                       end0 := uint32(0)
+                       for i := uint32(0); i < p; i++ {
+                               begin, end := ParForIters(desc, i)
+                               size := end - begin
+                               sum += size
+                               if i == 0 {
+                                       size0 = size
+                                       if begin != 0 {
+                                               t.Fatalf("incorrect begin: %d (n=%d, p=%d)", begin, n, p)
+                                       }
+                               } else {
+                                       if size != size0 && size != size0+1 {
+                                               t.Fatalf("incorrect size: %d/%d (n=%d, p=%d)", size, size0, n, p)
+                                       }
+                                       if begin != end0 {
+                                               t.Fatalf("incorrect begin/end: %d/%d (n=%d, p=%d)", begin, end0, n, p)
+                                       }
+                               }
+                               end0 = end
+                       }
+                       if sum != n {
+                               t.Fatalf("incorrect sum: %d/%d (p=%d)", sum, n, p)
+                       }
+               }
+       }
+}
+
+// Test parallel parallelfor.
+func TestParForParallel(t *testing.T) {
+       N := uint64(1e7)
+       if testing.Short() {
+               N /= 10
+       }
+       data := make([]uint64, N)
+       for i := uint64(0); i < N; i++ {
+               data[i] = i
+       }
+       P := GOMAXPROCS(-1)
+       desc := NewParFor(uint32(P))
+       ParForSetup(desc, uint32(P), uint32(N), nil, true, func(desc *ParFor, i uint32) {
+               data[i] = data[i]*data[i] + 1
+       })
+       for p := 1; p < P; p++ {
+               go ParForDo(desc)
+       }
+       ParForDo(desc)
+       for i := uint64(0); i < N; i++ {
+               if data[i] != i*i+1 {
+                       t.Fatalf("Wrong element %d: %d", i, data[i])
+               }
+       }
+}
index 672e05bfc90150da944bb65640c5e71078c34c64..15f5fa31c85fbff8f6fd52fe28e594960db96fc4 100644 (file)
@@ -73,6 +73,8 @@ typedef       struct  Timers          Timers;
 typedef        struct  Timer           Timer;
 typedef struct GCStats         GCStats;
 typedef struct LFNode          LFNode;
+typedef struct ParFor          ParFor;
+typedef struct ParForThread    ParForThread;
 
 /*
  * per-cpu declaration.
@@ -359,6 +361,27 @@ struct LFNode
        uintptr pushcnt;
 };
 
+// Parallel for descriptor.
+struct ParFor
+{
+       void (*body)(ParFor*, uint32);  // executed for each element
+       uint32 done;                    // number of idle threads
+       uint32 nthr;                    // total number of threads
+       uint32 nthrmax;                 // maximum number of threads
+       uint32 thrseq;                  // thread id sequencer
+       uint32 cnt;                     // iteration space [0, cnt)
+       void *ctx;                      // arbitrary user context
+       bool wait;                      // if true, wait while all threads finish processing,
+                                       // otherwise parfor may return while other threads are still working
+       ParForThread *thr;              // array of thread descriptors
+       // stats
+       uint64 nsteal;
+       uint64 nstealcnt;
+       uint64 nprocyield;
+       uint64 nosyield;
+       uint64 nsleep;
+};
+
 /*
  * defined macros
  *    you need super-gopher-guru privilege
@@ -668,6 +691,18 @@ void       runtime·futexwakeup(uint32*, uint32);
 void   runtime·lfstackpush(uint64 *head, LFNode *node);
 LFNode*        runtime·lfstackpop(uint64 *head);
 
+/*
+ * Parallel for over [0, n).
+ * body() is executed for each iteration.
+ * nthr - total number of worker threads.
+ * ctx - arbitrary user context.
+ * if wait=true, threads return from parfor() when all work is done;
+ * otherwise, threads can return while other threads are still finishing processing.
+ */
+ParFor*        runtime·parforalloc(uint32 nthrmax);
+void   runtime·parforsetup(ParFor *desc, uint32 nthr, uint32 n, void *ctx, bool wait, void (*body)(ParFor*, uint32));
+void   runtime·parfordo(ParFor *desc);
+
 /*
  * This is consistent across Linux and BSD.
  * If a new OS is added that is different, move this to