Skip to content

Commit 05f1ef7

Browse files
authored
fix(middleware): add missing return in RouteHeaders empty check (#1045)
The RouteHeaders middleware was missing a return statement after calling next.ServeHTTP when the router had no routes configured. This caused the next handler to be called twice - once in the empty check and again at the end of the function. Also adds comprehensive test coverage for the RouteHeaders middleware and Pattern matching functionality.
1 parent 6eb3588 commit 05f1ef7

2 files changed

Lines changed: 213 additions & 0 deletions

File tree

middleware/route_headers.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ func (hr HeaderRouter) Handler(next http.Handler) http.Handler {
7979
if len(hr) == 0 {
8080
// skip if no routes set
8181
next.ServeHTTP(w, r)
82+
return
8283
}
8384

8485
// find first matching header route, and continue

middleware/route_headers_test.go

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
package middleware
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"sync/atomic"
7+
"testing"
8+
)
9+
10+
func TestRouteHeaders(t *testing.T) {
11+
t.Run("empty router should call next handler exactly once", func(t *testing.T) {
12+
var callCount atomic.Int32
13+
14+
hr := RouteHeaders()
15+
16+
handler := hr.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
17+
callCount.Add(1)
18+
w.WriteHeader(http.StatusOK)
19+
}))
20+
21+
req := httptest.NewRequest("GET", "/", nil)
22+
rec := httptest.NewRecorder()
23+
24+
handler.ServeHTTP(rec, req)
25+
26+
if callCount.Load() != 1 {
27+
t.Errorf("expected next handler to be called exactly once, but was called %d times", callCount.Load())
28+
}
29+
})
30+
31+
t.Run("matching header should route to correct middleware", func(t *testing.T) {
32+
var matchedRoute string
33+
34+
hr := RouteHeaders().
35+
Route("Host", "example.com", func(next http.Handler) http.Handler {
36+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
37+
matchedRoute = "example.com"
38+
next.ServeHTTP(w, r)
39+
})
40+
}).
41+
Route("Host", "other.com", func(next http.Handler) http.Handler {
42+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
43+
matchedRoute = "other.com"
44+
next.ServeHTTP(w, r)
45+
})
46+
})
47+
48+
handler := hr.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49+
w.WriteHeader(http.StatusOK)
50+
}))
51+
52+
req := httptest.NewRequest("GET", "/", nil)
53+
req.Host = "example.com"
54+
req.Header.Set("Host", "example.com")
55+
rec := httptest.NewRecorder()
56+
57+
handler.ServeHTTP(rec, req)
58+
59+
if matchedRoute != "example.com" {
60+
t.Errorf("expected matched route to be 'example.com', got '%s'", matchedRoute)
61+
}
62+
})
63+
64+
t.Run("wildcard pattern should match", func(t *testing.T) {
65+
var matched bool
66+
67+
hr := RouteHeaders().
68+
Route("Host", "*.example.com", func(next http.Handler) http.Handler {
69+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
70+
matched = true
71+
next.ServeHTTP(w, r)
72+
})
73+
})
74+
75+
handler := hr.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
76+
w.WriteHeader(http.StatusOK)
77+
}))
78+
79+
req := httptest.NewRequest("GET", "/", nil)
80+
req.Header.Set("Host", "api.example.com")
81+
rec := httptest.NewRecorder()
82+
83+
handler.ServeHTTP(rec, req)
84+
85+
if !matched {
86+
t.Error("expected wildcard pattern to match")
87+
}
88+
})
89+
90+
t.Run("default route should be used when no match", func(t *testing.T) {
91+
var usedDefault bool
92+
93+
hr := RouteHeaders().
94+
Route("Host", "example.com", func(next http.Handler) http.Handler {
95+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
96+
next.ServeHTTP(w, r)
97+
})
98+
}).
99+
RouteDefault(func(next http.Handler) http.Handler {
100+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
101+
usedDefault = true
102+
next.ServeHTTP(w, r)
103+
})
104+
})
105+
106+
handler := hr.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
107+
w.WriteHeader(http.StatusOK)
108+
}))
109+
110+
req := httptest.NewRequest("GET", "/", nil)
111+
req.Header.Set("Host", "other.com")
112+
rec := httptest.NewRecorder()
113+
114+
handler.ServeHTTP(rec, req)
115+
116+
if !usedDefault {
117+
t.Error("expected default route to be used when no match")
118+
}
119+
})
120+
121+
t.Run("RouteAny should match any of the provided patterns", func(t *testing.T) {
122+
var matched bool
123+
124+
hr := RouteHeaders().
125+
RouteAny("Content-Type", []string{"application/json", "application/xml"}, func(next http.Handler) http.Handler {
126+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
127+
matched = true
128+
next.ServeHTTP(w, r)
129+
})
130+
})
131+
132+
handler := hr.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
133+
w.WriteHeader(http.StatusOK)
134+
}))
135+
136+
// Test with application/json
137+
req := httptest.NewRequest("POST", "/", nil)
138+
req.Header.Set("Content-Type", "application/json")
139+
rec := httptest.NewRecorder()
140+
141+
handler.ServeHTTP(rec, req)
142+
143+
if !matched {
144+
t.Error("expected RouteAny to match 'application/json'")
145+
}
146+
147+
// Reset and test with application/xml
148+
matched = false
149+
req = httptest.NewRequest("POST", "/", nil)
150+
req.Header.Set("Content-Type", "application/xml")
151+
rec = httptest.NewRecorder()
152+
153+
handler.ServeHTTP(rec, req)
154+
155+
if !matched {
156+
t.Error("expected RouteAny to match 'application/xml'")
157+
}
158+
})
159+
160+
t.Run("no match and no default should call next handler", func(t *testing.T) {
161+
var nextCalled bool
162+
163+
hr := RouteHeaders().
164+
Route("Host", "example.com", func(next http.Handler) http.Handler {
165+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
166+
next.ServeHTTP(w, r)
167+
})
168+
})
169+
170+
handler := hr.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
171+
nextCalled = true
172+
w.WriteHeader(http.StatusOK)
173+
}))
174+
175+
req := httptest.NewRequest("GET", "/", nil)
176+
req.Header.Set("Host", "other.com")
177+
rec := httptest.NewRecorder()
178+
179+
handler.ServeHTTP(rec, req)
180+
181+
if !nextCalled {
182+
t.Error("expected next handler to be called when no match and no default")
183+
}
184+
})
185+
}
186+
187+
func TestPattern(t *testing.T) {
188+
tests := []struct {
189+
pattern string
190+
value string
191+
expected bool
192+
}{
193+
{"example.com", "example.com", true},
194+
{"example.com", "other.com", false},
195+
{"*.example.com", "api.example.com", true},
196+
{"*.example.com", "example.com", false},
197+
{"api.*", "api.example.com", true},
198+
{"*", "anything", true},
199+
{"prefix*suffix", "prefixmiddlesuffix", true},
200+
{"prefix*suffix", "prefixsuffix", true},
201+
{"prefix*suffix", "wrongmiddlesuffix", false},
202+
}
203+
204+
for _, tt := range tests {
205+
t.Run(tt.pattern+"_"+tt.value, func(t *testing.T) {
206+
p := NewPattern(tt.pattern)
207+
if got := p.Match(tt.value); got != tt.expected {
208+
t.Errorf("Pattern(%q).Match(%q) = %v, want %v", tt.pattern, tt.value, got, tt.expected)
209+
}
210+
})
211+
}
212+
}

0 commit comments

Comments
 (0)