Skip to content

Commit 3dee8f5

Browse files
pcaillaudmory-bot
authored andcommitted
feat: add captcha strategy for recovery flow
GitOrigin-RevId: 8bf87d0de35855b50cd0c995329b4e40520a4c2b
1 parent 071ad54 commit 3dee8f5

18 files changed

Lines changed: 78 additions & 53 deletions

driver/registry_default_recovery.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ func (m *RegistryDefault) RecoveryStrategies(ctx context.Context) (recoveryStrat
4141
return
4242
}
4343

44-
// GetActiveRecoveryStrategy returns the currently active recovery strategy
45-
// If no recovery strategy has been set, an error is returned
46-
func (m *RegistryDefault) GetActiveRecoveryStrategy(ctx context.Context) (recovery.Strategy, error) {
44+
// GetActiveRecoveryStrategies returns the currently active recovery strategies
45+
// If no primary recovery strategy has been set, an error is returned
46+
func (m *RegistryDefault) GetActiveRecoveryStrategies(ctx context.Context) (recovery.Strategies, error) {
4747
as := m.Config().SelfServiceFlowRecoveryUse(ctx)
48-
s, err := m.RecoveryStrategies(ctx).Strategy(as)
48+
s, err := m.RecoveryStrategies(ctx).ActiveStrategies(as)
4949
if err != nil {
5050
return nil, errors.WithStack(herodot.ErrBadRequest.
5151
WithReasonf("You attempted recovery using %s, which is not enabled or does not exist. An administrator needs to enable this recovery method.", as))

driver/registry_default_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -917,11 +917,11 @@ func TestGetActiveRecoveryStrategy(t *testing.T) {
917917
config.ViperKeySelfServiceRecoveryUse: "code",
918918
})
919919

920-
_, err := reg.GetActiveRecoveryStrategy(ctx)
920+
_, err := reg.GetActiveRecoveryStrategies(ctx)
921921
require.Error(t, err)
922922
})
923923

924-
t.Run("returns active strategy", func(t *testing.T) {
924+
t.Run("returns active strategies", func(t *testing.T) {
925925
for _, sID := range []string{
926926
"code", "link",
927927
} {
@@ -931,9 +931,10 @@ func TestGetActiveRecoveryStrategy(t *testing.T) {
931931
config.ViperKeySelfServiceRecoveryUse: sID,
932932
})
933933

934-
s, err := reg.GetActiveRecoveryStrategy(ctx)
934+
s, err := reg.GetActiveRecoveryStrategies(ctx)
935935
require.NoError(t, err)
936-
require.Equal(t, sID, s.RecoveryStrategyID())
936+
require.Len(t, s, 1)
937+
require.Equal(t, sID, s[0].RecoveryStrategyID())
937938
})
938939
}
939940
})

pkg/testhelpers/selfservice_verification.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,10 @@ func SubmitRecoveryForm(
227227
return b
228228
}
229229

230-
func PersistNewRecoveryFlow(t *testing.T, strategy recovery.Strategy, conf *config.Config, reg *driver.RegistryDefault) *recovery.Flow {
230+
func PersistNewRecoveryFlow(t *testing.T, strategies recovery.Strategies, conf *config.Config, reg *driver.RegistryDefault) *recovery.Flow {
231231
t.Helper()
232232
req := NewTestHTTPRequest(t, "GET", conf.SelfPublicURL(context.Background()).String()+"/test", nil)
233-
f, err := recovery.NewFlow(conf, conf.SelfServiceFlowRecoveryRequestLifespan(context.Background()), reg.GenerateCSRFToken(req), req, strategy, flow.TypeBrowser)
233+
f, err := recovery.NewFlow(conf, conf.SelfServiceFlowRecoveryRequestLifespan(context.Background()), reg.GenerateCSRFToken(req), req, strategies, flow.TypeBrowser)
234234
require.NoError(t, err, "Expected no error when creating a new recovery flow: %s", err)
235235

236236
err = reg.RecoveryFlowPersister().CreateRecoveryFlow(context.Background(), f)

selfservice/flow/recovery/error.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,17 @@ func (s *ErrorHandler) WriteFlowError(
8888
trace.SpanFromContext(r.Context()).AddEvent(events.NewRecoveryFailed(r.Context(), f.ID, string(f.Type), f.Active.String(), recoveryErr))
8989

9090
if expiredError := new(flow.ExpiredError); errors.As(recoveryErr, &expiredError) {
91-
strategy, err := s.d.RecoveryStrategies(r.Context()).Strategy(f.Active.String())
91+
strategies, err := s.d.RecoveryStrategies(r.Context()).ActiveStrategies(f.Active.String())
9292
if err != nil {
93-
strategy, err = s.d.GetActiveRecoveryStrategy(r.Context())
94-
// Can't retry the recovery if no strategy has been set
93+
strategies, err = s.d.GetActiveRecoveryStrategies(r.Context())
94+
// Can't retry the recovery if no primary strategy has been set
9595
if err != nil {
9696
s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
9797
return
9898
}
9999
}
100100
// create new flow because the old one is not valid
101-
newFlow, err := FromOldFlow(s.d.Config(), s.d.Config().SelfServiceFlowRecoveryRequestLifespan(r.Context()), s.d.GenerateCSRFToken(r), r, strategy, *f)
101+
newFlow, err := FromOldFlow(s.d.Config(), s.d.Config().SelfServiceFlowRecoveryRequestLifespan(r.Context()), s.d.GenerateCSRFToken(r), r, strategies, *f)
102102
if err != nil {
103103
// failed to create a new session and redirect to it, handle that error as a new one
104104
s.WriteFlowError(w, r, f, group, err)

selfservice/flow/recovery/error_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func TestHandleError(t *testing.T) {
7474

7575
newFlow := func(t *testing.T, ttl time.Duration, ft flow.Type) *recovery.Flow {
7676
req := &http.Request{URL: urlx.ParseOrPanic("/")}
77-
s, err := reg.GetActiveRecoveryStrategy(context.Background())
77+
s, err := reg.GetActiveRecoveryStrategies(context.Background())
7878
require.NoError(t, err)
7979
f, err := recovery.NewFlow(conf, ttl, nosurfx.FakeCSRFToken, req, s, ft)
8080
require.NoError(t, err)
@@ -191,7 +191,7 @@ func TestHandleError(t *testing.T) {
191191
c, reg := pkg.NewVeryFastRegistryWithoutDB(t)
192192
require.NoError(t, c.Set(context.Background(), "selfservice.methods.code.enabled", false))
193193
require.NoError(t, c.Set(context.Background(), config.ViperKeySelfServiceRecoveryUse, "code"))
194-
_, err := reg.GetActiveRecoveryStrategy(context.Background())
194+
_, err := reg.GetActiveRecoveryStrategies(context.Background())
195195
recoveryFlow = newFlow(t, time.Minute, tc.t)
196196
flowError = err
197197
methodName = node.UiNodeGroup(recovery.RecoveryStrategyLink)
@@ -337,7 +337,7 @@ func TestHandleError_WithContinueWith(t *testing.T) {
337337

338338
newFlow := func(t *testing.T, ttl time.Duration, ft flow.Type) *recovery.Flow {
339339
req := &http.Request{URL: urlx.ParseOrPanic("/")}
340-
s, err := reg.GetActiveRecoveryStrategy(context.Background())
340+
s, err := reg.GetActiveRecoveryStrategies(context.Background())
341341
require.NoError(t, err)
342342
f, err := recovery.NewFlow(conf, ttl, nosurfx.FakeCSRFToken, req, s, ft)
343343
require.NoError(t, err)
@@ -462,7 +462,7 @@ func TestHandleError_WithContinueWith(t *testing.T) {
462462
c, reg := pkg.NewVeryFastRegistryWithoutDB(t)
463463
require.NoError(t, c.Set(context.Background(), "selfservice.methods.code.enabled", false))
464464
require.NoError(t, c.Set(context.Background(), config.ViperKeySelfServiceRecoveryUse, "code"))
465-
_, err := reg.GetActiveRecoveryStrategy(context.Background())
465+
_, err := reg.GetActiveRecoveryStrategies(context.Background())
466466
recoveryFlow = newFlow(t, time.Minute, tc.t)
467467
flowError = err
468468
methodName = node.UiNodeGroup(recovery.RecoveryStrategyLink)

selfservice/flow/recovery/flow.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ type Flow struct {
116116

117117
var _ flow.Flow = (*Flow)(nil)
118118

119-
func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Request, strategy Strategy, ft flow.Type) (*Flow, error) {
119+
func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Request, strategies Strategies, ft flow.Type) (*Flow, error) {
120120
now := time.Now().UTC()
121121
id := x.NewUUID()
122122

@@ -153,8 +153,10 @@ func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Reques
153153
Type: ft,
154154
}
155155

156-
if strategy != nil {
157-
f.Active = sqlxx.NullString(strategy.NodeGroup())
156+
for _, strategy := range strategies {
157+
if strategy.IsPrimary() {
158+
f.Active = sqlxx.NullString(strategy.NodeGroup())
159+
}
158160
if err := strategy.PopulateRecoveryMethod(r, f); err != nil {
159161
return nil, err
160162
}
@@ -163,13 +165,13 @@ func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Reques
163165
return f, nil
164166
}
165167

166-
func FromOldFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Request, strategy Strategy, of Flow) (*Flow, error) {
168+
func FromOldFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Request, strategies Strategies, of Flow) (*Flow, error) {
167169
f := of.Type
168170
// Using the same flow in the recovery/verification context can lead to using API flow in a verification/recovery email
169171
if of.Type == flow.TypeAPI && of.Active.String() == string(RecoveryStrategyLink) {
170172
f = flow.TypeBrowser
171173
}
172-
nf, err := NewFlow(conf, exp, csrf, r, strategy, f)
174+
nf, err := NewFlow(conf, exp, csrf, r, strategies, f)
173175
if err != nil {
174176
return nil, err
175177
}

selfservice/flow/recovery/flow_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ func TestFromOldFlow(t *testing.T) {
9898
flow.TypeBrowser,
9999
} {
100100
t.Run(fmt.Sprintf("case=original flow is %s", ft), func(t *testing.T) {
101-
f, err := recovery.NewFlow(conf, 0, "csrf", &r, code.NewStrategy(reg), ft)
101+
f, err := recovery.NewFlow(conf, 0, "csrf", &r, recovery.Strategies{code.NewStrategy(reg)}, ft)
102102
require.NoError(t, err)
103103
nF, err := recovery.FromOldFlow(conf, time.Duration(time.Hour), f.CSRFToken, &r, nil, *f)
104104
require.NoError(t, err)
@@ -113,7 +113,7 @@ func TestFromOldFlow(t *testing.T) {
113113
flow.TypeBrowser,
114114
} {
115115
t.Run(fmt.Sprintf("case=original flow is %s", ft), func(t *testing.T) {
116-
f, err := recovery.NewFlow(conf, 0, "csrf", &r, link.NewStrategy(reg), ft)
116+
f, err := recovery.NewFlow(conf, 0, "csrf", &r, recovery.Strategies{link.NewStrategy(reg)}, ft)
117117
require.NoError(t, err)
118118
nF, err := recovery.FromOldFlow(conf, time.Duration(time.Hour), f.CSRFToken, &r, nil, *f)
119119
require.NoError(t, err)

selfservice/flow/recovery/handler.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,13 @@ func (h *Handler) createNativeRecoveryFlow(w http.ResponseWriter, r *http.Reques
122122
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Recovery is not allowed because it was disabled.")))
123123
return
124124
}
125-
activeRecoveryStrategy, err := h.d.GetActiveRecoveryStrategy(r.Context())
125+
activeRecoveryStrategies, err := h.d.GetActiveRecoveryStrategies(r.Context())
126126
if err != nil {
127127
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
128128
return
129129
}
130130

131-
f, err := NewFlow(h.d.Config(), h.d.Config().SelfServiceFlowRecoveryRequestLifespan(r.Context()), h.d.GenerateCSRFToken(r), r, activeRecoveryStrategy, flow.TypeAPI)
131+
f, err := NewFlow(h.d.Config(), h.d.Config().SelfServiceFlowRecoveryRequestLifespan(r.Context()), h.d.GenerateCSRFToken(r), r, activeRecoveryStrategies, flow.TypeAPI)
132132
if err != nil {
133133
h.d.Writer().WriteError(w, r, err)
134134
return
@@ -190,13 +190,13 @@ func (h *Handler) createBrowserRecoveryFlow(w http.ResponseWriter, r *http.Reque
190190
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Recovery is not allowed because it was disabled.")))
191191
return
192192
}
193-
activeRecoveryStrategy, err := h.d.GetActiveRecoveryStrategy(r.Context())
193+
activeRecoveryStrategies, err := h.d.GetActiveRecoveryStrategies(r.Context())
194194
if err != nil {
195195
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
196196
return
197197
}
198198

199-
f, err := NewFlow(h.d.Config(), h.d.Config().SelfServiceFlowRecoveryRequestLifespan(r.Context()), h.d.GenerateCSRFToken(r), r, activeRecoveryStrategy, flow.TypeBrowser)
199+
f, err := NewFlow(h.d.Config(), h.d.Config().SelfServiceFlowRecoveryRequestLifespan(r.Context()), h.d.GenerateCSRFToken(r), r, activeRecoveryStrategies, flow.TypeBrowser)
200200
if err != nil {
201201
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
202202
return

selfservice/flow/recovery/hook_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func TestRecoveryExecutor(t *testing.T) {
3434
conf, reg := pkg.NewFastRegistryWithMocks(t,
3535
configx.WithValues(testhelpers.DefaultIdentitySchemaConfig("file://./stub/identity.schema.json")),
3636
)
37-
s := code.NewStrategy(reg)
37+
s := recovery.Strategies{code.NewStrategy(reg)}
3838

3939
newServer := func(t *testing.T, i *identity.Identity, ft flow.Type) *httptest.Server {
4040
router := http.NewServeMux()

selfservice/flow/recovery/strategy.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ const (
2424
type (
2525
Strategy interface {
2626
RecoveryStrategyID() string
27+
IsPrimary() bool
2728
NodeGroup() node.UiNodeGroup
2829
PopulateRecoveryMethod(*http.Request, *Flow) error
2930
Recover(w http.ResponseWriter, r *http.Request, f *Flow) (err error)
@@ -32,18 +33,27 @@ type (
3233
StrategyProvider interface {
3334
AllRecoveryStrategies() Strategies
3435
RecoveryStrategies(ctx context.Context) Strategies
35-
GetActiveRecoveryStrategy(ctx context.Context) (Strategy, error)
36+
GetActiveRecoveryStrategies(ctx context.Context) (Strategies, error)
3637
}
3738
)
3839

39-
func (s Strategies) Strategy(id string) (Strategy, error) {
40+
func (s Strategies) ActiveStrategies(id string) (Strategies, error) {
4041
ids := make([]string, len(s))
42+
activeStrategies := Strategies{}
43+
foundPrimary := false
4144
for k, ss := range s {
4245
ids[k] = ss.RecoveryStrategyID()
43-
if ss.RecoveryStrategyID() == id {
44-
return ss, nil
46+
if ss.RecoveryStrategyID() == id || !ss.IsPrimary() {
47+
activeStrategies = append(activeStrategies, ss)
48+
if ss.IsPrimary() {
49+
foundPrimary = true
50+
}
4551
}
4652
}
4753

48-
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("unable to find strategy for %s have %v", id, ids))
54+
if !foundPrimary {
55+
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("unable to find strategy for %s have %v", id, ids))
56+
}
57+
58+
return activeStrategies, nil
4959
}

0 commit comments

Comments
 (0)