]> Cypherpunks repositories - gostls13.git/commitdiff
binary search for constant case statements.
authorKen Thompson <ken@golang.org>
Sun, 8 Mar 2009 01:33:42 +0000 (17:33 -0800)
committerKen Thompson <ken@golang.org>
Sun, 8 Mar 2009 01:33:42 +0000 (17:33 -0800)
R=r
OCL=25890
CL=25890

src/cmd/gc/swt.c

index fb3e4b095bbf4c8b62272616837aad47ffce9a82..5f014e9a9f5f51ef0c3e35fb130e56edcfb6be04 100644 (file)
@@ -11,6 +11,7 @@ enum
        Sfalse,
        Stype,
 };
+Node*  binarysw(Node *t, Iter *save, Node *name);
 
 /*
  * walktype
@@ -263,6 +264,7 @@ prepsw(Node *sw, int arg)
        Iter save;
        Node *name, *bool, *cas;
        Node *t, *a;
+//dump("prepsw before", sw->nbody->left);
 
        cas = N;
        name = N;
@@ -279,7 +281,7 @@ prepsw(Node *sw, int arg)
 loop:
        if(t == N) {
                sw->nbody->left = rev(cas);
-//dump("case", sw->nbody->left);
+//dump("prepsw after", sw->nbody->left);
                return;
        }
 
@@ -309,21 +311,29 @@ loop:
                goto loop;
        }
 
-       a = nod(OIF, N, N);
-       a->nbody = t->right;                            // then goto l
 
        switch(arg) {
        default:
                // not bool const
+               a = binarysw(t, &save, name);
+               if(a != N)
+                       break;
+
+               a = nod(OIF, N, N);
                a->ntest = nod(OEQ, name, t->left);     // if name == val
+               a->nbody = t->right;                    // then goto l
                break;
 
        case Strue:
+               a = nod(OIF, N, N);
                a->ntest = t->left;                     // if val
+               a->nbody = t->right;                    // then goto l
                break;
 
        case Sfalse:
+               a = nod(OIF, N, N);
                a->ntest = nod(ONOT, t->left, N);       // if !val
+               a->nbody = t->right;                    // then goto l
                break;
        }
        cas = list(cas, a);
@@ -470,3 +480,202 @@ walkswitch(Node *sw)
        walkstate(sw->nbody);
 //print("normal done\n");
 }
+
+/*
+ * binary search on cases
+ */
+enum
+{
+       Ncase   = 4,    // needed to binary search
+};
+
+typedef        struct  Case    Case;
+struct Case
+{
+       Node*   node;           // points at case statement
+       Case*   link;           // linked list to link
+};
+#define        C       ((Case*)nil)
+
+int
+iscaseconst(Node *t)
+{
+       if(t == N || t->left == N)
+               return 0;
+       switch(whatis(t->left)) {
+       case Wlitfloat:
+       case Wlitint:
+       case Wlitstr:
+               return 1;
+       }
+       return 0;
+}
+
+int
+countcase(Node *t, Iter save)
+{
+       int n;
+
+       // note that the iter is by value,
+       // so cases are not really consumed
+       for(n=0;; n++) {
+               if(!iscaseconst(t))
+                       return n;
+               t = listnext(&save);
+       }
+}
+
+Case*
+csort(Case *l, int(*f)(Case*, Case*))
+{
+       Case *l1, *l2, *le;
+
+       if(l == C || l->link == C)
+               return l;
+
+       l1 = l;
+       l2 = l;
+       for(;;) {
+               l2 = l2->link;
+               if(l2 == C)
+                       break;
+               l2 = l2->link;
+               if(l2 == C)
+                       break;
+               l1 = l1->link;
+       }
+
+       l2 = l1->link;
+       l1->link = C;
+       l1 = csort(l, f);
+       l2 = csort(l2, f);
+
+       /* set up lead element */
+       if((*f)(l1, l2) < 0) {
+               l = l1;
+               l1 = l1->link;
+       } else {
+               l = l2;
+               l2 = l2->link;
+       }
+       le = l;
+
+       for(;;) {
+               if(l1 == C) {
+                       while(l2) {
+                               le->link = l2;
+                               le = l2;
+                               l2 = l2->link;
+                       }
+                       le->link = C;
+                       break;
+               }
+               if(l2 == C) {
+                       while(l1) {
+                               le->link = l1;
+                               le = l1;
+                               l1 = l1->link;
+                       }
+                       break;
+               }
+               if((*f)(l1, l2) < 0) {
+                       le->link = l1;
+                       le = l1;
+                       l1 = l1->link;
+               } else {
+                       le->link = l2;
+                       le = l2;
+                       l2 = l2->link;
+               }
+       }
+       le->link = C;
+       return l;
+}
+
+int
+casecmp(Case *c1, Case *c2)
+{
+       int w;
+
+       w = whatis(c1->node->left);
+       if(w != whatis(c2->node->left))
+               fatal("casecmp1");
+
+       switch(w) {
+       case Wlitfloat:
+               return mpcmpfltflt(c1->node->left->val.u.fval, c2->node->left->val.u.fval);
+       case Wlitint:
+               return mpcmpfixfix(c1->node->left->val.u.xval, c2->node->left->val.u.xval);
+       case Wlitstr:
+               return cmpslit(c1->node->left, c2->node->left);
+       }
+
+       fatal("casecmp2");
+       return 0;
+}
+
+Node*
+constsw(Case *c0, int ncase, Node *name)
+{
+       Node *cas, *a;
+       Case *c;
+       int i, n;
+
+       // small number do sequentially
+       if(ncase < Ncase) {
+               cas = N;
+               for(i=0; i<ncase; i++) {
+                       a = nod(OIF, N, N);
+                       a->ntest = nod(OEQ, name, c0->node->left);
+                       a->nbody = c0->node->right;
+                       cas = list(cas, a);
+                       c0 = c0->link;
+               }
+               return rev(cas);
+       }
+
+       // find center and recur
+       c = c0;
+       n = ncase>>1;
+       for(i=0; i<n; i++)
+               c = c->link;
+
+       a = nod(OIF, N, N);
+       a->ntest = nod(OLE, name, c->node->left);
+       a->nbody = constsw(c0, n+1, name);      // include center
+       a->nelse = constsw(c->link, ncase-n-1, name);   // exclude center
+       return a;
+}
+
+Node*
+binarysw(Node *t, Iter *save, Node *name)
+{
+       Case *c, *c1;
+       int i, ncase;
+       Node *a;
+
+       ncase = countcase(t, *save);
+       if(ncase < Ncase)
+               return N;
+
+       c = C;
+       for(i=1; i<ncase; i++) {
+               c1 = mal(sizeof(*c1));
+               c1->link = c;
+               c1->node = t;
+               c = c1;
+
+               t = listnext(save);
+       }
+
+       // last one shouldnt consume the iter
+       c1 = mal(sizeof(*c1));
+       c1->link = c;
+       c1->node = t;
+       c = c1;
+
+       c = csort(c, casecmp);
+       a = constsw(c, ncase, name);
+//dump("bin", a);
+       return a;
+}