Skip to content

Commit b94f4c9

Browse files
pcaillaudmory-bot
authored andcommitted
feat: update GetActiveRecoveryStrategies method
GitOrigin-RevId: ae8121e4488d8fc332bea1bcea704b15c2cd7987
1 parent 420f69d commit b94f4c9

12 files changed

Lines changed: 36 additions & 37 deletions

driver/registry_default_recovery.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,17 @@ func (m *RegistryDefault) RecoveryStrategies(ctx context.Context) (recoveryStrat
4141
return
4242
}
4343

44-
// GetActiveRecoveryStrategies returns the currently active recovery strategies
45-
// If no primary recovery strategy has been set, an error is returned
46-
func (m *RegistryDefault) GetActiveRecoveryStrategies(ctx context.Context) (recovery.Strategies, error) {
44+
// GetActiveRecoveryStrategies returns the currently active recovery strategies.
45+
// It returns a list of all strategies and the specific primary strategy.
46+
// If no primary recovery strategy has been set, an error is returned.
47+
func (m *RegistryDefault) GetActiveRecoveryStrategies(ctx context.Context) (active recovery.Strategies, primary recovery.Strategy, err error) {
4748
as := m.Config().SelfServiceFlowRecoveryUse(ctx)
48-
s, err := m.RecoveryStrategies(ctx).ActiveStrategies(as)
49+
s, ps, err := m.RecoveryStrategies(ctx).ActiveStrategies(as)
4950
if err != nil {
50-
return nil, errors.WithStack(herodot.ErrBadRequest.
51+
return nil, ps, errors.WithStack(herodot.ErrBadRequest.
5152
WithReasonf("You attempted recovery using %s, which is not enabled or does not exist. An administrator needs to enable this recovery method.", as))
5253
}
53-
return s, nil
54+
return s, ps, nil
5455
}
5556

5657
func (m *RegistryDefault) AllRecoveryStrategies() (recoveryStrategies recovery.Strategies) {

driver/registry_default_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,7 @@ func TestGetActiveRecoveryStrategy(t *testing.T) {
917917
config.ViperKeySelfServiceRecoveryUse: "code",
918918
})
919919

920-
_, err := reg.GetActiveRecoveryStrategies(ctx)
920+
_, _, err := reg.GetActiveRecoveryStrategies(ctx)
921921
require.Error(t, err)
922922
})
923923

@@ -931,10 +931,10 @@ func TestGetActiveRecoveryStrategy(t *testing.T) {
931931
config.ViperKeySelfServiceRecoveryUse: sID,
932932
})
933933

934-
s, err := reg.GetActiveRecoveryStrategies(ctx)
934+
s, ps, err := reg.GetActiveRecoveryStrategies(ctx)
935935
require.NoError(t, err)
936936
require.Len(t, s, 1)
937-
require.Equal(t, sID, s[0].RecoveryStrategyID())
937+
require.Equal(t, sID, ps.RecoveryStrategyID())
938938
})
939939
}
940940
})

driver/registry_default_verification.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func (m *RegistryDefault) GetActiveVerificationStrategies(ctx context.Context) (
5858
as := m.Config().SelfServiceFlowVerificationUse(ctx)
5959
s, ps, err := m.VerificationStrategies(ctx).ActiveStrategies(as)
6060
if err != nil {
61-
return nil, nil, errors.WithStack(herodot.ErrBadRequest.
61+
return nil, ps, errors.WithStack(herodot.ErrBadRequest.
6262
WithReasonf("The active verification strategy %s is not enabled. Please enable it in the configuration.", as))
6363
}
6464
return s, ps, nil

selfservice/flow/recovery/error.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ func (s *ErrorHandler) WriteFlowError(
8888
trace.SpanFromContext(r.Context()).AddEvent(events.NewRecoveryFailed(r.Context(), f.ID, string(f.Type), f.Active.String(), recoveryErr))
8989

9090
if expiredError := new(flow.ExpiredError); errors.As(recoveryErr, &expiredError) {
91-
strategies, err := s.d.RecoveryStrategies(r.Context()).ActiveStrategies(f.Active.String())
91+
strategies, _, err := s.d.RecoveryStrategies(r.Context()).ActiveStrategies(f.Active.String())
9292
if err != nil {
93-
strategies, err = s.d.GetActiveRecoveryStrategies(r.Context())
93+
strategies, _, err = s.d.GetActiveRecoveryStrategies(r.Context())
9494
// Can't retry the recovery if no primary strategy has been set
9595
if err != nil {
9696
s.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)

selfservice/flow/recovery/error_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func TestHandleError(t *testing.T) {
7474

7575
newFlow := func(t *testing.T, ttl time.Duration, ft flow.Type) *recovery.Flow {
7676
req := &http.Request{URL: urlx.ParseOrPanic("/")}
77-
s, err := reg.GetActiveRecoveryStrategies(context.Background())
77+
s, _, err := reg.GetActiveRecoveryStrategies(context.Background())
7878
require.NoError(t, err)
7979
f, err := recovery.NewFlow(conf, ttl, nosurfx.FakeCSRFToken, req, s, ft)
8080
require.NoError(t, err)
@@ -191,7 +191,7 @@ func TestHandleError(t *testing.T) {
191191
c, reg := pkg.NewVeryFastRegistryWithoutDB(t)
192192
require.NoError(t, c.Set(context.Background(), "selfservice.methods.code.enabled", false))
193193
require.NoError(t, c.Set(context.Background(), config.ViperKeySelfServiceRecoveryUse, "code"))
194-
_, err := reg.GetActiveRecoveryStrategies(context.Background())
194+
_, _, err := reg.GetActiveRecoveryStrategies(context.Background())
195195
recoveryFlow = newFlow(t, time.Minute, tc.t)
196196
flowError = err
197197
methodName = node.UiNodeGroup(recovery.RecoveryStrategyLink)
@@ -337,7 +337,7 @@ func TestHandleError_WithContinueWith(t *testing.T) {
337337

338338
newFlow := func(t *testing.T, ttl time.Duration, ft flow.Type) *recovery.Flow {
339339
req := &http.Request{URL: urlx.ParseOrPanic("/")}
340-
s, err := reg.GetActiveRecoveryStrategies(context.Background())
340+
s, _, err := reg.GetActiveRecoveryStrategies(context.Background())
341341
require.NoError(t, err)
342342
f, err := recovery.NewFlow(conf, ttl, nosurfx.FakeCSRFToken, req, s, ft)
343343
require.NoError(t, err)
@@ -462,7 +462,7 @@ func TestHandleError_WithContinueWith(t *testing.T) {
462462
c, reg := pkg.NewVeryFastRegistryWithoutDB(t)
463463
require.NoError(t, c.Set(context.Background(), "selfservice.methods.code.enabled", false))
464464
require.NoError(t, c.Set(context.Background(), config.ViperKeySelfServiceRecoveryUse, "code"))
465-
_, err := reg.GetActiveRecoveryStrategies(context.Background())
465+
_, _, err := reg.GetActiveRecoveryStrategies(context.Background())
466466
recoveryFlow = newFlow(t, time.Minute, tc.t)
467467
flowError = err
468468
methodName = node.UiNodeGroup(recovery.RecoveryStrategyLink)

selfservice/flow/recovery/handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ func (h *Handler) createNativeRecoveryFlow(w http.ResponseWriter, r *http.Reques
122122
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Recovery is not allowed because it was disabled.")))
123123
return
124124
}
125-
activeRecoveryStrategies, err := h.d.GetActiveRecoveryStrategies(r.Context())
125+
activeRecoveryStrategies, _, err := h.d.GetActiveRecoveryStrategies(r.Context())
126126
if err != nil {
127127
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
128128
return
@@ -190,7 +190,7 @@ func (h *Handler) createBrowserRecoveryFlow(w http.ResponseWriter, r *http.Reque
190190
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Recovery is not allowed because it was disabled.")))
191191
return
192192
}
193-
activeRecoveryStrategies, err := h.d.GetActiveRecoveryStrategies(r.Context())
193+
activeRecoveryStrategies, _, err := h.d.GetActiveRecoveryStrategies(r.Context())
194194
if err != nil {
195195
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
196196
return

selfservice/flow/recovery/strategy.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,25 @@ type (
3333
StrategyProvider interface {
3434
AllRecoveryStrategies() Strategies
3535
RecoveryStrategies(ctx context.Context) Strategies
36-
GetActiveRecoveryStrategies(ctx context.Context) (Strategies, error)
36+
GetActiveRecoveryStrategies(ctx context.Context) (active Strategies, primary Strategy, err error)
3737
}
3838
)
3939

40-
func (s Strategies) ActiveStrategies(id string) (Strategies, error) {
40+
func (s Strategies) ActiveStrategies(id string) (active Strategies, primary Strategy, err error) {
4141
ids := make([]string, len(s))
42-
activeStrategies := Strategies{}
43-
foundPrimary := false
4442
for k, ss := range s {
4543
ids[k] = ss.RecoveryStrategyID()
4644
if ss.RecoveryStrategyID() == id || !ss.IsPrimary() {
47-
activeStrategies = append(activeStrategies, ss)
45+
active = append(active, ss)
4846
if ss.IsPrimary() {
49-
foundPrimary = true
47+
primary = ss
5048
}
5149
}
5250
}
5351

54-
if !foundPrimary {
55-
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("unable to find strategy for %s have %v", id, ids))
52+
if primary == nil {
53+
return nil, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("unable to find strategy for %s have %v", id, ids))
5654
}
5755

58-
return activeStrategies, nil
56+
return active, primary, nil
5957
}

selfservice/hook/require_verified_address.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ func (e *AddressVerifier) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Re
106106

107107
verificationFlow.State = flow.StateEmailSent
108108
for _, strategy := range strategies {
109+
if strategy.IsPrimary() {
110+
verificationFlow.Active = sqlxx.NullString(strategy.NodeGroup())
111+
}
109112
if err := strategy.PopulateVerificationMethod(r, verificationFlow); err != nil {
110113
return err
111114
}
@@ -119,7 +122,6 @@ func (e *AddressVerifier) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Re
119122
WithMetaLabel(text.NewInfoNodeResendOTP()),
120123
)
121124

122-
verificationFlow.Active = sqlxx.NullString(strategy.NodeGroup())
123125
if err := e.r.VerificationFlowPersister().CreateVerificationFlow(ctx, verificationFlow); err != nil {
124126
return err
125127
}

selfservice/strategy/code/code_sender_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ func TestSender(t *testing.T) {
250250
flow: "recovery",
251251
configKey: config.ViperKeySelfServiceRecoveryNotifyUnknownRecipients,
252252
send: func(t *testing.T) {
253-
s, err := reg.RecoveryStrategies(ctx).ActiveStrategies("code")
253+
s, _, err := reg.RecoveryStrategies(ctx).ActiveStrategies("code")
254254
require.NoError(t, err)
255255
f, err := recovery.NewFlow(conf, time.Hour, "", u, s, flow.TypeBrowser)
256256
require.NoError(t, err)

selfservice/strategy/code/strategy_recovery_admin_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ func TestAdminStrategy(t *testing.T) {
6161

6262
t.Run("no panic on empty body #1384", func(t *testing.T) {
6363
ctx := context.Background()
64-
s, err := reg.RecoveryStrategies(ctx).ActiveStrategies("code")
64+
s, ps, err := reg.RecoveryStrategies(ctx).ActiveStrategies("code")
6565
require.NoError(t, err)
6666
require.Len(t, s, 1)
6767
w := httptest.NewRecorder()
6868
r := &http.Request{URL: new(url.URL)}
6969
f, err := recovery.NewFlow(reg.Config(), time.Minute, "", r, s, flow.TypeBrowser)
7070
require.NoError(t, err)
7171
require.NotPanics(t, func() {
72-
require.Error(t, s[0].(*Strategy).HandleRecoveryError(w, r, f, nil, errors.New("test")))
72+
require.Error(t, ps.(*Strategy).HandleRecoveryError(w, r, f, nil, errors.New("test")))
7373
})
7474
})
7575

0 commit comments

Comments
 (0)