]> Cypherpunks repositories - gostls13.git/commitdiff
Add authentication.
authorFirmansyah Adiputra <frm.adiputra@gmail.com>
Fri, 22 Jan 2010 06:55:44 +0000 (17:55 +1100)
committerNigel Tao <nigeltao@golang.org>
Fri, 22 Jan 2010 06:55:44 +0000 (17:55 +1100)
Other code fixing:
- Fixed bugs in get32.
- Fix code for parsing display string (as a new function).
- Fix code for connecting to X server. The old code only work
  if the server is listening to TCP port, otherwise it doesn't
  work (at least in my PC).

R=nigeltao_golang, rsc, jhh
CC=golang-dev
https://golang.org/cl/183111

src/pkg/xgb/auth.go [new file with mode: 0644]
src/pkg/xgb/xgb.go

diff --git a/src/pkg/xgb/auth.go b/src/pkg/xgb/auth.go
new file mode 100644 (file)
index 0000000..4d920c2
--- /dev/null
@@ -0,0 +1,110 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xgb
+
+import (
+       "bufio"
+       "io"
+       "os"
+)
+
+func getU16BE(r io.Reader, b []byte) (uint16, os.Error) {
+       _, err := io.ReadFull(r, b[0:2])
+       if err != nil {
+               return 0, err
+       }
+       return uint16(b[0])<<8 + uint16(b[1]), nil
+}
+
+func getBytes(r io.Reader, b []byte) ([]byte, os.Error) {
+       n, err := getU16BE(r, b)
+       if err != nil {
+               return nil, err
+       }
+       if int(n) > len(b) {
+               return nil, os.NewError("bytes too long for buffer")
+       }
+       _, err = io.ReadFull(r, b[0:n])
+       if err != nil {
+               return nil, err
+       }
+       return b[0:n], nil
+}
+
+func getString(r io.Reader, b []byte) (string, os.Error) {
+       b, err := getBytes(r, b)
+       if err != nil {
+               return "", err
+       }
+       return string(b), nil
+}
+
+// readAuthority reads the X authority file for the DISPLAY.
+// If hostname == "" or hostname == "localhost",
+// readAuthority uses the system's hostname (as returned by os.Hostname) instead.
+func readAuthority(hostname, display string) (name string, data []byte, err os.Error) {
+       // b is a scratch buffer to use and should be at least 256 bytes long
+       // (i.e. it should be able to hold a hostname).
+       var b [256]byte
+
+       // As per /usr/include/X11/Xauth.h.
+       const familyLocal = 256
+
+       if len(hostname) == 0 || hostname == "localhost" {
+               hostname, err = os.Hostname()
+               if err != nil {
+                       return "", nil, err
+               }
+       }
+
+       fname := os.Getenv("XAUTHORITY")
+       if len(fname) == 0 {
+               home := os.Getenv("HOME")
+               if len(home) == 0 {
+                       err = os.NewError("Xauthority not found: $XAUTHORITY, $HOME not set")
+                       return "", nil, err
+               }
+               fname = home + "/.Xauthority"
+       }
+
+       r, err := os.Open(fname, os.O_RDONLY, 0444)
+       if err != nil {
+               return "", nil, err
+       }
+       defer r.Close()
+
+       br := bufio.NewReader(r)
+       for {
+               family, err := getU16BE(br, b[0:2])
+               if err != nil {
+                       return "", nil, err
+               }
+
+               addr, err := getString(br, b[0:])
+               if err != nil {
+                       return "", nil, err
+               }
+
+               disp, err := getString(br, b[0:])
+               if err != nil {
+                       return "", nil, err
+               }
+
+               name0, err := getString(br, b[0:])
+               if err != nil {
+                       return "", nil, err
+               }
+
+               data0, err := getBytes(br, b[0:])
+               if err != nil {
+                       return "", nil, err
+               }
+
+               if family == familyLocal && addr == hostname && disp == display {
+                       return name0, data0, nil
+               }
+       }
+       panic("unreachable")
+}
index 1cde2b5c93e0c46daec47d343a755470bc57c7b4..3f6f0b0a68f1145e7df21ad27a93d0bed4ac070a 100644 (file)
@@ -18,12 +18,14 @@ import (
 // A Conn represents a connection to an X server.
 // Only one goroutine should use a Conn's methods at a time.
 type Conn struct {
+       host          string
        conn          net.Conn
        nextId        Id
        nextCookie    Cookie
        replies       map[Cookie][]byte
        events        queue
        err           os.Error
+       display       string
        defaultScreen int
        scratch       [32]byte
        Setup         SetupInfo
@@ -89,7 +91,7 @@ func get32(buf []byte) uint32 {
        v := uint32(buf[0])
        v |= uint32(buf[1]) << 8
        v |= uint32(buf[2]) << 16
-       v |= uint32(buf[3]) << 32
+       v |= uint32(buf[3]) << 24
        return v
 }
 
@@ -284,49 +286,39 @@ func (c *Conn) PollForEvent() (Event, os.Error) {
 }
 
 // Dial connects to the X server given in the 'display' string.
-// The display string is typically taken from os.Getenv("DISPLAY").
+// If 'display' is empty it will be taken from os.Getenv("DISPLAY").
+//
+// Examples:
+//     Dial(":1")                 // connect to net.Dial("unix", "", "/tmp/.X11-unix/X1")
+//     Dial("/tmp/launch-123/:0") // connect to net.Dial("unix", "", "/tmp/launch-123/:0")
+//     Dial("hostname:2.1")       // connect to net.Dial("tcp", "", "hostname:6002")
+//     Dial("tcp/hostname:1.0")   // connect to net.Dial("tcp", "", "hostname:6001")
 func Dial(display string) (*Conn, os.Error) {
-       var err os.Error
-
-       c := new(Conn)
+       c, err := connect(display)
+       if err != nil {
+               return nil, err
+       }
 
-       if display[0] == '/' {
-               c.conn, err = net.Dial("unix", "", display)
-               if err != nil {
-                       fmt.Printf("cannot connect: %v\n", err)
-                       return nil, err
-               }
-       } else {
-               parts := strings.Split(display, ":", 2)
-               host := parts[0]
-               port := 0
-               if len(parts) > 1 {
-                       parts = strings.Split(parts[1], ".", 2)
-                       port, _ = strconv.Atoi(parts[0])
-                       if len(parts) > 1 {
-                               c.defaultScreen, _ = strconv.Atoi(parts[1])
-                       }
-               }
-               display = fmt.Sprintf("%s:%d", host, port+6000)
-               c.conn, err = net.Dial("tcp", "", display)
-               if err != nil {
-                       fmt.Printf("cannot connect: %v\n", err)
-                       return nil, err
-               }
+       // Get authentication data
+       authName, authData, err := readAuthority(c.host, c.display)
+       if err != nil {
+               return nil, err
        }
 
-       // TODO: get these from .Xauthority
-       var authName, authData []byte
+       // Assume that the authentication protocol is "MIT-MAGIC-COOKIE-1".
+       if authName != "MIT-MAGIC-COOKIE-1" || len(authData) != 16 {
+               return nil, os.NewError("unsupported auth protocol " + authName)
+       }
 
        buf := make([]byte, 12+pad(len(authName))+pad(len(authData)))
-       buf[0] = 'l'
+       buf[0] = 0x6c
        buf[1] = 0
        put16(buf[2:], 11)
        put16(buf[4:], 0)
        put16(buf[6:], uint16(len(authName)))
        put16(buf[8:], uint16(len(authData)))
        put16(buf[10:], 0)
-       copy(buf[12:], authName)
+       copy(buf[12:], strings.Bytes(authName))
        copy(buf[12+pad(len(authName)):], authData)
        if _, err = c.conn.Write(buf); err != nil {
                return nil, err
@@ -396,3 +388,77 @@ func getClientMessageData(b []byte, v *ClientMessageData) int {
        }
        return 20
 }
+
+func connect(display string) (*Conn, os.Error) {
+       if len(display) == 0 {
+               display = os.Getenv("DISPLAY")
+       }
+
+       display0 := display
+       if len(display) == 0 {
+               return nil, os.NewError("empty display string")
+       }
+
+       colonIdx := strings.LastIndex(display, ":")
+       if colonIdx < 0 {
+               return nil, os.NewError("bad display string: " + display0)
+       }
+
+       var protocol, socket string
+       c := new(Conn)
+
+       if display[0] == '/' {
+               socket = display[0:colonIdx]
+       } else {
+               slashIdx := strings.LastIndex(display, "/")
+               if slashIdx >= 0 {
+                       protocol = display[0:slashIdx]
+                       c.host = display[slashIdx+1 : colonIdx]
+               } else {
+                       c.host = display[0:colonIdx]
+               }
+       }
+
+       display = display[colonIdx+1 : len(display)]
+       if len(display) == 0 {
+               return nil, os.NewError("bad display string: " + display0)
+       }
+
+       var scr string
+       dotIdx := strings.LastIndex(display, ".")
+       if dotIdx < 0 {
+               c.display = display[0:]
+       } else {
+               c.display = display[0:dotIdx]
+               scr = display[dotIdx+1:]
+       }
+
+       dispnum, err := strconv.Atoui(c.display)
+       if err != nil {
+               return nil, os.NewError("bad display string: " + display0)
+       }
+
+       if len(scr) != 0 {
+               c.defaultScreen, err = strconv.Atoi(scr)
+               if err != nil {
+                       return nil, os.NewError("bad display string: " + display0)
+               }
+       }
+
+       // Connect to server
+       if len(socket) != 0 {
+               c.conn, err = net.Dial("unix", "", socket+":"+c.display)
+       } else if len(c.host) != 0 {
+               if protocol == "" {
+                       protocol = "tcp"
+               }
+               c.conn, err = net.Dial(protocol, "", c.host+":"+strconv.Uitoa(6000+dispnum))
+       } else {
+               c.conn, err = net.Dial("unix", "", "/tmp/.X11-unix/X"+c.display)
+       }
+
+       if err != nil {
+               return nil, os.NewError("cannot connect to " + display0 + ": " + err.String())
+       }
+       return c, nil
+}