Skip to content

Commit 7776922

Browse files
committed
convert to asyncpg-based
1 parent 1ac5932 commit 7776922

File tree

3 files changed

+67
-209
lines changed

3 files changed

+67
-209
lines changed

internal/config.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ package python
22

33
type Config struct {
44
EmitExactTableNames bool `json:"emit_exact_table_names"`
5-
EmitSyncQuerier bool `json:"emit_sync_querier"`
6-
EmitAsyncQuerier bool `json:"emit_async_querier"`
75
Package string `json:"package"`
86
Out string `json:"out"`
97
EmitPydanticModels bool `json:"emit_pydantic_models"`

internal/gen.go

Lines changed: 65 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -354,18 +354,6 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum
354354
return &gs
355355
}
356356

357-
var postgresPlaceholderRegexp = regexp.MustCompile(`\B\$(\d+)\b`)
358-
359-
// Sqlalchemy uses ":name" for placeholders, so "$N" is converted to ":pN"
360-
// This also means ":" has special meaning to sqlalchemy, so it must be escaped.
361-
func sqlalchemySQL(s, engine string) string {
362-
s = strings.ReplaceAll(s, ":", `\\:`)
363-
if engine == "postgresql" {
364-
return postgresPlaceholderRegexp.ReplaceAllString(s, ":p$1")
365-
}
366-
return s
367-
}
368-
369357
func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([]Query, error) {
370358
qs := make([]Query, 0, len(req.Queries))
371359
for _, query := range req.Queries {
@@ -387,7 +375,7 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([
387375
MethodName: methodName,
388376
FieldName: sdk.LowerTitle(query.Name) + "Stmt",
389377
ConstantName: strings.ToUpper(methodName),
390-
SQL: sqlalchemySQL(query.Text, req.Settings.Engine),
378+
SQL: query.Text,
391379
SourceName: query.Filename,
392380
}
393381

@@ -625,18 +613,7 @@ func typeRefNode(base string, parts ...string) *pyast.Node {
625613
}
626614

627615
func connMethodNode(method, name string, arg *pyast.Node) *pyast.Node {
628-
args := []*pyast.Node{
629-
{
630-
Node: &pyast.Node_Call{
631-
Call: &pyast.Call{
632-
Func: typeRefNode("sqlalchemy", "text"),
633-
Args: []*pyast.Node{
634-
poet.Name(name),
635-
},
636-
},
637-
},
638-
},
639-
}
616+
args := []*pyast.Node{poet.Name(name)}
640617
if arg != nil {
641618
args = append(args, arg)
642619
}
@@ -792,7 +769,7 @@ func asyncQuerierClassDef() *pyast.ClassDef {
792769
},
793770
{
794771
Arg: "conn",
795-
Annotation: typeRefNode("sqlalchemy", "ext", "asyncio", "AsyncConnection"),
772+
Annotation: typeRefNode("asyncpg", "pool", "PoolConnectionProxy"),
796773
},
797774
},
798775
},
@@ -872,190 +849,81 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
872849
}
873850
}
874851

875-
if ctx.C.EmitSyncQuerier {
876-
cls := querierClassDef()
877-
for _, q := range ctx.Queries {
878-
if !ctx.OutputQuery(q.SourceName) {
879-
continue
880-
}
881-
f := &pyast.FunctionDef{
882-
Name: q.MethodName,
883-
Args: &pyast.Arguments{
884-
Args: []*pyast.Arg{
885-
{
886-
Arg: "self",
887-
},
852+
cls := asyncQuerierClassDef()
853+
for _, q := range ctx.Queries {
854+
if !ctx.OutputQuery(q.SourceName) {
855+
continue
856+
}
857+
f := &pyast.AsyncFunctionDef{
858+
Name: q.MethodName,
859+
Args: &pyast.Arguments{
860+
Args: []*pyast.Arg{
861+
{
862+
Arg: "self",
888863
},
889864
},
890-
}
891-
892-
q.AddArgs(f.Args)
893-
exec := connMethodNode("execute", q.ConstantName, q.ArgDictNode())
865+
},
866+
}
894867

895-
switch q.Cmd {
896-
case ":one":
897-
f.Body = append(f.Body,
898-
assignNode("row", poet.Node(
899-
&pyast.Call{
900-
Func: poet.Attribute(exec, "first"),
901-
},
902-
)),
903-
poet.Node(
904-
&pyast.If{
905-
Test: poet.Node(
906-
&pyast.Compare{
907-
Left: poet.Name("row"),
908-
Ops: []*pyast.Node{
909-
poet.Is(),
910-
},
911-
Comparators: []*pyast.Node{
912-
poet.Constant(nil),
913-
},
868+
q.AddArgs(f.Args)
869+
870+
switch q.Cmd {
871+
case ":one":
872+
fetchrow := connMethodNode("fetchrow", q.ConstantName, q.ArgDictNode())
873+
f.Body = append(f.Body,
874+
assignNode("row", poet.Await(fetchrow)),
875+
poet.Node(
876+
&pyast.If{
877+
Test: poet.Node(
878+
&pyast.Compare{
879+
Left: poet.Name("row"),
880+
Ops: []*pyast.Node{
881+
poet.Is(),
914882
},
915-
),
916-
Body: []*pyast.Node{
917-
poet.Return(
883+
Comparators: []*pyast.Node{
918884
poet.Constant(nil),
919-
),
885+
},
920886
},
887+
),
888+
Body: []*pyast.Node{
889+
poet.Return(
890+
poet.Constant(nil),
891+
),
921892
},
922-
),
923-
poet.Return(q.Ret.RowNode("row")),
924-
)
925-
f.Returns = subscriptNode("Optional", q.Ret.Annotation())
926-
case ":many":
927-
f.Body = append(f.Body,
928-
assignNode("result", exec),
929-
poet.Node(
930-
&pyast.For{
931-
Target: poet.Name("row"),
932-
Iter: poet.Name("result"),
933-
Body: []*pyast.Node{
934-
poet.Expr(
935-
poet.Yield(
936-
q.Ret.RowNode("row"),
937-
),
893+
},
894+
),
895+
poet.Return(q.Ret.RowNode("row")),
896+
)
897+
f.Returns = subscriptNode("Optional", q.Ret.Annotation())
898+
case ":many":
899+
cursor := connMethodNode("cursor", q.ConstantName, q.ArgDictNode())
900+
f.Body = append(f.Body,
901+
poet.Node(
902+
&pyast.AsyncFor{
903+
Target: poet.Name("row"),
904+
Iter: cursor,
905+
Body: []*pyast.Node{
906+
poet.Expr(
907+
poet.Yield(
908+
q.Ret.RowNode("row"),
938909
),
939-
},
940-
},
941-
),
942-
)
943-
f.Returns = subscriptNode("Iterator", q.Ret.Annotation())
944-
case ":exec":
945-
f.Body = append(f.Body, exec)
946-
f.Returns = poet.Constant(nil)
947-
case ":execrows":
948-
f.Body = append(f.Body,
949-
assignNode("result", exec),
950-
poet.Return(poet.Attribute(poet.Name("result"), "rowcount")),
951-
)
952-
f.Returns = poet.Name("int")
953-
case ":execresult":
954-
f.Body = append(f.Body,
955-
poet.Return(exec),
956-
)
957-
f.Returns = typeRefNode("sqlalchemy", "engine", "Result")
958-
default:
959-
panic("unknown cmd " + q.Cmd)
960-
}
961-
962-
cls.Body = append(cls.Body, poet.Node(f))
963-
}
964-
mod.Body = append(mod.Body, poet.Node(cls))
965-
}
966-
967-
if ctx.C.EmitAsyncQuerier {
968-
cls := asyncQuerierClassDef()
969-
for _, q := range ctx.Queries {
970-
if !ctx.OutputQuery(q.SourceName) {
971-
continue
972-
}
973-
f := &pyast.AsyncFunctionDef{
974-
Name: q.MethodName,
975-
Args: &pyast.Arguments{
976-
Args: []*pyast.Arg{
977-
{
978-
Arg: "self",
910+
),
979911
},
980912
},
981-
},
982-
}
983-
984-
q.AddArgs(f.Args)
913+
),
914+
)
915+
f.Returns = subscriptNode("AsyncIterator", q.Ret.Annotation())
916+
case ":exec":
985917
exec := connMethodNode("execute", q.ConstantName, q.ArgDictNode())
986-
987-
switch q.Cmd {
988-
case ":one":
989-
f.Body = append(f.Body,
990-
assignNode("row", poet.Node(
991-
&pyast.Call{
992-
Func: poet.Attribute(poet.Await(exec), "first"),
993-
},
994-
)),
995-
poet.Node(
996-
&pyast.If{
997-
Test: poet.Node(
998-
&pyast.Compare{
999-
Left: poet.Name("row"),
1000-
Ops: []*pyast.Node{
1001-
poet.Is(),
1002-
},
1003-
Comparators: []*pyast.Node{
1004-
poet.Constant(nil),
1005-
},
1006-
},
1007-
),
1008-
Body: []*pyast.Node{
1009-
poet.Return(
1010-
poet.Constant(nil),
1011-
),
1012-
},
1013-
},
1014-
),
1015-
poet.Return(q.Ret.RowNode("row")),
1016-
)
1017-
f.Returns = subscriptNode("Optional", q.Ret.Annotation())
1018-
case ":many":
1019-
stream := connMethodNode("stream", q.ConstantName, q.ArgDictNode())
1020-
f.Body = append(f.Body,
1021-
assignNode("result", poet.Await(stream)),
1022-
poet.Node(
1023-
&pyast.AsyncFor{
1024-
Target: poet.Name("row"),
1025-
Iter: poet.Name("result"),
1026-
Body: []*pyast.Node{
1027-
poet.Expr(
1028-
poet.Yield(
1029-
q.Ret.RowNode("row"),
1030-
),
1031-
),
1032-
},
1033-
},
1034-
),
1035-
)
1036-
f.Returns = subscriptNode("AsyncIterator", q.Ret.Annotation())
1037-
case ":exec":
1038-
f.Body = append(f.Body, poet.Await(exec))
1039-
f.Returns = poet.Constant(nil)
1040-
case ":execrows":
1041-
f.Body = append(f.Body,
1042-
assignNode("result", poet.Await(exec)),
1043-
poet.Return(poet.Attribute(poet.Name("result"), "rowcount")),
1044-
)
1045-
f.Returns = poet.Name("int")
1046-
case ":execresult":
1047-
f.Body = append(f.Body,
1048-
poet.Return(poet.Await(exec)),
1049-
)
1050-
f.Returns = typeRefNode("sqlalchemy", "engine", "Result")
1051-
default:
1052-
panic("unknown cmd " + q.Cmd)
1053-
}
1054-
1055-
cls.Body = append(cls.Body, poet.Node(f))
918+
f.Body = append(f.Body, poet.Await(exec))
919+
f.Returns = poet.Constant(nil)
920+
default:
921+
panic("unknown cmd " + q.Cmd)
1056922
}
1057-
mod.Body = append(mod.Body, poet.Node(cls))
923+
924+
cls.Body = append(cls.Body, poet.Node(f))
1058925
}
926+
mod.Body = append(mod.Body, poet.Node(cls))
1059927

1060928
return poet.Node(mod)
1061929
}

internal/imports.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,7 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map
131131
std := stdImports(queryUses)
132132

133133
pkg := make(map[string]importSpec)
134-
pkg["sqlalchemy"] = importSpec{Module: "sqlalchemy"}
135-
if i.C.EmitAsyncQuerier {
136-
pkg["sqlalchemy.ext.asyncio"] = importSpec{Module: "sqlalchemy.ext.asyncio"}
137-
}
134+
pkg["asyncpg"] = importSpec{Module: "asyncpg"}
138135

139136
queryValueModelImports := func(qv QueryValue) {
140137
if qv.IsStruct() && qv.EmitStruct() {
@@ -154,12 +151,7 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map
154151
std["typing.Optional"] = importSpec{Module: "typing", Name: "Optional"}
155152
}
156153
if q.Cmd == ":many" {
157-
if i.C.EmitSyncQuerier {
158-
std["typing.Iterator"] = importSpec{Module: "typing", Name: "Iterator"}
159-
}
160-
if i.C.EmitAsyncQuerier {
161-
std["typing.AsyncIterator"] = importSpec{Module: "typing", Name: "AsyncIterator"}
162-
}
154+
std["typing.AsyncIterator"] = importSpec{Module: "typing", Name: "AsyncIterator"}
163155
}
164156
queryValueModelImports(q.Ret)
165157
for _, qv := range q.Args {

0 commit comments

Comments
 (0)