-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerator.go
396 lines (350 loc) · 12 KB
/
generator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 werbenhu
// SPDX-FileContributor: werbenhu
package digo
import (
"fmt"
"go/ast"
"go/format"
"go/token"
"os"
"path/filepath"
"strings"
)
// newIdent creates a new ast.Ident with the given name.
func newIdent(name string) *ast.Ident {
return &ast.Ident{Name: name}
}
// newSelectorExpr creates a new ast.SelectorExpr with the given package and name.
func newSelectorExpr(selector string) *ast.SelectorExpr {
splitted := strings.Split(selector, ".")
return &ast.SelectorExpr{
X: newIdent(splitted[0]),
Sel: newIdent(splitted[1]),
}
}
// newStarExpr creates a new ast.StarExpr with the given package and name.
func newStarExpr(selector string) *ast.StarExpr {
splitted := strings.Split(selector, ".")
return &ast.StarExpr{
X: &ast.SelectorExpr{
X: newIdent(splitted[0]),
Sel: newIdent(splitted[1]),
},
}
}
// newCommentGroup creates a new ast.CommentGroup with the given texts.
func newCommentGroup(texts []string) *ast.CommentGroup {
comments := make([]*ast.Comment, len(texts))
for i, text := range texts {
comments[i] = &ast.Comment{
Text: text,
}
}
return &ast.CommentGroup{
List: comments,
}
}
// newCallExpr creates a new ast.CallExpr with the given function and arguments.
func newCallExpr(fn ast.Expr, args []ast.Expr) *ast.CallExpr {
return &ast.CallExpr{
Fun: fn,
Args: args,
}
}
// newExprs creates a new slice of ast.Expr with the given expressions.
func newExprs(exprs ...ast.Expr) []ast.Expr {
rets := make([]ast.Expr, len(exprs))
copy(rets, exprs)
return rets
}
// newBasicLit creates a new ast.BasicLit with the given value.
func newBasicLit(val string) *ast.BasicLit {
return &ast.BasicLit{
Kind: token.STRING,
Value: "\"" + val + "\"",
}
}
// newErrCheckStmt creates a new error check statement for the if condition.
func newErrCheckStmt() ast.Stmt {
return &ast.IfStmt{
Cond: &ast.BinaryExpr{
X: newIdent("err"),
Op: token.NEQ,
Y: newIdent("nil"),
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ExprStmt{
X: newCallExpr(newIdent("panic"), []ast.Expr{
newIdent("err"),
}),
},
},
},
}
}
func newImportSpec(path, alias string) *ast.ImportSpec {
spec := &ast.ImportSpec{
Path: newBasicLit(path),
}
if len(alias) != 0 {
spec.Name = &ast.Ident{
Name: alias,
}
}
return spec
}
func objName(prefix string) string {
name := strings.ReplaceAll(prefix, ".", "_")
name = strings.ReplaceAll(name, "/", "_")
return name + "_obj"
}
// Generator is a code generator for dependency injection.
type Generator struct {
Package *DiPackage
CalledInitFuncs []ast.Stmt // Initialization functions for singleton objects
Fset *token.FileSet // FileSet for token positions
Decls []ast.Decl // Functions generated by the provider
ImportSpecs map[string]ast.Spec // Packages to be imported
ManagerPackage string
RegisterFunction string
ProvideFunction string
GroupFunction string
GeneratedFileName string
}
// NewGenerator creates a new Generator with the given path, package name, and filename.
func NewGenerator(pkg *DiPackage) *Generator {
return &Generator{
Package: pkg,
CalledInitFuncs: make([]ast.Stmt, 0),
Fset: token.NewFileSet(),
Decls: make([]ast.Decl, 0),
ImportSpecs: make(map[string]ast.Spec),
ManagerPackage: "github.com/werbenhu/digo",
RegisterFunction: "digo.RegisterSingleton",
ProvideFunction: "digo.Provide",
GroupFunction: "digo.RegisterMember",
GeneratedFileName: "digo.generated.go",
}
}
// defineInjectStmts analyzes the inject annotation and generates the corresponding code segment based on the annotation information.
func (g *Generator) defineInjectStmts(inject *Injector) []ast.Stmt {
stmts := make([]ast.Stmt, 0)
// Add import statement if the package is specified.
if len(inject.Pkg) > 0 {
g.addImport(inject.Pkg, inject.Alias)
}
// Generate assignment statements for providing the object and handling the error.
stmts = append(stmts,
&ast.AssignStmt{
Lhs: newExprs(newIdent(objName(inject.Param)), newIdent("err")),
Tok: token.DEFINE,
Rhs: newExprs(
newCallExpr(
newSelectorExpr(g.ProvideFunction),
[]ast.Expr{newBasicLit(inject.ProviderId)},
),
),
},
newErrCheckStmt(),
&ast.AssignStmt{
Lhs: newExprs(newIdent(inject.Param)),
Tok: token.DEFINE,
Rhs: newExprs(&ast.TypeAssertExpr{
X: newIdent(objName(inject.Param)),
Type: inject.Typ,
}),
})
return stmts
}
// defineProviderFunc creates a provider's singleton initialization function and returns an ast.FuncDecl object.
func (g *Generator) defineProviderFunc(fn *DiFunc) *ast.FuncDecl {
stmts := make([]ast.Stmt, 0)
args := make([]ast.Expr, 0)
// Generate function arguments and inject statements if there are injectors.
for _, inject := range fn.Injectors {
args = append(args, newIdent(inject.GetArgName()))
stmts = append(stmts, g.defineInjectStmts(inject)...)
}
// Generate assignment statements for calling the provider function, defining the object, and registering it as a singleton.
stmts = append(stmts, &ast.AssignStmt{
Lhs: newExprs(newIdent(fn.providerObjName())),
Tok: token.DEFINE,
Rhs: newExprs(newCallExpr(newIdent(fn.Name), args)),
}, &ast.ExprStmt{
X: newCallExpr(newSelectorExpr(g.RegisterFunction), newExprs(
newBasicLit(fn.ProviderId),
newIdent(fn.providerObjName())),
),
})
comments := []string{
fmt.Sprintf("\n// %s registers the singleton object with ID %s into the DI object manager", fn.providerFuncName(), fn.ProviderId),
fmt.Sprintf("// Now you can retrieve the singleton object by using `obj, err := digo.Provide(\"%s\")`.", fn.ProviderId),
"// The obj obtained from the above code is of type `any`.",
"// You will need to forcefully cast the obj to its corresponding actual object type.",
}
return &ast.FuncDecl{
Doc: newCommentGroup(comments),
Name: newIdent(fn.providerFuncName()),
Type: &ast.FuncType{},
Body: &ast.BlockStmt{List: stmts},
}
}
// defineProviderFuncs generates initialization functions for singleton objects.
func (g *Generator) defineProviderFuncs() {
// Iterate over each provider and generate the initialization function for the singleton object.
for _, fn := range g.Package.Funcs {
if len(fn.ProviderId) > 0 {
// Add the initialization function for the singleton object to the ast.File.
// For example, if the provider's ID is "xxx", then we add the init_xxx() function to the AST.
g.Decls = append(g.Decls, g.defineProviderFunc(fn))
// The initialization function for the singleton object needs to be called in the init() function.
// Store the functions that need to be called in init() in the callFuncsInInit slice.
// Later when creating the init() function, we will call these initialization functions for the singleton objects.
g.CalledInitFuncs = append(g.CalledInitFuncs, &ast.ExprStmt{
X: newCallExpr(newIdent(fn.providerFuncName()), newExprs()),
})
}
}
}
// defineGroupFuncs adds initialization functions for all group singleton objects to the AST.
func (g *Generator) defineGroupFuncs() {
// Iterate over each member and generate the initialization function for the singleton object.
for _, fn := range g.Package.Funcs {
if len(fn.GroupId) > 0 {
// Add the initialization function for the singleton object to the ast.File.
// For example, if the provider's ID is "xxx", then we add the init_xxx() function to the AST.
g.Decls = append(g.Decls, g.defineGroupFunc(fn))
// The initialization function for the singleton object needs to be called in the init() function.
// Store the functions that need to be called in init() in the callFuncsInInit slice.
// Later when creating the init() function, we will call these initialization functions for the singleton objects.
g.CalledInitFuncs = append(g.CalledInitFuncs, &ast.ExprStmt{
X: newCallExpr(newIdent(fn.groupFuncName()), newExprs()),
})
}
}
}
// addImport adds a package name to the AST object.
func (g *Generator) addImport(pkg string, alias string) {
key := pkg + "_" + alias
if _, ok := g.ImportSpecs[key]; !ok {
g.ImportSpecs[key] = newImportSpec(pkg, alias)
}
}
// defineGroupFunc creates a group's member initialization function and returns an ast.FuncDecl object.
func (g *Generator) defineGroupFunc(fn *DiFunc) *ast.FuncDecl {
stmts := make([]ast.Stmt, 0)
args := make([]ast.Expr, 0)
if len(fn.ProviderId) > 0 {
// Generate assignment statement for providing the member object and handling the error.
stmts = append(stmts,
&ast.AssignStmt{
Lhs: newExprs(newIdent("member"), newIdent("err")),
Tok: token.DEFINE,
Rhs: newExprs(
newCallExpr(
newSelectorExpr(g.GroupFunction),
[]ast.Expr{newBasicLit(fn.ProviderId)},
),
),
},
newErrCheckStmt(),
)
} else {
// Generate arguments and inject statements for member initialization.
for _, inject := range fn.Injectors {
args = append(args, newIdent(inject.Param))
stmts = append(stmts, g.defineInjectStmts(inject)...)
}
stmts = append(stmts, &ast.AssignStmt{
Lhs: newExprs(newIdent("member")),
Tok: token.DEFINE,
Rhs: newExprs(newCallExpr(newIdent(fn.Name), args)),
})
}
// Register the member object with the group.
stmts = append(stmts, &ast.ExprStmt{
X: newCallExpr(newSelectorExpr(g.GroupFunction), newExprs(
newBasicLit(fn.GroupId),
newIdent("member")),
),
})
comments := []string{
fmt.Sprintf("\n// Add a member object to group: %s", fn.GroupId),
fmt.Sprintf("// Now you can retrieve the group's member objects by using `objs, err := digo.Members(\"%s\")`.", fn.GroupId),
"// The objs obtained from the above code are of type `[]any`.",
"// You will need to forcefully cast the objs to their corresponding actual object types.",
}
return &ast.FuncDecl{
Doc: newCommentGroup(comments),
Name: newIdent(fn.groupFuncName()),
Type: &ast.FuncType{},
Body: &ast.BlockStmt{List: stmts},
}
}
// defineInitFunc generates the code for the init() function as an ast.FuncDecl object.
func (g *Generator) defineInitFunc() {
decl := &ast.FuncDecl{
Doc: newCommentGroup([]string{
"\n// init registers all providers in the current package into the DI object manager.",
}),
Name: newIdent("init"),
Type: &ast.FuncType{},
Body: &ast.BlockStmt{
List: g.CalledInitFuncs,
},
}
g.Decls = append(g.Decls, decl)
}
// genAllAstDecls combines all ast.Decl objects into g.decls, where import declarations come before function declarations.
func (g *Generator) genAllAstDecls() {
importSpecs := make([]ast.Spec, 0)
for _, spec := range g.ImportSpecs {
importSpecs = append(importSpecs, spec)
}
g.Decls = append([]ast.Decl{&ast.GenDecl{
Tok: token.IMPORT,
Specs: importSpecs,
}}, g.Decls...)
}
// writeHeaderComment writes the header comment to the Go file.
func (g *Generator) writeHeaderComment(file *os.File) int {
header := "\n//\n// This file is generated by digogen. Run 'digogen' to regenerate.\n//\n" +
"// You can install this tool by running `go install github.com/werbenhu/digo/digogen`.\n" +
"// For more details, please refer to https://github.com/werbenhu/digo. \n//\n"
fmt.Fprintf(file, header)
return len(header)
}
// output writes the generated AST structures to the Go code file.
func (g *Generator) output() error {
if err := os.MkdirAll(g.Package.Folder, 0777); err != nil {
return err
}
path := filepath.Join(g.Package.Folder, g.GeneratedFileName)
os.Remove(path)
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0666)
if err != nil {
return err
}
g.genAllAstDecls()
startPos := g.writeHeaderComment(file)
dest := &ast.File{
FileStart: token.Pos(startPos),
Name: &ast.Ident{
Name: g.Package.Name,
},
Decls: g.Decls,
}
ast.SortImports(g.Fset, dest)
format.Node(file, g.Fset, dest)
return nil
}
// Do converts the extracted providers and injectors in the current package into Go AST structures and outputs the code to a Go file.
func (g *Generator) Do() {
g.addImport(g.ManagerPackage, "")
g.defineProviderFuncs()
g.defineGroupFuncs()
g.defineInitFunc()
g.output()
}