package ioutil
import (
+ "errors"
"os"
"path/filepath"
"strconv"
dir = os.TempDir()
}
- prefix, suffix := prefixAndSuffix(pattern)
+ prefix, suffix, err := prefixAndSuffix(pattern)
+ if err != nil {
+ return
+ }
nconflict := 0
for i := 0; i < 10000; i++ {
return
}
+var errPatternHasSeparator = errors.New("pattern contains path separator")
+
// prefixAndSuffix splits pattern by the last wildcard "*", if applicable,
// returning prefix as the part before "*" and suffix as the part after "*".
-func prefixAndSuffix(pattern string) (prefix, suffix string) {
+func prefixAndSuffix(pattern string) (prefix, suffix string, err error) {
+ if strings.ContainsRune(pattern, os.PathSeparator) {
+ err = errPatternHasSeparator
+ return
+ }
if pos := strings.LastIndex(pattern, "*"); pos != -1 {
prefix, suffix = pattern[:pos], pattern[pos+1:]
} else {
dir = os.TempDir()
}
- prefix, suffix := prefixAndSuffix(pattern)
+ prefix, suffix, err := prefixAndSuffix(pattern)
+ if err != nil {
+ return
+ }
nconflict := 0
for i := 0; i < 10000; i++ {
}
}
+func TestTempFile_BadPattern(t *testing.T) {
+ tmpDir, err := TempDir("", t.Name())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ const sep = string(os.PathSeparator)
+ tests := []struct {
+ pattern string
+ wantErr bool
+ } {
+ {"ioutil*test", false},
+ {"ioutil_test*foo", false},
+ {"ioutil_test" + sep + "foo", true},
+ {"ioutil_test*" + sep + "foo", true},
+ {"ioutil_test" + sep + "*foo", true},
+ {sep + "ioutil_test" + sep + "*foo", true},
+ {"ioutil_test*foo" + sep, true},
+ }
+ for _, tt := range tests {
+ t.Run(tt.pattern, func(t *testing.T) {
+ tmpfile, err := TempFile(tmpDir, tt.pattern)
+ defer func() {
+ if tmpfile != nil {
+ tmpfile.Close()
+ }
+ }()
+ if tt.wantErr {
+ if err == nil {
+ t.Errorf("Expected an error for pattern %q", tt.pattern)
+ }
+ if g, w := err, errPatternHasSeparator; g != w {
+ t.Errorf("Error mismatch: got %#v, want %#v for pattern %q", g, w, tt.pattern)
+ }
+ } else if err != nil {
+ t.Errorf("Unexpected error %v for pattern %q", err, tt.pattern)
+ }
+ })
+ }
+}
+
func TestTempDir(t *testing.T) {
name, err := TempDir("/_not_exists_", "foo")
if name != "" || err == nil {
t.Errorf("TempDir error = %#v; want PathError for path %q satisifying os.IsNotExist", err, badDir)
}
}
+
+func TestTempDir_BadPattern(t *testing.T) {
+ tmpDir, err := TempDir("", t.Name())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ const sep = string(os.PathSeparator)
+ tests := []struct {
+ pattern string
+ wantErr bool
+ } {
+ {"ioutil*test", false},
+ {"ioutil_test*foo", false},
+ {"ioutil_test" + sep + "foo", true},
+ {"ioutil_test*" + sep + "foo", true},
+ {"ioutil_test" + sep + "*foo", true},
+ {sep + "ioutil_test" + sep + "*foo", true},
+ {"ioutil_test*foo" + sep, true},
+ }
+ for _, tt := range tests {
+ t.Run(tt.pattern, func(t *testing.T) {
+ _, err := TempDir(tmpDir, tt.pattern)
+ if tt.wantErr {
+ if err == nil {
+ t.Errorf("Expected an error for pattern %q", tt.pattern)
+ }
+ if g, w := err, errPatternHasSeparator; g != w {
+ t.Errorf("Error mismatch: got %#v, want %#v for pattern %q", g, w, tt.pattern)
+ }
+ } else if err != nil {
+ t.Errorf("Unexpected error %v for pattern %q", err, tt.pattern)
+ }
+ })
+ }
+}