]> Cypherpunks repositories - gostls13.git/commitdiff
mime: add ExtensionByType method
authorNick Cooper <nmvc@google.com>
Thu, 12 Mar 2015 00:23:44 +0000 (11:23 +1100)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 27 Mar 2015 16:24:07 +0000 (16:24 +0000)
Added the inverse of TypeByExtension for discovering an appropriate
extensions for a given MIME type.

Fixes #10144

Change-Id: I6a80e1af3db5d45ad6a4c7ff4ccfdf6a4f424367
Reviewed-on: https://go-review.googlesource.com/7444
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/mime/type.go
src/mime/type_test.go

index c605a94787cc656fbb72b7645085fd3713f725f7..89254f5001adc9ad891affdc283bb0f85a3c2c7d 100644 (file)
@@ -25,7 +25,8 @@ var (
                ".svg":  "image/svg+xml",
                ".xml":  "text/xml; charset=utf-8",
        }
-       mimeTypes = clone(mimeTypesLower)
+       mimeTypes  = clone(mimeTypesLower)
+       extensions = invert(mimeTypesLower)
 )
 
 func clone(m map[string]string) map[string]string {
@@ -39,6 +40,18 @@ func clone(m map[string]string) map[string]string {
        return m2
 }
 
+func invert(m map[string]string) map[string][]string {
+       m2 := make(map[string][]string, len(m))
+       for k, v := range m {
+               justType, _, err := ParseMediaType(v)
+               if err != nil {
+                       panic(err)
+               }
+               m2[justType] = append(m2[justType], k)
+       }
+       return m2
+}
+
 var once sync.Once // guards initMime
 
 // TypeByExtension returns the MIME type associated with the file extension ext.
@@ -92,6 +105,26 @@ func TypeByExtension(ext string) string {
        return mimeTypesLower[string(lower)]
 }
 
+// ExtensionsByType returns the extensions known to be associated with the MIME
+// type typ. The returned extensions will each begin with a leading dot, as in
+// ".html". When typ has no associated extensions, ExtensionsByType returns an
+// nil slice.
+func ExtensionsByType(typ string) ([]string, error) {
+       justType, _, err := ParseMediaType(typ)
+       if err != nil {
+               return nil, err
+       }
+
+       once.Do(initMime)
+       mimeLock.RLock()
+       defer mimeLock.RUnlock()
+       s, ok := extensions[justType]
+       if !ok {
+               return nil, nil
+       }
+       return append([]string{}, s...), nil
+}
+
 // AddExtensionType sets the MIME type associated with
 // the extension ext to typ. The extension should begin with
 // a leading dot, as in ".html".
@@ -104,7 +137,7 @@ func AddExtensionType(ext, typ string) error {
 }
 
 func setExtensionType(extension, mimeType string) error {
-       _, param, err := ParseMediaType(mimeType)
+       justType, param, err := ParseMediaType(mimeType)
        if err != nil {
                return err
        }
@@ -115,8 +148,14 @@ func setExtensionType(extension, mimeType string) error {
        extLower := strings.ToLower(extension)
 
        mimeLock.Lock()
+       defer mimeLock.Unlock()
        mimeTypes[extension] = mimeType
        mimeTypesLower[extLower] = mimeType
-       mimeLock.Unlock()
+       for _, v := range extensions[justType] {
+               if v == extLower {
+                       return nil
+               }
+       }
+       extensions[justType] = append(extensions[justType], extLower)
        return nil
 }
index d2d254ae9a5d7dfa6b35aeef0d95c1c0ae8a2894..dabb585e21827b2e18ed63b255477b4f5afeba79 100644 (file)
@@ -5,6 +5,9 @@
 package mime
 
 import (
+       "reflect"
+       "sort"
+       "strings"
        "testing"
 )
 
@@ -43,6 +46,59 @@ func TestTypeByExtensionCase(t *testing.T) {
        }
 }
 
+func TestExtensionsByType(t *testing.T) {
+       for want, typ := range typeTests {
+               val, err := ExtensionsByType(typ)
+               if err != nil {
+                       t.Errorf("error %s for ExtensionsByType(%q)", err, typ)
+                       continue
+               }
+               if len(val) != 1 {
+                       t.Errorf("ExtensionsByType(%q) = %v; expected exactly 1 entry", typ, val)
+                       continue
+               }
+               // We always expect lower case, test data includes upper-case.
+               want = strings.ToLower(want)
+               if val[0] != want {
+                       t.Errorf("ExtensionsByType(%q) = %q, want %q", typ, val[0], want)
+               }
+       }
+}
+
+func TestExtensionsByTypeMultiple(t *testing.T) {
+       const typ = "text/html"
+       exts, err := ExtensionsByType(typ)
+       if err != nil {
+               t.Fatalf("ExtensionsByType(%q) error: %v", typ, err)
+       }
+       sort.Strings(exts)
+       if want := []string{".htm", ".html"}; !reflect.DeepEqual(exts, want) {
+               t.Errorf("ExtensionsByType(%q) = %v; want %v", typ, exts, want)
+       }
+}
+
+func TestExtensionsByTypeNoDuplicates(t *testing.T) {
+       const (
+               typ = "text/html"
+               ext = ".html"
+       )
+       AddExtensionType(ext, typ)
+       AddExtensionType(ext, typ)
+       exts, err := ExtensionsByType(typ)
+       if err != nil {
+               t.Fatalf("ExtensionsByType(%q) error: %v", typ, err)
+       }
+       count := 0
+       for _, v := range exts {
+               if v == ext {
+                       count++
+               }
+       }
+       if count != 1 {
+               t.Errorf("ExtensionsByType(%q) = %v; want %v once", typ, exts, ext)
+       }
+}
+
 func TestLookupMallocs(t *testing.T) {
        n := testing.AllocsPerRun(10000, func() {
                TypeByExtension(".html")