Skip to content

Commit caeca10

Browse files
authored
chore: refactor license validation (coder#20411)
1 parent 823b14a commit caeca10

3 files changed

Lines changed: 65 additions & 3 deletions

File tree

enterprise/coderd/coderdenttest/coderdenttest.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ type LicenseOptions struct {
186186
// past.
187187
IssuedAt time.Time
188188
Features license.Features
189+
190+
AllowEmpty bool
189191
}
190192

191193
func (opts *LicenseOptions) WithIssuedAt(now time.Time) *LicenseOptions {
@@ -276,10 +278,10 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string {
276278
issuedAt = time.Now().Add(-time.Minute)
277279
}
278280

279-
if options.AccountType == "" {
281+
if !options.AllowEmpty && options.AccountType == "" {
280282
options.AccountType = license.AccountTypeSalesforce
281283
}
282-
if options.AccountID == "" {
284+
if !options.AllowEmpty && options.AccountID == "" {
283285
options.AccountID = "test-account-id"
284286
}
285287

enterprise/coderd/license/license.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,8 @@ var (
612612
ErrMissingLicenseExpires = xerrors.New("license has invalid or missing license_expires claim")
613613
ErrMissingExp = xerrors.New("license has invalid or missing exp (expires at) claim")
614614
ErrMultipleIssues = xerrors.New("license has multiple issues; contact support")
615+
ErrMissingAccountType = xerrors.New("license must contain valid account type")
616+
ErrMissingAccountID = xerrors.New("license must contain valid account ID")
615617
)
616618

617619
type Features map[codersdk.FeatureName]int64
@@ -696,12 +698,20 @@ func validateClaims(tok *jwt.Token) (*Claims, error) {
696698
if claims.NotBefore == nil {
697699
return nil, ErrMissingNotBefore
698700
}
699-
if claims.LicenseExpires == nil {
701+
702+
yearsHardLimit := time.Now().Add(5 /* years */ * 365 * 24 * time.Hour)
703+
if claims.LicenseExpires == nil || claims.LicenseExpires.Time.After(yearsHardLimit) {
700704
return nil, ErrMissingLicenseExpires
701705
}
702706
if claims.ExpiresAt == nil {
703707
return nil, ErrMissingExp
704708
}
709+
if claims.AccountType == "" {
710+
return nil, ErrMissingAccountType
711+
}
712+
if claims.AccountID == "" {
713+
return nil, ErrMissingAccountID
714+
}
705715
return claims, nil
706716
}
707717
return nil, xerrors.New("unable to parse Claims")

enterprise/coderd/licenses_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,56 @@ func TestPostLicense(t *testing.T) {
5454
require.Contains(t, errResp.Message, "License cannot be used on this deployment!")
5555
})
5656

57+
t.Run("InvalidAccountID", func(t *testing.T) {
58+
t.Parallel()
59+
// The generated deployment will start out with a different deployment ID.
60+
client, _ := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true})
61+
license := coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
62+
AllowEmpty: true,
63+
AccountID: "",
64+
})
65+
_, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{
66+
License: license,
67+
})
68+
errResp := &codersdk.Error{}
69+
require.ErrorAs(t, err, &errResp)
70+
require.Equal(t, http.StatusBadRequest, errResp.StatusCode())
71+
require.Contains(t, errResp.Message, "Invalid license")
72+
})
73+
74+
t.Run("InvalidAccountType", func(t *testing.T) {
75+
t.Parallel()
76+
// The generated deployment will start out with a different deployment ID.
77+
client, _ := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true})
78+
license := coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
79+
AllowEmpty: true,
80+
AccountType: "",
81+
})
82+
_, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{
83+
License: license,
84+
})
85+
errResp := &codersdk.Error{}
86+
require.ErrorAs(t, err, &errResp)
87+
require.Equal(t, http.StatusBadRequest, errResp.StatusCode())
88+
require.Contains(t, errResp.Message, "Invalid license")
89+
})
90+
91+
t.Run("InvalidLicenseExpires", func(t *testing.T) {
92+
t.Parallel()
93+
// The generated deployment will start out with a different deployment ID.
94+
client, _ := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true})
95+
license := coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
96+
GraceAt: time.Unix(99999999999, 0),
97+
})
98+
_, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{
99+
License: license,
100+
})
101+
errResp := &codersdk.Error{}
102+
require.ErrorAs(t, err, &errResp)
103+
require.Equal(t, http.StatusBadRequest, errResp.StatusCode())
104+
require.Contains(t, errResp.Message, "Invalid license")
105+
})
106+
57107
t.Run("Unauthorized", func(t *testing.T) {
58108
t.Parallel()
59109
client, _ := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true})

0 commit comments

Comments
 (0)