-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathmain.go
More file actions
300 lines (263 loc) · 7.46 KB
/
main.go
File metadata and controls
300 lines (263 loc) · 7.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
package main
import (
"bytes"
_ "embed"
"errors"
"flag"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"log"
"os"
"slices"
"strings"
"text/template"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
utilstrings "github.com/coder/coder/v2/coderd/util/strings"
"github.com/coder/coder/v2/codersdk"
)
//go:embed rbacobject.gotmpl
var rbacObjectTemplate string
//go:embed codersdk.gotmpl
var codersdkTemplate string
//go:embed typescript.tstmpl
var typescriptTemplate string
//go:embed scopenames.gotmpl
var scopenamesTemplate string
//go:embed countries.tstmpl
var countriesTemplate string
func usage() {
_, _ = fmt.Println("Usage: typegen <type> [template]")
_, _ = fmt.Println("Types:")
_, _ = fmt.Println(" rbac <object|codersdk|typescript> - Generate RBAC related files")
_, _ = fmt.Println(" countries - Generate countries TypeScript")
}
// main will generate a file based on the type and template specified.
// This is to provide an "AllResources" function that is always
// in sync.
func main() {
flag.Parse()
if len(flag.Args()) < 1 {
usage()
os.Exit(1)
}
var (
out []byte
err error
)
// It did not make sense to have 2 different generators that do essentially
// the same thing, but different format for the BE and the sdk.
// So the argument switches the go template to use.
switch strings.ToLower(flag.Args()[0]) {
case "rbac":
if len(flag.Args()) < 2 {
usage()
os.Exit(1)
}
out, err = generateRBAC(flag.Args()[1])
case "countries":
out, err = generateCountries()
default:
_, _ = fmt.Fprintf(os.Stderr, "%q is not a valid type\n", flag.Args()[0])
usage()
os.Exit(2)
}
if err != nil {
log.Fatalf("Generate source: %s", err.Error())
}
_, _ = fmt.Fprint(os.Stdout, string(out))
}
func generateRBAC(tmpl string) ([]byte, error) {
formatSource := format.Source
var source string
switch strings.ToLower(tmpl) {
case "codersdk":
source = codersdkTemplate
case "object":
source = rbacObjectTemplate
case "typescript":
source = typescriptTemplate
formatSource = func(src []byte) ([]byte, error) {
// No typescript formatting
return src, nil
}
case "scopenames":
source = scopenamesTemplate
default:
return nil, xerrors.Errorf("%q is not a valid RBAC template target", tmpl)
}
out, err := generateRbacObjects(source)
if err != nil {
return nil, err
}
return formatSource(out)
}
func generateCountries() ([]byte, error) {
tmpl, err := template.New("countries.tstmpl").Parse(countriesTemplate)
if err != nil {
return nil, xerrors.Errorf("parse template: %w", err)
}
var out bytes.Buffer
err = tmpl.Execute(&out, codersdk.Countries)
if err != nil {
return nil, xerrors.Errorf("execute template: %w", err)
}
return out.Bytes(), nil
}
func pascalCaseName[T ~string](name T) string {
names := strings.Split(string(name), "_")
for i := range names {
names[i] = utilstrings.Capitalize(names[i])
}
return strings.Join(names, "")
}
type Definition struct {
policy.PermissionDefinition
Type string
}
func (p Definition) FunctionName() string {
if p.Name != "" {
return p.Name
}
return p.Type
}
// fileActions is required because we cannot get the variable name of the enum
// at runtime. So parse the package to get it. This is purely to ensure enum
// names are consistent, which is a bit annoying, but not too bad.
func fileActions(file *ast.File) map[string]string {
// actions is a map from the enum value -> enum name
actions := make(map[string]string)
// Find the action consts
fileDeclLoop:
for _, decl := range file.Decls {
switch typedDecl := decl.(type) {
case *ast.GenDecl:
if len(typedDecl.Specs) == 0 {
continue
}
// This is the right on, loop over all idents, pull the actions
for _, spec := range typedDecl.Specs {
vSpec, ok := spec.(*ast.ValueSpec)
if !ok {
continue fileDeclLoop
}
typeIdent, ok := vSpec.Type.(*ast.Ident)
if !ok {
continue fileDeclLoop
}
if typeIdent.Name != "Action" || len(vSpec.Values) != 1 || len(vSpec.Names) != 1 {
continue fileDeclLoop
}
literal, ok := vSpec.Values[0].(*ast.BasicLit)
if !ok {
continue fileDeclLoop
}
actions[strings.Trim(literal.Value, `"`)] = vSpec.Names[0].Name
}
default:
continue
}
}
return actions
}
type ActionDetails struct {
Enum string
Value string
}
// generateRbacObjects will take the policy.go file, and send it as input
// to the go templates. Some AST of the Action enum is also included.
func generateRbacObjects(templateSource string) ([]byte, error) {
// Parse the policy.go file for the action enums
f, err := parser.ParseFile(token.NewFileSet(), "./coderd/rbac/policy/policy.go", nil, parser.ParseComments)
if err != nil {
return nil, xerrors.Errorf("parsing policy.go: %w", err)
}
actionMap := fileActions(f)
actionList := make([]ActionDetails, 0)
for value, enum := range actionMap {
actionList = append(actionList, ActionDetails{
Enum: enum,
Value: value,
})
}
// Sorting actions for auto gen consistency.
slices.SortFunc(actionList, func(a, b ActionDetails) int {
return strings.Compare(a.Enum, b.Enum)
})
var errorList []error
var x int
tpl, err := template.New("object.gotmpl").Funcs(template.FuncMap{
"capitalize": utilstrings.Capitalize,
"pascalCaseName": pascalCaseName[string],
"actionsList": func() []ActionDetails {
return actionList
},
"actionsOf": func(d Definition) []string {
// Extract and sort action string keys for deterministic output.
list := make([]string, 0, len(d.Actions))
for a := range d.Actions {
list = append(list, string(a))
}
slices.Sort(list)
return list
},
"allCaseList": func(defs []Definition) string {
// Build a multi-line comma-separated list of all scope constants (including builtins)
// suitable for use in a `case ...:` clause, without a trailing comma.
var names []string
// Builtins first, sourced dynamically from the rbac package to avoid drift.
for _, n := range rbac.BuiltinScopeNames() {
// Use typed string literals to avoid relying on constant identifiers.
names = append(names, fmt.Sprintf("ScopeName(%q)", string(n)))
}
for _, d := range defs {
res := pascalCaseName[string](d.Type)
acts := make([]string, 0, len(d.Actions))
for a := range d.Actions {
acts = append(acts, string(a))
}
slices.Sort(acts)
for _, a := range acts {
names = append(names, "Scope"+res+pascalCaseName[string](a))
}
}
return strings.Join(names, ",\n\t\t")
},
"actionEnum": func(action policy.Action) string {
x++
v, ok := actionMap[string(action)]
if !ok {
errorList = append(errorList, xerrors.Errorf("action value %q does not have a constant a matching enum constant", action))
}
return v
},
"concat": func(strs ...string) string { return strings.Join(strs, "") },
}).Parse(templateSource)
if err != nil {
return nil, xerrors.Errorf("parse template: %w", err)
}
// Convert to sorted list for autogen consistency.
var out bytes.Buffer
list := make([]Definition, 0)
for t, v := range policy.RBACPermissions {
list = append(list, Definition{
PermissionDefinition: v,
Type: t,
})
}
slices.SortFunc(list, func(a, b Definition) int {
return strings.Compare(a.Type, b.Type)
})
err = tpl.Execute(&out, list)
if err != nil {
return nil, xerrors.Errorf("execute template: %w", err)
}
if len(errorList) > 0 {
return nil, errors.Join(errorList...)
}
return out.Bytes(), nil
}