]> Cypherpunks repositories - gostls13.git/commitdiff
syscall: another attempt to keep windows syscall pointers live
authorAlex Brainman <alex.brainman@gmail.com>
Sun, 5 Oct 2014 02:15:13 +0000 (13:15 +1100)
committerAlex Brainman <alex.brainman@gmail.com>
Sun, 5 Oct 2014 02:15:13 +0000 (13:15 +1100)
This approach was suggested in
https://golang.org/cl/138250043/#msg15.
Unlike current version of mksyscall_windows.go,
new code could be used in go.sys and other external
repos without help from asm.

LGTM=iant
R=golang-codereviews, iant, r
CC=golang-codereviews
https://golang.org/cl/143160046

src/syscall/mksyscall_windows.go
src/syscall/zsyscall_windows.go

index 1cdd6b4d226c66c02fcbdaa12ef36f0757eca1c6..316e88d7ea2e6e0f09d8e76d8d9a5937055674bc 100644 (file)
@@ -138,8 +138,6 @@ func (p *Param) StringTmpVarCode() string {
 // TmpVarCode returns source code for temp variable.
 func (p *Param) TmpVarCode() string {
        switch {
-       case p.Type == "string":
-               return p.StringTmpVarCode()
        case p.Type == "bool":
                return p.BoolTmpVarCode()
        case strings.HasPrefix(p.Type, "[]"):
@@ -149,19 +147,26 @@ func (p *Param) TmpVarCode() string {
        }
 }
 
+// TmpVarHelperCode returns source code for helper's temp variable.
+func (p *Param) TmpVarHelperCode() string {
+       if p.Type != "string" {
+               return ""
+       }
+       return p.StringTmpVarCode()
+}
+
 // SyscallArgList returns source code fragments representing p parameter
 // in syscall. Slices are translated into 2 syscall parameters: pointer to
 // the first element and length.
 func (p *Param) SyscallArgList() []string {
+       t := p.HelperType()
        var s string
        switch {
-       case p.Type[0] == '*':
+       case t[0] == '*':
                s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name)
-       case p.Type == "string":
-               s = fmt.Sprintf("unsafe.Pointer(%s)", p.tmpVar())
-       case p.Type == "bool":
+       case t == "bool":
                s = p.tmpVar()
-       case strings.HasPrefix(p.Type, "[]"):
+       case strings.HasPrefix(t, "[]"):
                return []string{
                        fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()),
                        fmt.Sprintf("uintptr(len(%s))", p.Name),
@@ -177,6 +182,14 @@ func (p *Param) IsError() bool {
        return p.Name == "err" && p.Type == "error"
 }
 
+// HelperType returns type of parameter p used in helper function.
+func (p *Param) HelperType() string {
+       if p.Type == "string" {
+               return p.fn.StrconvType()
+       }
+       return p.Type
+}
+
 // join concatenates parameters ps into a string with sep separator.
 // Each parameter is converted into string by applying fn to it
 // before conversion.
@@ -454,6 +467,11 @@ func (f *Fn) ParamList() string {
        return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ")
 }
 
+// HelperParamList returns source code for helper function f parameters.
+func (f *Fn) HelperParamList() string {
+       return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ")
+}
+
 // ParamPrintList returns source code of trace printing part correspondent
 // to syscall input parameters.
 func (f *Fn) ParamPrintList() string {
@@ -510,6 +528,19 @@ func (f *Fn) SyscallParamList() string {
        return strings.Join(a, ", ")
 }
 
+// HelperCallParamList returns source code of call into function f helper.
+func (f *Fn) HelperCallParamList() string {
+       a := make([]string, 0, len(f.Params))
+       for _, p := range f.Params {
+               s := p.Name
+               if p.Type == "string" {
+                       s = p.tmpVar()
+               }
+               a = append(a, s)
+       }
+       return strings.Join(a, ", ")
+}
+
 // IsUTF16 is true, if f is W (utf16) function. It is false
 // for all A (ascii) functions.
 func (f *Fn) IsUTF16() bool {
@@ -533,6 +564,25 @@ func (f *Fn) StrconvType() string {
        return "*byte"
 }
 
+// HasStringParam is true, if f has at least one string parameter.
+// Otherwise it is false.
+func (f *Fn) HasStringParam() bool {
+       for _, p := range f.Params {
+               if p.Type == "string" {
+                       return true
+               }
+       }
+       return false
+}
+
+// HelperName returns name of function f helper.
+func (f *Fn) HelperName() string {
+       if !f.HasStringParam() {
+               return f.Name
+       }
+       return "_" + f.Name
+}
+
 // Source files and functions.
 type Source struct {
        Funcs []*Fn
@@ -666,7 +716,7 @@ import "syscall"{{end}}
 var (
 {{template "dlls" .}}
 {{template "funcnames" .}})
-{{range .Funcs}}{{template "funcbody" .}}{{end}}
+{{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}}
 {{end}}
 
 {{/* help functions */}}
@@ -677,16 +727,27 @@ var (
 {{define "funcnames"}}{{range .Funcs}} proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}")
 {{end}}{{end}}
 
+{{define "helperbody"}}
+func {{.Name}}({{.ParamList}}) {{template "results" .}}{
+{{template "helpertmpvars" .}} return {{.HelperName}}({{.HelperCallParamList}})
+}
+{{end}}
+
 {{define "funcbody"}}
-func {{.Name}}({{.ParamList}}) {{if .Rets.List}}{{.Rets.List}} {{end}}{
+func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{
 {{template "tmpvars" .}}       {{template "syscall" .}}
 {{template "seterror" .}}{{template "printtrace" .}}   return
 }
 {{end}}
 
+{{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}}    {{.TmpVarHelperCode}}
+{{end}}{{end}}{{end}}
+
 {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}}        {{.TmpVarCode}}
 {{end}}{{end}}{{end}}
 
+{{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}}
+
 {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
 
 {{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}}
index 1f44750b7f3c8b243922c6b77984a82b64ade8c8..afc28f99390129dd7fcb0e65bfbc2e028cfa9822 100644 (file)
@@ -176,7 +176,11 @@ func LoadLibrary(libname string) (handle Handle, err error) {
        if err != nil {
                return
        }
-       r0, _, e1 := Syscall(procLoadLibraryW.Addr(), 1, uintptr(unsafe.Pointer(_p0)), 0, 0)
+       return _LoadLibrary(_p0)
+}
+
+func _LoadLibrary(libname *uint16) (handle Handle, err error) {
+       r0, _, e1 := Syscall(procLoadLibraryW.Addr(), 1, uintptr(unsafe.Pointer(libname)), 0, 0)
        handle = Handle(r0)
        if handle == 0 {
                if e1 != 0 {
@@ -206,7 +210,11 @@ func GetProcAddress(module Handle, procname string) (proc uintptr, err error) {
        if err != nil {
                return
        }
-       r0, _, e1 := Syscall(procGetProcAddress.Addr(), 2, uintptr(module), uintptr(unsafe.Pointer(_p0)), 0)
+       return _GetProcAddress(module, _p0)
+}
+
+func _GetProcAddress(module Handle, procname *byte) (proc uintptr, err error) {
+       r0, _, e1 := Syscall(procGetProcAddress.Addr(), 2, uintptr(module), uintptr(unsafe.Pointer(procname)), 0)
        proc = uintptr(r0)
        if proc == 0 {
                if e1 != 0 {
@@ -1558,7 +1566,11 @@ func GetHostByName(name string) (h *Hostent, err error) {
        if err != nil {
                return
        }
-       r0, _, e1 := Syscall(procgethostbyname.Addr(), 1, uintptr(unsafe.Pointer(_p0)), 0, 0)
+       return _GetHostByName(_p0)
+}
+
+func _GetHostByName(name *byte) (h *Hostent, err error) {
+       r0, _, e1 := Syscall(procgethostbyname.Addr(), 1, uintptr(unsafe.Pointer(name)), 0, 0)
        h = (*Hostent)(unsafe.Pointer(r0))
        if h == nil {
                if e1 != 0 {
@@ -1581,7 +1593,11 @@ func GetServByName(name string, proto string) (s *Servent, err error) {
        if err != nil {
                return
        }
-       r0, _, e1 := Syscall(procgetservbyname.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_p1)), 0)
+       return _GetServByName(_p0, _p1)
+}
+
+func _GetServByName(name *byte, proto *byte) (s *Servent, err error) {
+       r0, _, e1 := Syscall(procgetservbyname.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(proto)), 0)
        s = (*Servent)(unsafe.Pointer(r0))
        if s == nil {
                if e1 != 0 {
@@ -1605,7 +1621,11 @@ func GetProtoByName(name string) (p *Protoent, err error) {
        if err != nil {
                return
        }
-       r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(_p0)), 0, 0)
+       return _GetProtoByName(_p0)
+}
+
+func _GetProtoByName(name *byte) (p *Protoent, err error) {
+       r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(name)), 0, 0)
        p = (*Protoent)(unsafe.Pointer(r0))
        if p == nil {
                if e1 != 0 {
@@ -1623,7 +1643,11 @@ func DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSR
        if status != nil {
                return
        }
-       r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(_p0)), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr)))
+       return _DnsQuery(_p0, qtype, options, extra, qrs, pr)
+}
+
+func _DnsQuery(name *uint16, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status error) {
+       r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(name)), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr)))
        if r0 != 0 {
                status = Errno(r0)
        }