diff --git a/CHANGELOG.md b/CHANGELOG.md
index 40016c9ed..8490ab2c8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,18 @@
# Changelog
+## v4.11.3 - 2023-11-07
+
+**Security**
+
+* 'c.Attachment' and 'c.Inline' should escape filename in 'Content-Disposition' header to avoid 'Reflect File Download' vulnerability. [#2541](https://github.com/labstack/echo/pull/2541)
+
+**Enhancements**
+
+* Tests: refactor context tests to be separate functions [#2540](https://github.com/labstack/echo/pull/2540)
+* Proxy middleware: reuse echo request context [#2537](https://github.com/labstack/echo/pull/2537)
+* Mark unmarshallable yaml struct tags as ignored [#2536](https://github.com/labstack/echo/pull/2536)
+
+
## v4.11.2 - 2023-10-11
**Security**
diff --git a/binder.go b/binder.go
index 29cceca0b..8e7b81413 100644
--- a/binder.go
+++ b/binder.go
@@ -1323,7 +1323,7 @@ func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExi
case time.Second:
*dest = time.Unix(n, 0)
case time.Millisecond:
- *dest = time.Unix(n/1e3, (n%1e3)*1e6) // TODO: time.UnixMilli(n) exists since Go1.17 switch to that when min version allows
+ *dest = time.UnixMilli(n)
case time.Nanosecond:
*dest = time.Unix(0, n)
}
diff --git a/context.go b/context.go
index 27da28a9c..6a1811685 100644
--- a/context.go
+++ b/context.go
@@ -584,8 +584,10 @@ func (c *context) Inline(file, name string) error {
return c.contentDisposition(file, name, "inline")
}
+var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
+
func (c *context) contentDisposition(file, name, dispositionType string) error {
- c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", dispositionType, name))
+ c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, dispositionType, quoteEscaper.Replace(name)))
return c.File(file)
}
diff --git a/context_test.go b/context_test.go
index 11a63cfce..01a8784b8 100644
--- a/context_test.go
+++ b/context_test.go
@@ -19,7 +19,7 @@ import (
"time"
"github.com/labstack/gommon/log"
- testify "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/assert"
)
type (
@@ -85,303 +85,443 @@ func (t *Template) Render(w io.Writer, name string, data interface{}, c Context)
return t.templates.ExecuteTemplate(w, name, data)
}
-type responseWriterErr struct {
-}
-
-func (responseWriterErr) Header() http.Header {
- return http.Header{}
-}
+func TestContextEcho(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
-func (responseWriterErr) Write([]byte) (int, error) {
- return 0, errors.New("err")
-}
+ c := e.NewContext(req, rec).(*context)
-func (responseWriterErr) WriteHeader(statusCode int) {
+ assert.Equal(t, e, c.Echo())
}
-func TestContext(t *testing.T) {
+func TestContextRequest(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
+
c := e.NewContext(req, rec).(*context)
- assert := testify.New(t)
+ assert.NotNil(t, c.Request())
+ assert.Equal(t, req, c.Request())
+}
- // Echo
- assert.Equal(e, c.Echo())
+func TestContextResponse(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
- // Request
- assert.NotNil(c.Request())
+ c := e.NewContext(req, rec).(*context)
- // Response
- assert.NotNil(c.Response())
+ assert.NotNil(t, c.Response())
+}
+
+func TestContextRenderTemplate(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
- //--------
- // Render
- //--------
+ c := e.NewContext(req, rec).(*context)
tmpl := &Template{
templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
}
c.echo.Renderer = tmpl
err := c.Render(http.StatusOK, "hello", "Jon Snow")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("Hello, Jon Snow!", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "Hello, Jon Snow!", rec.Body.String())
}
+}
+
+func TestContextRenderErrorsOnNoRenderer(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec).(*context)
c.echo.Renderer = nil
- err = c.Render(http.StatusOK, "hello", "Jon Snow")
- assert.Error(err)
-
- // JSON
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSON+"\n", rec.Body.String())
- }
-
- // JSON with "?pretty"
- req = httptest.NewRequest(http.MethodGet, "/?pretty", nil)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSONPretty+"\n", rec.Body.String())
- }
- req = httptest.NewRequest(http.MethodGet, "/", nil) // reset
-
- // JSONPretty
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSONPretty+"\n", rec.Body.String())
- }
-
- // JSON (error)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.JSON(http.StatusOK, make(chan bool))
- assert.Error(err)
-
- // JSONP
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
+ assert.Error(t, c.Render(http.StatusOK, "hello", "Jon Snow"))
+}
+
+func TestContextJSON(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ c := e.NewContext(req, rec).(*context)
+
+ err := c.JSON(http.StatusOK, user{1, "Jon Snow"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSON+"\n", rec.Body.String())
+ }
+}
+
+func TestContextJSONErrorsOut(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ c := e.NewContext(req, rec).(*context)
+
+ err := c.JSON(http.StatusOK, make(chan bool))
+ assert.EqualError(t, err, "json: unsupported type: chan bool")
+}
+
+func TestContextJSONPrettyURL(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ err := c.JSON(http.StatusOK, user{1, "Jon Snow"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
+ }
+}
+
+func TestContextJSONPretty(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ err := c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
+ }
+}
+
+func TestContextJSONWithEmptyIntent(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ u := user{1, "Jon Snow"}
+ emptyIndent := ""
+ buf := new(bytes.Buffer)
+
+ enc := json.NewEncoder(buf)
+ enc.SetIndent(emptyIndent, emptyIndent)
+ _ = enc.Encode(u)
+ err := c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, buf.String(), rec.Body.String())
+ }
+}
+
+func TestContextJSONP(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec).(*context)
+
callback := "callback"
- err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String())
- }
-
- // XML
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.XML(http.StatusOK, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+userXML, rec.Body.String())
- }
-
- // XML with "?pretty"
- req = httptest.NewRequest(http.MethodGet, "/?pretty", nil)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.XML(http.StatusOK, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+userXMLPretty, rec.Body.String())
- }
- req = httptest.NewRequest(http.MethodGet, "/", nil)
-
- // XML (error)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.XML(http.StatusOK, make(chan bool))
- assert.Error(err)
-
- // XML response write error
- c = e.NewContext(req, rec).(*context)
- c.response.Writer = responseWriterErr{}
- err = c.XML(0, 0)
- testify.Error(t, err)
-
- // XMLPretty
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+userXMLPretty, rec.Body.String())
- }
-
- t.Run("empty indent", func(t *testing.T) {
- var (
- u = user{1, "Jon Snow"}
- buf = new(bytes.Buffer)
- emptyIndent = ""
- )
-
- t.Run("json", func(t *testing.T) {
- buf.Reset()
- assert := testify.New(t)
-
- // New JSONBlob with empty indent
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- enc := json.NewEncoder(buf)
- enc.SetIndent(emptyIndent, emptyIndent)
- err = enc.Encode(u)
- err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(buf.String(), rec.Body.String())
- }
- })
+ err := c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String())
+ }
+}
- t.Run("xml", func(t *testing.T) {
- buf.Reset()
- assert := testify.New(t)
-
- // New XMLBlob with empty indent
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- enc := xml.NewEncoder(buf)
- enc.Indent(emptyIndent, emptyIndent)
- err = enc.Encode(u)
- err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+buf.String(), rec.Body.String())
- }
- })
- })
+func TestContextJSONBlob(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec).(*context)
- // Legacy JSONBlob
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
data, err := json.Marshal(user{1, "Jon Snow"})
- assert.NoError(err)
+ assert.NoError(t, err)
err = c.JSONBlob(http.StatusOK, data)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSON, rec.Body.String())
- }
-
- // Legacy JSONPBlob
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- callback = "callback"
- data, err = json.Marshal(user{1, "Jon Snow"})
- assert.NoError(err)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSON, rec.Body.String())
+ }
+}
+
+func TestContextJSONPBlob(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ callback := "callback"
+ data, err := json.Marshal(user{1, "Jon Snow"})
+ assert.NoError(t, err)
err = c.JSONPBlob(http.StatusOK, callback, data)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(callback+"("+userJSON+");", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, callback+"("+userJSON+");", rec.Body.String())
+ }
+}
+
+func TestContextXML(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ err := c.XML(http.StatusOK, user{1, "Jon Snow"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXML, rec.Body.String())
+ }
+}
+
+func TestContextXMLPrettyURL(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ err := c.XML(http.StatusOK, user{1, "Jon Snow"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
}
+}
- // Legacy XMLBlob
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- data, err = xml.Marshal(user{1, "Jon Snow"})
- assert.NoError(err)
+func TestContextXMLPretty(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ err := c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
+ }
+}
+
+func TestContextXMLBlob(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ data, err := xml.Marshal(user{1, "Jon Snow"})
+ assert.NoError(t, err)
err = c.XMLBlob(http.StatusOK, data)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+userXML, rec.Body.String())
- }
-
- // String
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.String(http.StatusOK, "Hello, World!")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal("Hello, World!", rec.Body.String())
- }
-
- // HTML
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.HTML(http.StatusOK, "Hello, World!")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal("Hello, World!", rec.Body.String())
- }
-
- // Stream
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXML, rec.Body.String())
+ }
+}
+
+func TestContextXMLWithEmptyIntent(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ u := user{1, "Jon Snow"}
+ emptyIndent := ""
+ buf := new(bytes.Buffer)
+
+ enc := xml.NewEncoder(buf)
+ enc.Indent(emptyIndent, emptyIndent)
+ _ = enc.Encode(u)
+ err := c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+buf.String(), rec.Body.String())
+ }
+}
+
+type responseWriterErr struct {
+}
+
+func (responseWriterErr) Header() http.Header {
+ return http.Header{}
+}
+
+func (responseWriterErr) Write([]byte) (int, error) {
+ return 0, errors.New("responseWriterErr")
+}
+
+func (responseWriterErr) WriteHeader(statusCode int) {
+}
+
+func TestContextXMLError(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
+ c.response.Writer = responseWriterErr{}
+
+ err := c.XML(http.StatusOK, make(chan bool))
+ assert.EqualError(t, err, "responseWriterErr")
+}
+
+func TestContextString(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ err := c.String(http.StatusOK, "Hello, World!")
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "Hello, World!", rec.Body.String())
+ }
+}
+
+func TestContextHTML(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ err := c.HTML(http.StatusOK, "Hello, World!")
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "Hello, World!", rec.Body.String())
+ }
+}
+
+func TestContextStream(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
+
r := strings.NewReader("response from a stream")
- err = c.Stream(http.StatusOK, "application/octet-stream", r)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType))
- assert.Equal("response from a stream", rec.Body.String())
- }
-
- // Attachment
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.Attachment("_fixture/images/walle.png", "walle.png")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
- assert.Equal(219885, rec.Body.Len())
- }
-
- // Inline
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.Inline("_fixture/images/walle.png", "walle.png")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
- assert.Equal(219885, rec.Body.Len())
- }
-
- // NoContent
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
+ err := c.Stream(http.StatusOK, "application/octet-stream", r)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "response from a stream", rec.Body.String())
+ }
+}
+
+func TestContextAttachment(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenName string
+ expectHeader string
+ }{
+ {
+ name: "ok",
+ whenName: "walle.png",
+ expectHeader: `attachment; filename="walle.png"`,
+ },
+ {
+ name: "ok, escape quotes in malicious filename",
+ whenName: `malicious.sh"; \"; dummy=.txt`,
+ expectHeader: `attachment; filename="malicious.sh\"; \\\"; dummy=.txt"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ err := c.Attachment("_fixture/images/walle.png", tc.whenName)
+ if assert.NoError(t, err) {
+ assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition))
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, 219885, rec.Body.Len())
+ }
+ })
+ }
+}
+
+func TestContextInline(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenName string
+ expectHeader string
+ }{
+ {
+ name: "ok",
+ whenName: "walle.png",
+ expectHeader: `inline; filename="walle.png"`,
+ },
+ {
+ name: "ok, escape quotes in malicious filename",
+ whenName: `malicious.sh"; \"; dummy=.txt`,
+ expectHeader: `inline; filename="malicious.sh\"; \\\"; dummy=.txt"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ err := c.Inline("_fixture/images/walle.png", tc.whenName)
+ if assert.NoError(t, err) {
+ assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition))
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, 219885, rec.Body.Len())
+ }
+ })
+ }
+}
+
+func TestContextNoContent(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
+
c.NoContent(http.StatusOK)
- assert.Equal(http.StatusOK, rec.Code)
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+func TestContextError(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
- // Error
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
c.Error(errors.New("error"))
- assert.Equal(http.StatusInternalServerError, rec.Code)
+ assert.Equal(t, http.StatusInternalServerError, rec.Code)
+ assert.True(t, c.Response().Committed)
+}
+
+func TestContextReset(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec).(*context)
- // Reset
c.SetParamNames("foo")
c.SetParamValues("bar")
c.Set("foe", "ban")
c.query = url.Values(map[string][]string{"fon": {"baz"}})
+
c.Reset(req, httptest.NewRecorder())
- assert.Equal(0, len(c.ParamValues()))
- assert.Equal(0, len(c.ParamNames()))
- assert.Equal(0, len(c.store))
- assert.Equal("", c.Path())
- assert.Equal(0, len(c.QueryParams()))
+
+ assert.Len(t, c.ParamValues(), 0)
+ assert.Len(t, c.ParamNames(), 0)
+ assert.Len(t, c.Path(), 0)
+ assert.Len(t, c.QueryParams(), 0)
+ assert.Len(t, c.store, 0)
}
func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
@@ -391,11 +531,10 @@ func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
c := e.NewContext(req, rec).(*context)
err := c.JSON(http.StatusCreated, user{1, "Jon Snow"})
- assert := testify.New(t)
- if assert.NoError(err) {
- assert.Equal(http.StatusCreated, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSON+"\n", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusCreated, rec.Code)
+ assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSON+"\n", rec.Body.String())
}
}
@@ -406,9 +545,8 @@ func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) {
c := e.NewContext(req, rec).(*context)
err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()})
- assert := testify.New(t)
- if assert.Error(err) {
- assert.False(c.response.Committed)
+ if assert.Error(t, err) {
+ assert.False(t, c.response.Committed)
}
}
@@ -422,22 +560,20 @@ func TestContextCookie(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
- assert := testify.New(t)
-
// Read single
cookie, err := c.Cookie("theme")
- if assert.NoError(err) {
- assert.Equal("theme", cookie.Name)
- assert.Equal("light", cookie.Value)
+ if assert.NoError(t, err) {
+ assert.Equal(t, "theme", cookie.Name)
+ assert.Equal(t, "light", cookie.Value)
}
// Read multiple
for _, cookie := range c.Cookies() {
switch cookie.Name {
case "theme":
- assert.Equal("light", cookie.Value)
+ assert.Equal(t, "light", cookie.Value)
case "user":
- assert.Equal("Jon Snow", cookie.Value)
+ assert.Equal(t, "Jon Snow", cookie.Value)
}
}
@@ -452,11 +588,11 @@ func TestContextCookie(t *testing.T) {
HttpOnly: true,
}
c.SetCookie(cookie)
- assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID")
- assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
- assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com")
- assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure")
- assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly")
}
func TestContextPath(t *testing.T) {
@@ -469,14 +605,12 @@ func TestContextPath(t *testing.T) {
c := e.NewContext(nil, nil)
r.Find(http.MethodGet, "/users/1", c)
- assert := testify.New(t)
-
- assert.Equal("/users/:id", c.Path())
+ assert.Equal(t, "/users/:id", c.Path())
r.Add(http.MethodGet, "/users/:uid/files/:fid", handler)
c = e.NewContext(nil, nil)
r.Find(http.MethodGet, "/users/1/files/1", c)
- assert.Equal("/users/:uid/files/:fid", c.Path())
+ assert.Equal(t, "/users/:uid/files/:fid", c.Path())
}
func TestContextPathParam(t *testing.T) {
@@ -486,15 +620,15 @@ func TestContextPathParam(t *testing.T) {
// ParamNames
c.SetParamNames("uid", "fid")
- testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames())
+ assert.EqualValues(t, []string{"uid", "fid"}, c.ParamNames())
// ParamValues
c.SetParamValues("101", "501")
- testify.EqualValues(t, []string{"101", "501"}, c.ParamValues())
+ assert.EqualValues(t, []string{"101", "501"}, c.ParamValues())
// Param
- testify.Equal(t, "501", c.Param("fid"))
- testify.Equal(t, "", c.Param("undefined"))
+ assert.Equal(t, "501", c.Param("fid"))
+ assert.Equal(t, "", c.Param("undefined"))
}
func TestContextGetAndSetParam(t *testing.T) {
@@ -507,23 +641,21 @@ func TestContextGetAndSetParam(t *testing.T) {
// round-trip param values with modification
paramVals := c.ParamValues()
- testify.EqualValues(t, []string{""}, c.ParamValues())
+ assert.EqualValues(t, []string{""}, c.ParamValues())
paramVals[0] = "bar"
c.SetParamValues(paramVals...)
- testify.EqualValues(t, []string{"bar"}, c.ParamValues())
+ assert.EqualValues(t, []string{"bar"}, c.ParamValues())
// shouldn't explode during Reset() afterwards!
- testify.NotPanics(t, func() {
+ assert.NotPanics(t, func() {
c.Reset(nil, nil)
})
}
// Issue #1655
func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) {
- assert := testify.New(t)
-
e := New()
- assert.Equal(0, *e.maxParam)
+ assert.Equal(t, 0, *e.maxParam)
expectedOneParam := []string{"one"}
expectedTwoParams := []string{"one", "two"}
@@ -533,23 +665,23 @@ func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) {
c := e.NewContext(nil, nil)
c.SetParamNames("1", "2")
c.SetParamValues(expectedTwoParams...)
- assert.Equal(2, *e.maxParam)
- assert.EqualValues(expectedTwoParams, c.ParamValues())
+ assert.Equal(t, 2, *e.maxParam)
+ assert.EqualValues(t, expectedTwoParams, c.ParamValues())
c.SetParamNames("1")
- assert.Equal(2, *e.maxParam)
+ assert.Equal(t, 2, *e.maxParam)
// Here for backward compatibility the ParamValues remains as they are
- assert.EqualValues(expectedOneParam, c.ParamValues())
+ assert.EqualValues(t, expectedOneParam, c.ParamValues())
c.SetParamNames("1", "2", "3")
- assert.Equal(3, *e.maxParam)
+ assert.Equal(t, 3, *e.maxParam)
// Here for backward compatibility the ParamValues remains as they are, but the len is extended to e.maxParam
- assert.EqualValues(expectedThreeParams, c.ParamValues())
+ assert.EqualValues(t, expectedThreeParams, c.ParamValues())
c.SetParamValues("A", "B", "C", "D")
- assert.Equal(3, *e.maxParam)
+ assert.Equal(t, 3, *e.maxParam)
// Here D shouldn't be returned
- assert.EqualValues(expectedABCParams, c.ParamValues())
+ assert.EqualValues(t, expectedABCParams, c.ParamValues())
}
func TestContextFormValue(t *testing.T) {
@@ -563,13 +695,13 @@ func TestContextFormValue(t *testing.T) {
c := e.NewContext(req, nil)
// FormValue
- testify.Equal(t, "Jon Snow", c.FormValue("name"))
- testify.Equal(t, "jon@labstack.com", c.FormValue("email"))
+ assert.Equal(t, "Jon Snow", c.FormValue("name"))
+ assert.Equal(t, "jon@labstack.com", c.FormValue("email"))
// FormParams
params, err := c.FormParams()
- if testify.NoError(t, err) {
- testify.Equal(t, url.Values{
+ if assert.NoError(t, err) {
+ assert.Equal(t, url.Values{
"name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"},
}, params)
@@ -580,8 +712,8 @@ func TestContextFormValue(t *testing.T) {
req.Header.Add(HeaderContentType, MIMEMultipartForm)
c = e.NewContext(req, nil)
params, err = c.FormParams()
- testify.Nil(t, params)
- testify.Error(t, err)
+ assert.Nil(t, params)
+ assert.Error(t, err)
}
func TestContextQueryParam(t *testing.T) {
@@ -593,11 +725,11 @@ func TestContextQueryParam(t *testing.T) {
c := e.NewContext(req, nil)
// QueryParam
- testify.Equal(t, "Jon Snow", c.QueryParam("name"))
- testify.Equal(t, "jon@labstack.com", c.QueryParam("email"))
+ assert.Equal(t, "Jon Snow", c.QueryParam("name"))
+ assert.Equal(t, "jon@labstack.com", c.QueryParam("email"))
// QueryParams
- testify.Equal(t, url.Values{
+ assert.Equal(t, url.Values{
"name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"},
}, c.QueryParams())
@@ -608,7 +740,7 @@ func TestContextFormFile(t *testing.T) {
buf := new(bytes.Buffer)
mr := multipart.NewWriter(buf)
w, err := mr.CreateFormFile("file", "test")
- if testify.NoError(t, err) {
+ if assert.NoError(t, err) {
w.Write([]byte("test"))
}
mr.Close()
@@ -617,8 +749,8 @@ func TestContextFormFile(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
f, err := c.FormFile("file")
- if testify.NoError(t, err) {
- testify.Equal(t, "test", f.Filename)
+ if assert.NoError(t, err) {
+ assert.Equal(t, "test", f.Filename)
}
}
@@ -633,8 +765,8 @@ func TestContextMultipartForm(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
f, err := c.MultipartForm()
- if testify.NoError(t, err) {
- testify.NotNil(t, f)
+ if assert.NoError(t, err) {
+ assert.NotNil(t, f)
}
}
@@ -643,16 +775,16 @@ func TestContextRedirect(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
- testify.Equal(t, http.StatusMovedPermanently, rec.Code)
- testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
- testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
+ assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
+ assert.Equal(t, http.StatusMovedPermanently, rec.Code)
+ assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
+ assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
}
func TestContextStore(t *testing.T) {
var c Context = new(context)
c.Set("name", "Jon Snow")
- testify.Equal(t, "Jon Snow", c.Get("name"))
+ assert.Equal(t, "Jon Snow", c.Get("name"))
}
func BenchmarkContext_Store(b *testing.B) {
@@ -682,19 +814,19 @@ func TestContextHandler(t *testing.T) {
c := e.NewContext(nil, nil)
r.Find(http.MethodGet, "/handler", c)
err := c.Handler()(c)
- testify.Equal(t, "handler", b.String())
- testify.NoError(t, err)
+ assert.Equal(t, "handler", b.String())
+ assert.NoError(t, err)
}
func TestContext_SetHandler(t *testing.T) {
var c Context = new(context)
- testify.Nil(t, c.Handler())
+ assert.Nil(t, c.Handler())
c.SetHandler(func(c Context) error {
return nil
})
- testify.NotNil(t, c.Handler())
+ assert.NotNil(t, c.Handler())
}
func TestContext_Path(t *testing.T) {
@@ -703,7 +835,7 @@ func TestContext_Path(t *testing.T) {
var c Context = new(context)
c.SetPath(path)
- testify.Equal(t, path, c.Path())
+ assert.Equal(t, path, c.Path())
}
type validator struct{}
@@ -716,10 +848,10 @@ func TestContext_Validate(t *testing.T) {
e := New()
c := e.NewContext(nil, nil)
- testify.Error(t, c.Validate(struct{}{}))
+ assert.Error(t, c.Validate(struct{}{}))
e.Validator = &validator{}
- testify.NoError(t, c.Validate(struct{}{}))
+ assert.NoError(t, c.Validate(struct{}{}))
}
func TestContext_QueryString(t *testing.T) {
@@ -730,18 +862,18 @@ func TestContext_QueryString(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil)
c := e.NewContext(req, nil)
- testify.Equal(t, queryString, c.QueryString())
+ assert.Equal(t, queryString, c.QueryString())
}
func TestContext_Request(t *testing.T) {
var c Context = new(context)
- testify.Nil(t, c.Request())
+ assert.Nil(t, c.Request())
req := httptest.NewRequest(http.MethodGet, "/path", nil)
c.SetRequest(req)
- testify.Equal(t, req, c.Request())
+ assert.Equal(t, req, c.Request())
}
func TestContext_Scheme(t *testing.T) {
@@ -798,14 +930,14 @@ func TestContext_Scheme(t *testing.T) {
}
for _, tt := range tests {
- testify.Equal(t, tt.s, tt.c.Scheme())
+ assert.Equal(t, tt.s, tt.c.Scheme())
}
}
func TestContext_IsWebSocket(t *testing.T) {
tests := []struct {
c Context
- ws testify.BoolAssertionFunc
+ ws assert.BoolAssertionFunc
}{
{
&context{
@@ -813,7 +945,7 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"websocket"}},
},
},
- testify.True,
+ assert.True,
},
{
&context{
@@ -821,13 +953,13 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"Websocket"}},
},
},
- testify.True,
+ assert.True,
},
{
&context{
request: &http.Request{},
},
- testify.False,
+ assert.False,
},
{
&context{
@@ -835,7 +967,7 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"other"}},
},
},
- testify.False,
+ assert.False,
},
}
@@ -854,8 +986,8 @@ func TestContext_Bind(t *testing.T) {
req.Header.Add(HeaderContentType, MIMEApplicationJSON)
err := c.Bind(u)
- testify.NoError(t, err)
- testify.Equal(t, &user{1, "Jon Snow"}, u)
+ assert.NoError(t, err)
+ assert.Equal(t, &user{1, "Jon Snow"}, u)
}
func TestContext_Logger(t *testing.T) {
@@ -863,15 +995,15 @@ func TestContext_Logger(t *testing.T) {
c := e.NewContext(nil, nil)
log1 := c.Logger()
- testify.NotNil(t, log1)
+ assert.NotNil(t, log1)
log2 := log.New("echo2")
c.SetLogger(log2)
- testify.Equal(t, log2, c.Logger())
+ assert.Equal(t, log2, c.Logger())
// Resetting the context returns the initial logger
c.Reset(nil, nil)
- testify.Equal(t, log1, c.Logger())
+ assert.Equal(t, log1, c.Logger())
}
func TestContext_RealIP(t *testing.T) {
@@ -959,6 +1091,6 @@ func TestContext_RealIP(t *testing.T) {
}
for _, tt := range tests {
- testify.Equal(t, tt.s, tt.c.RealIP())
+ assert.Equal(t, tt.s, tt.c.RealIP())
}
}
diff --git a/echo.go b/echo.go
index 8bdf97539..0ac644924 100644
--- a/echo.go
+++ b/echo.go
@@ -259,7 +259,7 @@ const (
const (
// Version of Echo
- Version = "4.11.2"
+ Version = "4.11.3"
website = "https://echo.labstack.com"
// http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo
banner = `
diff --git a/json_test.go b/json_test.go
index 27ee43e73..8fb9ebc96 100644
--- a/json_test.go
+++ b/json_test.go
@@ -1,7 +1,7 @@
package echo
import (
- testify "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"strings"
@@ -16,16 +16,14 @@ func TestDefaultJSONCodec_Encode(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
- assert := testify.New(t)
-
// Echo
- assert.Equal(e, c.Echo())
+ assert.Equal(t, e, c.Echo())
// Request
- assert.NotNil(c.Request())
+ assert.NotNil(t, c.Request())
// Response
- assert.NotNil(c.Response())
+ assert.NotNil(t, c.Response())
//--------
// Default JSON encoder
@@ -34,16 +32,16 @@ func TestDefaultJSONCodec_Encode(t *testing.T) {
enc := new(DefaultJSONSerializer)
err := enc.Serialize(c, user{1, "Jon Snow"}, "")
- if assert.NoError(err) {
- assert.Equal(userJSON+"\n", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, userJSON+"\n", rec.Body.String())
}
req = httptest.NewRequest(http.MethodPost, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = enc.Serialize(c, user{1, "Jon Snow"}, " ")
- if assert.NoError(err) {
- assert.Equal(userJSONPretty+"\n", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
}
}
@@ -55,16 +53,14 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
- assert := testify.New(t)
-
// Echo
- assert.Equal(e, c.Echo())
+ assert.Equal(t, e, c.Echo())
// Request
- assert.NotNil(c.Request())
+ assert.NotNil(t, c.Request())
// Response
- assert.NotNil(c.Response())
+ assert.NotNil(t, c.Response())
//--------
// Default JSON encoder
@@ -74,8 +70,8 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
var u = user{}
err := enc.Deserialize(c, &u)
- if assert.NoError(err) {
- assert.Equal(u, user{ID: 1, Name: "Jon Snow"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, u, user{ID: 1, Name: "Jon Snow"})
}
var userUnmarshalSyntaxError = user{}
@@ -83,8 +79,8 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = enc.Deserialize(c, &userUnmarshalSyntaxError)
- assert.IsType(&HTTPError{}, err)
- assert.EqualError(err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value")
+ assert.IsType(t, &HTTPError{}, err)
+ assert.EqualError(t, err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value")
var userUnmarshalTypeError = struct {
ID string `json:"id"`
@@ -95,7 +91,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = enc.Deserialize(c, &userUnmarshalTypeError)
- assert.IsType(&HTTPError{}, err)
- assert.EqualError(err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string")
+ assert.IsType(t, &HTTPError{}, err)
+ assert.EqualError(t, err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string")
}
diff --git a/middleware/cors.go b/middleware/cors.go
index 10504359f..7ace2f224 100644
--- a/middleware/cors.go
+++ b/middleware/cors.go
@@ -39,7 +39,7 @@ type (
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// Optional.
- AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
+ AllowOriginFunc func(origin string) (bool, error) `yaml:"-"`
// AllowMethods determines the value of the Access-Control-Allow-Methods
// response header. This header specified the list of methods allowed when
diff --git a/middleware/proxy.go b/middleware/proxy.go
index e4f98d9ed..16b00d645 100644
--- a/middleware/proxy.go
+++ b/middleware/proxy.go
@@ -359,6 +359,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
c.Set("_error", nil)
}
+ // This is needed for ProxyConfig.ModifyResponse and/or ProxyConfig.Transport to be able to process the Request
+ // that Balancer may have replaced with c.SetRequest.
+ req = c.Request()
+
// Proxy
switch {
case c.IsWebSocket():
diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go
index 415d68e77..1c93ba031 100644
--- a/middleware/proxy_test.go
+++ b/middleware/proxy_test.go
@@ -747,3 +747,63 @@ func TestProxyBalancerWithNoTargets(t *testing.T) {
rrb := NewRoundRobinBalancer([]*ProxyTarget{})
assert.Nil(t, rrb.Next(nil))
}
+
+type testContextKey string
+
+type customBalancer struct {
+ target *ProxyTarget
+}
+
+func (b *customBalancer) AddTarget(target *ProxyTarget) bool {
+ return false
+}
+
+func (b *customBalancer) RemoveTarget(name string) bool {
+ return false
+}
+
+func (b *customBalancer) Next(c echo.Context) *ProxyTarget {
+ ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER")
+ c.SetRequest(c.Request().WithContext(ctx))
+ return b.target
+}
+
+func TestModifyResponseUseContext(t *testing.T) {
+ server := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte("OK"))
+ }),
+ )
+ defer server.Close()
+
+ targetURL, _ := url.Parse(server.URL)
+ e := echo.New()
+ e.Use(ProxyWithConfig(
+ ProxyConfig{
+ Balancer: &customBalancer{
+ target: &ProxyTarget{
+ Name: "tst",
+ URL: targetURL,
+ },
+ },
+ RetryCount: 1,
+ ModifyResponse: func(res *http.Response) error {
+ val := res.Request.Context().Value(testContextKey("FROM_BALANCER"))
+ if valStr, ok := val.(string); ok {
+ res.Header.Set("FROM_BALANCER", valStr)
+ }
+ return nil
+ },
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "OK", rec.Body.String())
+ assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
+}
diff --git a/middleware/rewrite.go b/middleware/rewrite.go
index e5b0a6b56..2090eac04 100644
--- a/middleware/rewrite.go
+++ b/middleware/rewrite.go
@@ -27,7 +27,7 @@ type (
// Example:
// "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1",
- RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"`
+ RegexRules map[*regexp.Regexp]string `yaml:"-"`
}
)